GitOrigin-RevId: b9bb7352bc
tags/v1.8.0
| @@ -30,18 +30,18 @@ struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| template <typename src_ctype, typename dst_type = src_ctype> | |||
| struct TanhOp; | |||
| #define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ | |||
| #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ | |||
| using TanhOpBase::TanhOpBase; \ | |||
| using TanhOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||
| void operator()(const _neon_type2& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| vst1q_##_func_suffix(dst, vitem.val[0]); \ | |||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _neon_type operator()(const _neon_type& src) const { \ | |||
| _neon_type2 operator()(const _neon_type2& src) const { \ | |||
| auto one_val = vdupq_n_##_func_suffix(1.f); \ | |||
| auto two_val = vdupq_n_##_func_suffix(2.f); \ | |||
| auto val1 = src.val[0]; \ | |||
| @@ -62,10 +62,23 @@ struct TanhOp; | |||
| val2 = vsubq_##_func_suffix(one_val, val2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| _neon_type operator()(const _neon_type& src) const { \ | |||
| auto one_val = vdupq_n_##_func_suffix(1.f); \ | |||
| auto two_val = vdupq_n_##_func_suffix(2.f); \ | |||
| auto val1 = src; \ | |||
| val1 = vmulq_##_func_suffix(two_val, val1); \ | |||
| val1 = exp_ps_##_func_suffix(val1); \ | |||
| val1 = vaddq_##_func_suffix(one_val, val1); \ | |||
| auto rval1 = vrecpeq_##_func_suffix(val1); \ | |||
| rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), rval1); \ | |||
| val1 = vmulq_##_func_suffix(two_val, rval1); \ | |||
| val1 = vsubq_##_func_suffix(one_val, val1); \ | |||
| return val1; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, float32x4x2_t, f32, 4) | |||
| OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| OP(__fp16, float16x8x2_t, f16, 8) | |||
| OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | |||
| #endif | |||
| #undef OP | |||
| @@ -19,9 +19,12 @@ | |||
| #include "src/arm_common/elemwise/opr_impl.h" | |||
| #include "src/arm_common/elemwise_multi_type/opr_impl.h" | |||
| #include "src/arm_common/local/opr_impl.h" | |||
| #include "src/arm_common/lstm/opr_impl.h" | |||
| #include "src/arm_common/lstm_cell/opr_impl.h" | |||
| #include "src/arm_common/pooling/opr_impl.h" | |||
| #include "src/arm_common/reduce/opr_impl.h" | |||
| #include "src/arm_common/resize/opr_impl.h" | |||
| #include "src/arm_common/rnn_cell/opr_impl.h" | |||
| #include "src/arm_common/separable_conv/opr_impl.h" | |||
| #include "src/arm_common/separable_filter/opr_impl.h" | |||
| #include "src/arm_common/type_cvt/opr_impl.h" | |||
| @@ -50,6 +53,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wpragmas" | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm/lstm_utils.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./lstm_utils.h" | |||
| #include "src/arm_common/lstm/opr_impl.h" | |||
| #include "src/arm_common/lstm_cell/cell_kernel.h" | |||
| #include "src/arm_common/lstm_cell/opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| LstmCellWeight::LstmCellWeight( | |||
| RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
| DType dtype) { | |||
| // weight_ih: [gate_hidden_size, input_size] | |||
| // weight_hh: [gate_hidden_size, hidden_size] | |||
| // bias_ih: [gate_hidden_size] | |||
| // bias_hh: [gate_hidden_size] | |||
| size_t gate_hidden_size = 4 * hidden_size; | |||
| TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; | |||
| TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; | |||
| TensorLayout bias_layout{{gate_hidden_size}, dtype}; | |||
| m_weight_size = 0; | |||
| m_weight_ih = TensorND(weight_ih_layout, weight_ptr); | |||
| m_weight_size += weight_ih_layout.span().dist_byte(); | |||
| weight_ptr += weight_ih_layout.span().dist_byte(); | |||
| m_weight_hh = TensorND(weight_hh_layout, weight_ptr); | |||
| m_weight_size += weight_hh_layout.span().dist_byte(); | |||
| weight_ptr += weight_hh_layout.span().dist_byte(); | |||
| if (has_bias) { | |||
| m_bias_ih = TensorND(bias_layout, weight_ptr); | |||
| m_weight_size += bias_layout.span().dist_byte(); | |||
| weight_ptr += bias_layout.span().dist_byte(); | |||
| m_bias_hh = TensorND(bias_layout, weight_ptr); | |||
| m_weight_size += bias_layout.span().dist_byte(); | |||
| } | |||
| } | |||
| LstmStates::LstmStates( | |||
| const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size, | |||
| DType dtype) { | |||
| auto& h_ptr = ptr[0]; | |||
| auto& c_ptr = ptr[1]; | |||
| TensorLayout layout{{batch_size, hidden_size}, dtype}; | |||
| m_h = TensorND(layout, h_ptr); | |||
| m_c = TensorND(layout, c_ptr); | |||
| m_memory_size = layout.span().dist_byte(); | |||
| } | |||
| TensorNDArray megdnn::arm_common::split_tensor( | |||
| _megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout) { | |||
| megdnn_assert( | |||
| tensor.layout.span().dist_byte() == nr_tensor * layout.span().dist_byte()); | |||
| TensorNDArray tensors; | |||
| auto ptr = tensor.get_ref_ptr(); | |||
| for (size_t i = 0; i < nr_tensor; i++) { | |||
| tensors.push_back(TensorND(layout, ptr)); | |||
| ptr += layout.span().dist_byte(); | |||
| } | |||
| return tensors; | |||
| } | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| template <> | |||
| void cell_opr_compute<LSTMCell, LstmStates>( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
| _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in bias_hh, const LstmStates& state_in, LstmStates& state_out, | |||
| Workspace cell_workspace, Handle* handle) { | |||
| auto opr = handle->create_operator<LSTMCellForward>(); | |||
| TensorLayout gates, h_new, c_new; | |||
| opr->deduce_layout( | |||
| input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, | |||
| weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); | |||
| auto workspace_bundle = LstmCellCompute::get_workspace_bundle( | |||
| input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, | |||
| weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); | |||
| workspace_bundle.set(cell_workspace.raw_ptr); | |||
| TensorND gates_tensor{workspace_bundle.get(0), gates}; | |||
| _megdnn_workspace new_workspace = { | |||
| static_cast<dt_byte*>(workspace_bundle.get(1)), | |||
| workspace_bundle.get_size(1)}; | |||
| LstmCellCompute::run( | |||
| input, weight_ih, bias_ih, state_in.m_h, weight_hh, bias_hh, state_in.m_c, | |||
| state_out.m_h, state_out.m_c, gates_tensor, new_workspace, handle); | |||
| } | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,259 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm/lstm_utils.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/arm_common/lstm_cell/cell_kernel.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/naive/lstm/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| template <class CellOp, class States> | |||
| void cell_opr_compute( | |||
| _megdnn_tensor_in step_input, _megdnn_tensor_in weight_ih, | |||
| _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in bias_hh, const States& state_in, States& state_out, | |||
| Workspace cell_workspace, Handle* handle); | |||
| struct LstmCellWeight { | |||
| size_t m_weight_size = 0; | |||
| TensorND m_weight_ih, m_weight_hh, m_bias_ih, m_bias_hh; | |||
| // if no bias, will create dummy bias tensor from workspace | |||
| LstmCellWeight( | |||
| RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
| DType dtype); | |||
| }; | |||
| struct LstmStates { | |||
| static size_t nr_states() { return 2; } | |||
| size_t m_memory_size; | |||
| TensorND m_h, m_c; | |||
| LstmStates( | |||
| const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size, | |||
| DType dtype); | |||
| }; | |||
| TensorNDArray split_tensor( | |||
| _megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout); | |||
| template <class CellWeight> | |||
| SmallVector<CellWeight> get_all_cells( | |||
| size_t dir_size, size_t num_layers, size_t input_size, size_t hidden_size, | |||
| bool bias, _megdnn_tensor_in flatten_weights) { | |||
| SmallVector<CellWeight> cell_weights; | |||
| cell_weights.reserve(dir_size * num_layers); | |||
| auto weight_ptr = flatten_weights.get_ref_ptr(); | |||
| for (size_t layer = 0; layer < num_layers; ++layer) { | |||
| for (size_t d = 0; d < dir_size; ++d) { | |||
| size_t cell_input_size = layer == 0 ? input_size : dir_size * hidden_size; | |||
| CellWeight cell_weight( | |||
| weight_ptr, hidden_size, cell_input_size, bias, | |||
| flatten_weights.layout.dtype); | |||
| weight_ptr += cell_weight.m_weight_size; | |||
| cell_weights.push_back(cell_weight); | |||
| } | |||
| } | |||
| return cell_weights; | |||
| } | |||
| template <class States> | |||
| SmallVector<States> get_all_status( | |||
| _megdnn_tensor_in hx, _megdnn_tensor_in cx, size_t hidden_size, | |||
| size_t batch_size, size_t num_layers, size_t dir_size, DType dtype) { | |||
| SmallVector<States> states; | |||
| auto hx_ptr = hx.get_ref_ptr(); | |||
| auto cx_ptr = cx.get_ref_ptr(); | |||
| for (size_t layer = 0; layer < num_layers * dir_size; ++layer) { | |||
| States state({hx_ptr, cx_ptr}, hidden_size, batch_size, dtype); | |||
| hx_ptr += state.m_memory_size; | |||
| cx_ptr += state.m_memory_size; | |||
| states.push_back(state); | |||
| } | |||
| return states; | |||
| } | |||
| template <class Cell, typename CellOpr, class States> | |||
| void exec_kernel( | |||
| SmallVector<Cell>& cells, const TensorNDArray& inputs, | |||
| const SmallVector<States>& states_in, SmallVector<States>& states_out, | |||
| TensorNDArray& outputs, size_t num_layers, size_t dir_size, Handle* handle, | |||
| WorkspaceBundle workspace_bundle) { | |||
| megdnn_assert(cells.size() == num_layers * dir_size); | |||
| megdnn_assert( | |||
| states_in.size() == states_out.size() && | |||
| states_in.size() == num_layers * dir_size); | |||
| megdnn_assert(outputs.size() == inputs.size()); | |||
| //! two tmp state workspace | |||
| megdnn_assert(workspace_bundle.nr_workspace() == 4 + States::nr_states()); | |||
| size_t seq_len = inputs.size(); | |||
| size_t batch_size = inputs[0].layout.shape[0]; | |||
| size_t input_size = inputs[0].layout.shape[1]; | |||
| size_t hidden_size = cells[0].m_weight_hh.layout.shape[1]; | |||
| TensorLayout batch_output_layout{ | |||
| {hidden_size}, outputs[0].layout.dtype}; // output hy | |||
| TensorLayout cell_output_layout{ | |||
| {batch_size, hidden_size}, outputs[0].layout.dtype}; // output hy | |||
| TensorLayout seq_output_layout{ | |||
| {batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; | |||
| TensorLayout cell_first_input_layout{ | |||
| {batch_size, input_size}, inputs[0].layout.dtype}; // input | |||
| TensorLayout cell_input_layout{ | |||
| {batch_size, dir_size * hidden_size}, inputs[0].layout.dtype}; | |||
| TensorLayout tmp_output_layout{ | |||
| {seq_len, batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; | |||
| //! workspace get | |||
| Workspace cell_workspace( | |||
| static_cast<dt_byte*>(workspace_bundle.get(0)), | |||
| workspace_bundle.get_size(0) + workspace_bundle.get_size(1)); | |||
| auto&& tmp_inputs_1 = split_tensor( | |||
| TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, | |||
| cell_input_layout); | |||
| auto&& tmp_outputs_1 = split_tensor( | |||
| TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, | |||
| seq_output_layout); | |||
| auto&& tmp_inputs_2 = split_tensor( | |||
| TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, | |||
| cell_input_layout); | |||
| auto&& tmp_outputs_2 = split_tensor( | |||
| TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, | |||
| seq_output_layout); | |||
| using IoPair = std::pair<TensorNDArray, TensorNDArray>; | |||
| IoPair io_pair1 = {tmp_inputs_1, tmp_outputs_2}; | |||
| IoPair io_pair2 = {tmp_inputs_2, tmp_outputs_1}; | |||
| SmallVector<IoPair> io_pairs = {io_pair1, io_pair2}; | |||
| SmallVector<RefPtr> ptr; | |||
| for (size_t index = 0; index < States::nr_states(); index++) { | |||
| ptr.push_back(workspace_bundle.get(4 + index)); | |||
| } | |||
| auto&& tmp_state = States(ptr, hidden_size, batch_size, outputs[0].layout.dtype); | |||
| for (size_t layer = 0; layer < num_layers; layer++) { | |||
| auto layer_inputs = io_pairs[layer % 2].first; | |||
| auto layer_outputs = io_pairs[layer % 2].second; | |||
| //! if last layer, direct write to output tensors | |||
| if (num_layers - 1 == layer) { | |||
| layer_outputs = outputs; | |||
| } | |||
| if (0 == layer) { | |||
| layer_inputs = inputs; | |||
| } | |||
| for (size_t d = 0; d < dir_size; ++d) { | |||
| size_t cell_idx = layer * dir_size + d; | |||
| auto& cell = cells[cell_idx]; | |||
| auto& state_in_origin = states_in[cell_idx]; | |||
| auto& state_out_origin = states_out[cell_idx]; | |||
| auto state_in = state_in_origin; | |||
| auto state_out = tmp_state; | |||
| for (size_t i = 0; i < seq_len; ++i) { | |||
| size_t step = d == 0 ? i : seq_len - 1 - i; | |||
| auto& step_input = layer_inputs[step]; | |||
| auto& step_output = layer_outputs[step]; | |||
| if (i == seq_len - 1) { | |||
| state_out = state_out_origin; | |||
| } | |||
| //! task 1 | |||
| //! this CellOp will dispatch task inner, so here not dispatch task | |||
| cell_opr_compute<CellOpr, LstmStates>( | |||
| step_input, cell.m_weight_ih, cell.m_weight_hh, cell.m_bias_ih, | |||
| cell.m_bias_hh, state_in, state_out, cell_workspace, handle); | |||
| //! task 2 | |||
| //! copy output to continue space | |||
| auto copy_to_output = [=]() { | |||
| //! if dir_size >1 and batch_size > 1, recorder to output | |||
| size_t stride = batch_output_layout.span().dist_byte(); | |||
| if (dir_size > 1 && batch_size > 1) { | |||
| int8_t* source = static_cast<int8_t*>(state_out.m_h.raw_ptr()); | |||
| int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) + | |||
| d * stride; | |||
| for (size_t b = 0; b < batch_size; b++) { | |||
| memcpy(dst, source, stride); | |||
| source += stride; | |||
| dst += dir_size * stride; | |||
| } | |||
| } else { | |||
| void* source = state_out.m_h.raw_ptr(); | |||
| int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) + | |||
| d * stride; | |||
| memcpy(dst, source, state_out.m_h.layout.span().dist_byte()); | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN( | |||
| static_cast<naive::HandleImpl*>(handle), copy_to_output()); | |||
| //! state_in and state_out are read and write inplace | |||
| if (0 == i) { | |||
| state_in = tmp_state; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename CellOpr> | |||
| WorkspaceBundle get_workspace_bundle( | |||
| const TensorLayout& input, const TensorLayout& output, | |||
| const TensorLayout& flatten_weights, size_t hidden_size, size_t dir_size, | |||
| size_t states_size) { | |||
| size_t batch_size = input.shape[1]; | |||
| size_t input_size = input.shape[2]; | |||
| size_t gate_hidden_size = flatten_weights.shape[0]; | |||
| // cell workspace | |||
| TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; | |||
| TensorLayout weight_hh{ | |||
| {gate_hidden_size, dir_size * hidden_size}, flatten_weights.dtype}; | |||
| TensorLayout bias{{1, gate_hidden_size}, flatten_weights.dtype}; | |||
| TensorLayout hx{{batch_size, dir_size * hidden_size}, input.dtype}; | |||
| auto cell_opr = inplace_cpu_handle()->create_operator<CellOpr>(); | |||
| TensorLayout h_new, c_new, gates; | |||
| cell_opr->deduce_layout( | |||
| input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); | |||
| SmallVector<size_t> workspaces; | |||
| //! the cell opr compute workspace | |||
| size_t cell_opr_workspace = cell_opr->get_workspace_in_bytes( | |||
| input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); | |||
| workspaces.push_back(gates.span().dist_byte()); | |||
| workspaces.push_back(cell_opr_workspace); | |||
| //! double tmp output memory | |||
| size_t tmp_output_workspace = output.span().dist_byte(); | |||
| workspaces.push_back(tmp_output_workspace); | |||
| workspaces.push_back(tmp_output_workspace); | |||
| //! tmp states memory | |||
| size_t tmp_state_workspace = hx.span().dist_byte(); | |||
| for (size_t i = 0; i < states_size; i++) { | |||
| workspaces.push_back(tmp_state_workspace); | |||
| } | |||
| return {nullptr, workspaces}; | |||
| } | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/arm_common/lstm/opr_impl.h" | |||
| #include "./lstm_utils.h" | |||
| #include "src/arm_common/lstm_cell/opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_arm_common_lstm) | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| void LSTMImpl::exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
| _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
| _megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out, | |||
| _megdnn_workspace workspace) { | |||
| MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(0)) { | |||
| size_t dir_size = param().bidirectional ? 2 : 1; | |||
| size_t num_layers = param().num_layers; | |||
| size_t hidden_size = param().hidden_size; | |||
| size_t seq_len = input.layout.shape[0]; | |||
| size_t batch_size = input.layout.shape[1]; | |||
| size_t input_size = input.layout.shape[2]; | |||
| //! in order to support input ptr change in record, so this task should be | |||
| //! dispatch to device | |||
| auto&& cell_weights = get_all_cells<LstmCellWeight>( | |||
| dir_size, num_layers, input_size, hidden_size, param().bias, | |||
| flatten_weights); | |||
| auto&& cell_states_in = get_all_status<LstmStates>( | |||
| hx, cx, hidden_size, batch_size, num_layers, dir_size, hx.layout.dtype); | |||
| auto&& cell_states_out = get_all_status<LstmStates>( | |||
| hy, cy, hidden_size, batch_size, num_layers, dir_size, hy.layout.dtype); | |||
| auto&& inputs = split_tensor( | |||
| input, seq_len, | |||
| TensorLayout{{batch_size, input_size}, input.layout.dtype}); | |||
| auto&& outputs = split_tensor( | |||
| output, seq_len, | |||
| TensorLayout{ | |||
| {batch_size, dir_size * hidden_size}, output.layout.dtype}); | |||
| auto workspace_bundle = get_workspace_bundle<LSTMCell>( | |||
| input.layout, output.layout, flatten_weights.layout, hidden_size, | |||
| dir_size, LstmStates::nr_states()); | |||
| workspace_bundle.set(workspace.raw_ptr); | |||
| exec_kernel<LstmCellWeight, LSTMCell, LstmStates>( | |||
| cell_weights, inputs, cell_states_in, cell_states_out, outputs, | |||
| num_layers, dir_size, handle(), workspace_bundle); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| size_t LSTMImpl::get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout& flatten_weights, const TensorLayout& output, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||
| MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(1)) { | |||
| size_t dir_size = param().bidirectional ? 2 : 1; | |||
| size_t hidden_size = param().hidden_size; | |||
| auto bundle = get_workspace_bundle<LSTMCell>( | |||
| input, output, flatten_weights, hidden_size, dir_size, | |||
| LstmStates::nr_states()); | |||
| return bundle.total_size_in_bytes(); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/lstm/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| class LSTMImpl : public naive::LSTMImpl { | |||
| public: | |||
| using naive::LSTMImpl::LSTMImpl; | |||
| void exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
| _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
| _megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||
| _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
| const TensorLayout& flatten_weights, const TensorLayout& output, | |||
| const TensorLayout& hy, const TensorLayout& cy, | |||
| const TensorLayout& reserve_space) override; | |||
| //! in arm_common only store the output tensor, other tensor is only | |||
| //! used in computing grad, so arm ignore them | |||
| size_t get_reserve_size_in_bytes(const TensorLayout&) override { return 1; } | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,273 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm_cell/cell_kernel.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./cell_kernel.h" | |||
| #include "src/arm_common/lstm_cell/opr_impl.h" | |||
| #include "src/common/lstm_cell.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/sigmoid.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/tanh.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| namespace { | |||
| template <class Op, bool bias> | |||
| struct ElemwiseCompute { | |||
| static Op op; | |||
| static inline float32x4x2_t compute_8( | |||
| float* dst, float* tmp, float* ih, float* hh) { | |||
| float32x4_t dst0 = vld1q_f32(dst); | |||
| float32x4_t dst1 = vld1q_f32(dst + 4); | |||
| float32x4_t tmp0 = vld1q_f32(tmp); | |||
| float32x4_t tmp1 = vld1q_f32(tmp + 4); | |||
| auto mid0 = vaddq_f32(dst0, tmp0); | |||
| auto mid1 = vaddq_f32(dst1, tmp1); | |||
| float32x4_t out0, out1; | |||
| if (bias) { | |||
| float32x4_t ih0 = vld1q_f32(ih); | |||
| float32x4_t ih1 = vld1q_f32(ih + 4); | |||
| float32x4_t hh0 = vld1q_f32(hh); | |||
| float32x4_t hh1 = vld1q_f32(hh + 4); | |||
| auto midd0 = vaddq_f32(ih0, hh0); | |||
| auto midd1 = vaddq_f32(ih1, hh1); | |||
| out0 = vaddq_f32(mid0, midd0); | |||
| out1 = vaddq_f32(mid1, midd1); | |||
| } else { | |||
| out0 = mid0; | |||
| out1 = mid1; | |||
| } | |||
| return {{op(out0), op(out1)}}; | |||
| } | |||
| static inline float32x4_t compute_4(float* dst, float* tmp, float* ih, float* hh) { | |||
| float32x4_t dst0 = vld1q_f32(dst); | |||
| float32x4_t tmp0 = vld1q_f32(tmp); | |||
| auto mid0 = vaddq_f32(dst0, tmp0); | |||
| float32x4_t out0; | |||
| if (bias) { | |||
| float32x4_t ih0 = vld1q_f32(ih); | |||
| float32x4_t hh0 = vld1q_f32(hh); | |||
| auto midd0 = vaddq_f32(ih0, hh0); | |||
| out0 = vaddq_f32(mid0, midd0); | |||
| } else { | |||
| out0 = mid0; | |||
| } | |||
| return op(out0); | |||
| } | |||
| static inline float compute_1(float* dst, float* tmp, float* ih, float* hh) { | |||
| float out; | |||
| if (bias) { | |||
| out = dst[0] + tmp[0] + ih[0] + hh[0]; | |||
| } else { | |||
| out = dst[0] + tmp[0]; | |||
| } | |||
| return op(out); | |||
| } | |||
| }; | |||
| template <class Op, bool bias> | |||
| Op ElemwiseCompute<Op, bias>::op = Op(); | |||
| template <bool bias> | |||
| void rnn_cell_elemwise_compute( | |||
| _megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in bias_hh, _megdnn_tensor_in cx, _megdnn_tensor_out h_new, | |||
| _megdnn_tensor_out c_new) { | |||
| size_t batch = dst.layout[0]; | |||
| size_t batch_length = dst.layout.total_nr_elems() / batch; | |||
| size_t base_length = batch_length / 4; | |||
| float *ih_ptr_ = nullptr, *hh_ptr_ = nullptr; | |||
| float* dst_ptr_ = dst.ptr<float>(); | |||
| float* tmp_ptr_ = tmp.ptr<float>(); | |||
| if (bias) { | |||
| ih_ptr_ = bias_ih.ptr<float>(); | |||
| hh_ptr_ = bias_hh.ptr<float>(); | |||
| } | |||
| float* cx_ptr_ = cx.ptr<float>(); | |||
| float* h_new_ptr_ = h_new.ptr<float>(); | |||
| float* c_new_ptr_ = c_new.ptr<float>(); | |||
| ElemwiseCompute<SigmoidOp<dt_float32>, bias> sigmoid_compute; | |||
| ElemwiseCompute<TanhOp<dt_float32>, bias> tanh_compute; | |||
| TanhOp<dt_float32> tanh_op; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| float* dst_ptr = dst_ptr_ + b * batch_length; | |||
| float* tmp_ptr = tmp_ptr_ + b * batch_length; | |||
| float* ih_ptr = ih_ptr_; | |||
| float* hh_ptr = hh_ptr_; | |||
| float* cx_ptr = cx_ptr_ + b * base_length; | |||
| float* h_new_ptr = h_new_ptr_ + b * base_length; | |||
| float* c_new_ptr = c_new_ptr_ + b * base_length; | |||
| size_t index = 0; | |||
| for (; index + 7 < base_length; index += 8) { | |||
| auto out_i = sigmoid_compute.compute_8(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||
| auto out_f = sigmoid_compute.compute_8( | |||
| dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||
| hh_ptr + base_length); | |||
| auto out_g = tanh_compute.compute_8( | |||
| dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||
| ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||
| auto out_o = sigmoid_compute.compute_8( | |||
| dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||
| ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||
| float32x4_t cx_0 = vld1q_f32(cx_ptr); | |||
| float32x4_t cx_1 = vld1q_f32(cx_ptr + 4); | |||
| //! f * cx + i * g | |||
| auto c_new_0 = vaddq_f32( | |||
| vmulq_f32(out_f.val[0], cx_0), | |||
| vmulq_f32(out_i.val[0], out_g.val[0])); | |||
| auto c_new_1 = vaddq_f32( | |||
| vmulq_f32(out_f.val[1], cx_1), | |||
| vmulq_f32(out_i.val[1], out_g.val[1])); | |||
| vst1q_f32(c_new_ptr, c_new_0); | |||
| vst1q_f32(c_new_ptr + 4, c_new_1); | |||
| auto h_new_0 = vmulq_f32(tanh_op(c_new_0), out_o.val[0]); | |||
| auto h_new_1 = vmulq_f32(tanh_op(c_new_1), out_o.val[1]); | |||
| vst1q_f32(h_new_ptr, h_new_0); | |||
| vst1q_f32(h_new_ptr + 4, h_new_1); | |||
| dst_ptr += 8; | |||
| tmp_ptr += 8; | |||
| ih_ptr += 8; | |||
| hh_ptr += 8; | |||
| cx_ptr += 8; | |||
| c_new_ptr += 8; | |||
| h_new_ptr += 8; | |||
| } | |||
| for (; index + 3 < base_length; index += 4) { | |||
| auto out_i = sigmoid_compute.compute_4(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||
| auto out_f = sigmoid_compute.compute_4( | |||
| dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||
| hh_ptr + base_length); | |||
| auto out_g = tanh_compute.compute_4( | |||
| dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||
| ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||
| auto out_o = sigmoid_compute.compute_4( | |||
| dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||
| ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||
| float32x4_t cx_v = vld1q_f32(cx_ptr); | |||
| //! f * cx + i * g | |||
| auto c_new = vaddq_f32(vmulq_f32(out_f, cx_v), vmulq_f32(out_i, out_g)); | |||
| vst1q_f32(c_new_ptr, c_new); | |||
| auto h_new = vmulq_f32(tanh_op(c_new), out_o); | |||
| vst1q_f32(h_new_ptr, h_new); | |||
| dst_ptr += 4; | |||
| tmp_ptr += 4; | |||
| ih_ptr += 4; | |||
| hh_ptr += 4; | |||
| cx_ptr += 4; | |||
| c_new_ptr += 4; | |||
| h_new_ptr += 4; | |||
| } | |||
| for (; index < base_length; index++) { | |||
| auto out_i = sigmoid_compute.compute_1(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||
| auto out_f = sigmoid_compute.compute_1( | |||
| dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||
| hh_ptr + base_length); | |||
| auto out_g = tanh_compute.compute_1( | |||
| dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||
| ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||
| auto out_o = sigmoid_compute.compute_1( | |||
| dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||
| ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||
| c_new_ptr[0] = out_f * cx_ptr[0] + out_i * out_g; | |||
| h_new_ptr[0] = tanh_op(c_new_ptr[0]) * out_o; | |||
| dst_ptr += 1; | |||
| tmp_ptr += 1; | |||
| ih_ptr += 1; | |||
| hh_ptr += 1; | |||
| cx_ptr += 1; | |||
| c_new_ptr += 1; | |||
| h_new_ptr += 1; | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| void LstmCellCompute::run( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
| _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { | |||
| auto bundle = get_workspace_bundle( | |||
| input.layout, weight_ih.layout, bias_ih.layout, hx.layout, weight_hh.layout, | |||
| bias_hh.layout, cx.layout, h_new.layout, c_new.layout, gates.layout); | |||
| bundle.set(workspace.raw_ptr); | |||
| TensorND tmp{static_cast<void*>(bundle.get(0)), gates.layout}; | |||
| auto matmul_workspace = | |||
| megdnn::Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)}; | |||
| auto opr = handle->create_operator<MatrixMul>(); | |||
| opr->param().transposeB = true; | |||
| //! the opr will dispatch compute task to device, so record mode | |||
| //! performance will not be effect | |||
| opr->exec(input, weight_ih, tmp, matmul_workspace); | |||
| opr->exec(hx, weight_hh, gates, matmul_workspace); | |||
| //! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) | |||
| if (bias_ih.layout.ndim != 0 && bias_ih.layout.ndim != 0) { | |||
| MEGDNN_DISPATCH_CPU_KERN( | |||
| static_cast<naive::HandleImpl*>(handle), | |||
| rnn_cell_elemwise_compute<true>( | |||
| gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); | |||
| } else { | |||
| megdnn_assert(bias_ih.layout.ndim == 0 && bias_ih.layout.ndim == 0); | |||
| MEGDNN_DISPATCH_CPU_KERN( | |||
| static_cast<naive::HandleImpl*>(handle), | |||
| rnn_cell_elemwise_compute<false>( | |||
| gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); | |||
| } | |||
| } | |||
| WorkspaceBundle LstmCellCompute::get_workspace_bundle( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, | |||
| const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout& gates) { | |||
| auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||
| opr->param().transposeB = true; | |||
| size_t matmul_workspace = std::max( | |||
| opr->get_workspace_in_bytes(input, weight_ih, gates), | |||
| opr->get_workspace_in_bytes(hx, weight_hh, gates)); | |||
| return WorkspaceBundle{nullptr, {gates.span().dist_byte(), matmul_workspace}}; | |||
| } | |||
| bool LstmCellCompute::is_optimized( | |||
| const TensorLayout& input, const TensorLayout&, const TensorLayout& bias_ih, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout& bias_hh, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout& gates) { | |||
| if (input.dtype.enumv() == DTypeEnum::Float32 && gates[1] == bias_ih[1] && | |||
| bias_ih[0] == 1 && bias_ih.eq_layout(bias_hh)) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm_cell/cell_kernel.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/naive/lstm_cell/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| struct LstmCellCompute { | |||
| static WorkspaceBundle get_workspace_bundle( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& cx, const TensorLayout& h_new, | |||
| const TensorLayout& c_new, const TensorLayout& gates); | |||
| static bool is_optimized( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& cx, const TensorLayout& h_new, | |||
| const TensorLayout& c_new, const TensorLayout& gates); | |||
| static void run( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
| _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
| _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
| _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm_cell/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/arm_common/lstm_cell/opr_impl.h" | |||
| #include "src/common/lstm_cell.h" | |||
| #include "src/naive/handle.h" | |||
| #include "./cell_kernel.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_arm_common_lstm_cell) | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| void LSTMCellImpl::exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
| _megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||
| //! only float32 and {1, xx} shape bias will be optimized | |||
| MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(0)) { | |||
| if (!LstmCellCompute::is_optimized( | |||
| input.layout, weight_ih.layout, bias_ih.layout, hx.layout, | |||
| weight_hh.layout, bias_hh.layout, cx.layout, h_new.layout, | |||
| c_new.layout, gates.layout)) { | |||
| naive::LSTMCellImpl::exec( | |||
| input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||
| gates, workspace); | |||
| } else { | |||
| LstmCellCompute::run( | |||
| input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||
| gates, workspace, handle()); | |||
| } | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| size_t LSTMCellImpl::get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
| const TensorLayout& gates) { | |||
| MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(1)) { | |||
| if (!LstmCellCompute::is_optimized( | |||
| input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||
| gates)) { | |||
| return naive::LSTMCellImpl::get_workspace_in_bytes( | |||
| input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||
| gates); | |||
| } else { | |||
| return LstmCellCompute::get_workspace_bundle( | |||
| input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, | |||
| c_new, gates) | |||
| .total_size_in_bytes(); | |||
| } | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/lstm_cell/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/lstm_cell/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| class LSTMCellImpl : public naive::LSTMCellImpl { | |||
| public: | |||
| using naive::LSTMCellImpl::LSTMCellImpl; | |||
| void exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
| _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
| _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
| _megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& cx, const TensorLayout& h_new, | |||
| const TensorLayout& c_new, const TensorLayout& gates) override; | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,218 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/rnn_cell/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/arm_common/rnn_cell/opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/none.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/relu.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/tanh.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_arm_common_rnn_cell) | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| namespace { | |||
| ElemwiseForward* get_elemwise_opr() { | |||
| static CpuOprDelegationStorage<1> storage; | |||
| return storage.get<ElemwiseForward>(); | |||
| } | |||
| template <typename Op> | |||
| void elemwise_compute( | |||
| float* dst_ptr, float* tmp_ptr, float* ih_ptr, float* hh_ptr, size_t batch, | |||
| size_t length) { | |||
| const constexpr size_t SIMD_8 = 8; | |||
| const constexpr size_t SIMD_4 = 4; | |||
| Op op; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| float* dst = dst_ptr + b * length; | |||
| float* tmp = tmp_ptr + b * length; | |||
| float* ih = ih_ptr; | |||
| float* hh = hh_ptr; | |||
| size_t index = 0; | |||
| for (; index + SIMD_8 - 1 < length; index += SIMD_8) { | |||
| float32x4_t dst0 = vld1q_f32(dst); | |||
| float32x4_t dst1 = vld1q_f32(dst + 4); | |||
| float32x4_t tmp0 = vld1q_f32(tmp); | |||
| float32x4_t tmp1 = vld1q_f32(tmp + 4); | |||
| float32x4_t ih0 = vld1q_f32(ih); | |||
| float32x4_t ih1 = vld1q_f32(ih + 4); | |||
| float32x4_t hh0 = vld1q_f32(hh); | |||
| float32x4_t hh1 = vld1q_f32(hh + 4); | |||
| auto mid0 = vaddq_f32(dst0, tmp0); | |||
| auto mid1 = vaddq_f32(dst1, tmp1); | |||
| auto midd0 = vaddq_f32(ih0, hh0); | |||
| auto midd1 = vaddq_f32(ih1, hh1); | |||
| auto out0 = vaddq_f32(mid0, midd0); | |||
| auto out1 = vaddq_f32(mid1, midd1); | |||
| vst1q_f32(dst, op(out0)); | |||
| vst1q_f32(dst + 4, op(out1)); | |||
| dst += SIMD_8; | |||
| tmp += SIMD_8; | |||
| ih += SIMD_8; | |||
| hh += SIMD_8; | |||
| } | |||
| for (; index + SIMD_4 - 1 < length; index += SIMD_4) { | |||
| float32x4_t dst0 = vld1q_f32(dst); | |||
| float32x4_t tmp0 = vld1q_f32(tmp); | |||
| float32x4_t ih0 = vld1q_f32(ih); | |||
| float32x4_t hh0 = vld1q_f32(hh); | |||
| auto mid0 = vaddq_f32(dst0, tmp0); | |||
| auto midd0 = vaddq_f32(ih0, hh0); | |||
| auto out0 = vaddq_f32(mid0, midd0); | |||
| vst1q_f32(dst, op(out0)); | |||
| dst += SIMD_4; | |||
| tmp += SIMD_4; | |||
| ih += SIMD_4; | |||
| hh += SIMD_4; | |||
| } | |||
| for (; index < length; index++) { | |||
| auto out = dst[0] + tmp[0] + ih[0] + hh[0]; | |||
| dst[0] = op(out); | |||
| dst++; | |||
| tmp++; | |||
| ih++; | |||
| hh++; | |||
| } | |||
| } | |||
| } | |||
| void rnn_cell_post_compute( | |||
| _megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in bias_hh, param::RNNCell::NonlineMode nonline_mode, | |||
| Handle* handle) { | |||
| using NonlineMode = param::RNNCell::NonlineMode; | |||
| megdnn_assert( | |||
| nonline_mode == NonlineMode::RELU || nonline_mode == NonlineMode::TANH || | |||
| nonline_mode == NonlineMode::IDENTITY, | |||
| "Now arm only support nonlinear mode Relu, TANH, IDENTITY."); | |||
| if (dst.layout.dtype.enumv() == DTypeEnum::Float32 && | |||
| dst.layout[1] == bias_ih.layout[1] && bias_ih.layout[0] == 1 && | |||
| bias_ih.layout.eq_layout(bias_hh.layout)) { | |||
| auto run = [=]() { | |||
| size_t batch = dst.layout[0]; | |||
| size_t length = bias_ih.layout.total_nr_elems(); | |||
| float* dst_ptr = dst.ptr<float>(); | |||
| float* tmp_ptr = tmp.ptr<float>(); | |||
| float* ih_ptr = bias_ih.ptr<float>(); | |||
| float* hh_ptr = bias_hh.ptr<float>(); | |||
| if (nonline_mode == NonlineMode::RELU) { | |||
| elemwise_compute<ReluOp<dt_float32>>( | |||
| dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||
| } else if (nonline_mode == NonlineMode::TANH) { | |||
| elemwise_compute<TanhOp<dt_float32>>( | |||
| dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||
| } else { | |||
| elemwise_compute<NoneOp<dt_float32>>( | |||
| dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run()); | |||
| } else { | |||
| //! this opr must be created by inplace handle | |||
| auto elem_opr = get_elemwise_opr(); | |||
| auto run = [=]() { | |||
| elem_opr->param().mode = Elemwise::Param::Mode::ADD; | |||
| elem_opr->exec({dst, tmp}, dst); | |||
| elem_opr->exec({dst, bias_ih}, dst); | |||
| elem_opr->exec({dst, bias_hh}, dst); | |||
| // activation | |||
| switch (nonline_mode) { | |||
| #define cb(_mode) \ | |||
| case NonlineMode::_mode: { \ | |||
| elem_opr->param().mode = Elemwise::Param::Mode::_mode; \ | |||
| elem_opr->exec({dst}, dst); \ | |||
| break; \ | |||
| } | |||
| cb(RELU); | |||
| cb(TANH); | |||
| #undef cb | |||
| case NonlineMode::IDENTITY: | |||
| break; | |||
| default: | |||
| megdnn_throw("unsupport nonlinear mode."); | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run()); | |||
| } | |||
| } | |||
| } // namespace | |||
| WorkspaceBundle RNNCellImpl::get_workspace_bundle( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, | |||
| const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, | |||
| const TensorLayout& dst) { | |||
| MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(0)) { | |||
| auto opr = handle()->create_operator<MatrixMulForward>(); | |||
| opr->param().transposeB = true; | |||
| auto matmul_workspace = std::max( | |||
| opr->get_workspace_in_bytes(input, weight_ih, dst), | |||
| opr->get_workspace_in_bytes(hx, weight_hh, dst)); | |||
| auto tmp_workspace = dst.span().dist_byte(); | |||
| return WorkspaceBundle{nullptr, {tmp_workspace, matmul_workspace}}; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| size_t RNNCellImpl::get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& dst) { | |||
| return get_workspace_bundle(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst) | |||
| .total_size_in_bytes(); | |||
| } | |||
| void RNNCellImpl::exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
| _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(1)) { | |||
| auto bundle = get_workspace_bundle( | |||
| input.layout, weight_ih.layout, bias_ih.layout, hx.layout, | |||
| weight_hh.layout, bias_hh.layout, dst.layout); | |||
| bundle.set(workspace.raw_ptr); | |||
| auto nonline_mode = param().nonlineMode; | |||
| TensorND tmp{static_cast<void*>(bundle.get(0)), dst.layout}; | |||
| auto new_workspace = | |||
| Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)}; | |||
| //! this opr can't be created by inplace handle | |||
| auto opr = handle()->create_operator<MatrixMulForward>(); | |||
| opr->param().transposeB = true; | |||
| //! the opr will dispatch compute task to device, so record mode | |||
| //! performance will not be effect | |||
| opr->exec(input, weight_ih, tmp, new_workspace); | |||
| opr->exec(hx, weight_hh, dst, new_workspace); | |||
| //! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) | |||
| rnn_cell_post_compute(dst, tmp, bias_ih, bias_hh, nonline_mode, handle()); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/rnn_cell/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/naive/rnn_cell/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| class RNNCellImpl : public naive::RNNCellImpl { | |||
| public: | |||
| using naive::RNNCellImpl::RNNCellImpl; | |||
| void exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
| _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
| _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& dst) override; | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle( | |||
| const TensorLayout& input, const TensorLayout& weight_ih, | |||
| const TensorLayout& bias_ih, const TensorLayout& hx, | |||
| const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
| const TensorLayout& dst); | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -31,4 +31,6 @@ public: | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,225 @@ | |||
| /** | |||
| * \file dnn/test/arm_common/lstm.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/arm_common/fixture.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "megdnn/oprs/general.h" | |||
| #include "test/common/benchmarker.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/task_record_check.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| namespace { | |||
| //! in arm_common the reserve tensor is not used | |||
| void output_canonizer(const CheckerHelper::TensorValueArray& arr) { | |||
| const TensorND& reserve = arr.back(); | |||
| TensorND& modif_reserve = const_cast<TensorND&>(reserve); | |||
| modif_reserve.layout = TensorLayout(); | |||
| } | |||
| } // namespace | |||
| TEST_F(ARM_COMMON, LSTMCell) { | |||
| Checker<LSTMCell> checker(handle()); | |||
| checker.set_output_canonizer(output_canonizer); | |||
| checker.exec( | |||
| {{1, 10}, | |||
| {40, 10}, | |||
| {1, 40}, | |||
| {1, 10}, | |||
| {40, 10}, | |||
| {1, 40}, | |||
| {1, 10}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| for (size_t batch : {2}) | |||
| for (size_t n : {3, 4, 5, 23, 100}) | |||
| for (size_t out : {3, 6, 25, 100}) { | |||
| checker.exec( | |||
| {{batch, n}, | |||
| {out * 4, n}, | |||
| {1, out * 4}, | |||
| {batch, out}, | |||
| {out * 4, out}, | |||
| {1, out * 4}, | |||
| {batch, out}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| checker.exec( | |||
| {{batch, n}, | |||
| {out * 4, n}, | |||
| {batch, out * 4}, | |||
| {batch, out}, | |||
| {out * 4, out}, | |||
| {batch, out * 4}, | |||
| {batch, out}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| } | |||
| } | |||
| TEST_F(ARM_COMMON, LSTMCellRecord) { | |||
| TaskRecordChecker<LSTMCell> checker(0); | |||
| checker.exec( | |||
| {{1, 10}, | |||
| {40, 10}, | |||
| {1, 40}, | |||
| {1, 10}, | |||
| {40, 10}, | |||
| {1, 40}, | |||
| {1, 10}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| } | |||
| namespace { | |||
| void test_lstm(bool bias, bool direction, Handle* handle) { | |||
| Checker<LSTM> checker(handle, true); | |||
| //! because lstm has tanh, exp mathematical compute, after more iteration, | |||
| //! the error will more than 1e-3 | |||
| checker.set_epsilon(1e-2); | |||
| checker.set_output_canonizer(output_canonizer); | |||
| for (size_t input_size : {2, 8, 13}) | |||
| for (size_t hidden_size : {1, 4, 17}) { | |||
| size_t dir_size = direction == false ? 1 : 2; | |||
| LSTM::Param param; | |||
| param.bidirectional = direction; | |||
| size_t gate_hidden_size = 4 * hidden_size; | |||
| param.bias = bias; | |||
| param.hidden_size = hidden_size; | |||
| for (size_t seq_len : {1, 3, 5}) | |||
| for (size_t batch_size : {1, 2, 4}) | |||
| for (size_t number_layer : {1, 2, 4, 5, 8}) { | |||
| size_t flatten_size = 0; | |||
| for (size_t layer = 0; layer < number_layer; layer++) { | |||
| for (size_t dir = 0; dir < dir_size; dir++) { | |||
| flatten_size += layer == 0 | |||
| ? input_size | |||
| : dir_size * hidden_size; // ih | |||
| flatten_size += hidden_size; // hh | |||
| } | |||
| } | |||
| if (bias) { | |||
| flatten_size += 2 * dir_size * number_layer; | |||
| } | |||
| param.num_layers = number_layer; | |||
| checker.set_param(param).exec( | |||
| {{seq_len, batch_size, input_size}, // input | |||
| {number_layer * dir_size, batch_size, | |||
| hidden_size}, // hx | |||
| {number_layer * dir_size, batch_size, | |||
| hidden_size}, // hy | |||
| {gate_hidden_size, flatten_size}, // flat weight | |||
| {}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) { | |||
| test_lstm(false, false, handle()); | |||
| } | |||
| TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) { | |||
| test_lstm(true, false, handle()); | |||
| } | |||
| TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) { | |||
| test_lstm(false, true, handle()); | |||
| } | |||
| TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) { | |||
| test_lstm(true, true, handle()); | |||
| } | |||
| TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) { | |||
| TaskRecordChecker<LSTM> checker(0); | |||
| size_t input_size = 2; | |||
| size_t hidden_size = 2; | |||
| size_t gate_hidden_size = 4 * hidden_size; | |||
| LSTM::Param param; | |||
| param.bidirectional = false; | |||
| param.bias = false; | |||
| param.hidden_size = hidden_size; | |||
| // checker.set_output_canonizer(output_canonizer); | |||
| for (size_t seq_len : {1, 3, 5}) | |||
| for (size_t batch_size : {1, 2, 4}) | |||
| for (size_t number_layer : {1, 2, 4, 5, 8}) { | |||
| param.num_layers = number_layer; | |||
| checker.set_param(param).exec( | |||
| {{seq_len, batch_size, input_size}, // input | |||
| {number_layer, batch_size, hidden_size}, // hx | |||
| {number_layer, batch_size, hidden_size}, // hy | |||
| {number_layer, gate_hidden_size, | |||
| input_size + hidden_size}, // flat weight | |||
| {}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) { | |||
| Benchmarker<LSTM> optimized_bench(handle()); | |||
| constexpr size_t RUNS = 20; | |||
| auto run = [&](size_t hidden_size, size_t input_size) { | |||
| optimized_bench.set_times(20).set_display(true); | |||
| size_t gate_hidden_size = 4 * hidden_size; | |||
| for (bool direction : {false, true}) { | |||
| LSTM::Param param; | |||
| param.hidden_size = hidden_size; | |||
| param.bidirectional = direction; | |||
| param.bias = false; | |||
| size_t dir_size = direction == false ? 1 : 2; | |||
| for (size_t seq_len : {1, 5, 8}) | |||
| for (size_t batch_size : {1, 8, 16}) | |||
| for (size_t number_layer : {1}) { | |||
| param.num_layers = number_layer; | |||
| size_t flatten_size = 0; | |||
| for (size_t layer = 0; layer < number_layer; layer++) { | |||
| for (size_t dir = 0; dir < dir_size; dir++) { | |||
| flatten_size += layer == 0 | |||
| ? input_size | |||
| : dir_size * hidden_size; // ih | |||
| flatten_size += hidden_size; // hh | |||
| } | |||
| } | |||
| optimized_bench.set_param(param).exec( | |||
| {{seq_len, batch_size, input_size}, // input | |||
| {number_layer * dir_size, batch_size, | |||
| hidden_size}, // hx | |||
| {number_layer * dir_size, batch_size, | |||
| hidden_size}, // hy | |||
| {gate_hidden_size, flatten_size}, // flat weight | |||
| {}, | |||
| {}, | |||
| {}, | |||
| {}}); | |||
| } | |||
| } | |||
| }; | |||
| run(512, 256); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * \file dnn/test/arm_common/rnn.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/arm_common/fixture.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/benchmarker.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/task_record_check.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| TEST_F(ARM_COMMON, RNNCell) { | |||
| Checker<RNNCell> checker(handle()); | |||
| using NonlineMode = param::RNNCell::NonlineMode; | |||
| param::RNNCell param; | |||
| for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) | |||
| for (size_t batch : {1, 4}) | |||
| for (size_t n : {3, 4, 5, 23, 100}) | |||
| for (size_t h : {5, 23, 100}) | |||
| for (size_t out : {3, 6, 25, 100}) { | |||
| param.nonlineMode = mode; | |||
| checker.set_param(param); | |||
| checker.exec( | |||
| {{batch, n}, | |||
| {out, n}, | |||
| {1, out}, | |||
| {batch, h}, | |||
| {out, h}, | |||
| {1, out}, | |||
| {}}); | |||
| checker.exec( | |||
| {{batch, n}, | |||
| {out, n}, | |||
| {batch, out}, | |||
| {batch, h}, | |||
| {out, h}, | |||
| {batch, out}, | |||
| {}}); | |||
| } | |||
| } | |||
| TEST_F(ARM_COMMON, RNNCellRecord) { | |||
| TaskRecordChecker<RNNCell> checker(0); | |||
| using NonlineMode = param::RNNCell::NonlineMode; | |||
| param::RNNCell param; | |||
| for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) { | |||
| param.nonlineMode = mode; | |||
| checker.set_param(param); | |||
| checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}}); | |||
| checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}}); | |||
| checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}}); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||