| @@ -23,6 +23,8 @@ namespace mindspore { | |||
| namespace kernel { | |||
| void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| using tag = dnnl::memory::format_tag; | |||
| using dim = dnnl::memory::dims; | |||
| std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | |||
| @@ -36,7 +38,9 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| if (bidirectional_) { | |||
| num_directions_ = 2; | |||
| } | |||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; | |||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { | |||
| MS_LOG(EXCEPTION) << "error iteration shape!"; | |||
| } | |||
| const int gate_size = 4 * hidden_size_; | |||
| for (int i = 0; i < num_layers_; ++i) { | |||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | |||
| @@ -44,18 +48,8 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| weight_size_ = weight_size_ * num_directions_; | |||
| weight_h_size_ = weight_h_size_ * num_directions_; | |||
| } | |||
| bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| using dt = dnnl::memory::data_type; | |||
| using tag = dnnl::memory::format_tag; | |||
| using dim = dnnl::memory::dims; | |||
| auto eng = MKLKernelEngine::Get().engine(); | |||
| dnnl::stream s(eng); | |||
| auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; | |||
| auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; | |||
| dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | |||
| if (bidirectional_) { | |||
| direction = dnnl::rnn_direction::bidirectional_concat; | |||
| @@ -63,68 +57,69 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| dim src_dims = {seq_len_, batch_size_, input_size_}; | |||
| dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | |||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | |||
| dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | |||
| dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | |||
| dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); | |||
| dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | |||
| dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | |||
| dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | |||
| dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc( | |||
| dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||
| auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | |||
| dnnl::lstm_forward::desc desc = | |||
| dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, | |||
| formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, | |||
| dst_desc, dst_h_desc, dst_c_desc); | |||
| prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng); | |||
| primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_); | |||
| AddArgument(DNNL_ARG_SRC_LAYER, src_desc); | |||
| AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); | |||
| AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); | |||
| AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); | |||
| AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); | |||
| AddArgument(DNNL_ARG_BIAS, bias_desc); | |||
| AddArgument(DNNL_ARG_DST_LAYER, dst_desc); | |||
| AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); | |||
| AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); | |||
| AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); | |||
| } | |||
| // construct fw memory | |||
| auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng); | |||
| auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||
| src_memory.set_data_handle(inputs[0]->addr); | |||
| auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||
| auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||
| src_h_memory.set_data_handle(inputs[1]->addr); | |||
| src_c_memory.set_data_handle(inputs[2]->addr); | |||
| auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||
| auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||
| bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| using dt = dnnl::memory::data_type; | |||
| using tag = dnnl::memory::format_tag; | |||
| auto eng = MKLKernelEngine::Get().engine(); | |||
| auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); | |||
| auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); | |||
| user_weights_memory.set_data_handle(inputs[3]->addr); | |||
| user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | |||
| auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng); | |||
| auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng); | |||
| dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); | |||
| dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); | |||
| auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng); | |||
| Reorder(&user_weights_memory, &weights_memory); | |||
| Reorder(&user_weights_h_memory, &weights_h_memory); | |||
| auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); | |||
| if (has_bias_) { | |||
| auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||
| user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); | |||
| bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| } else { | |||
| std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f); | |||
| write_to_dnnl_memory(net_bias.data(), bias_memory); | |||
| std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size()); | |||
| } | |||
| auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||
| auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||
| auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||
| dnnl::lstm_forward fw_layer(prim_desc); | |||
| workspace_memory.set_data_handle(outputs[3]->addr); | |||
| dst_memory.set_data_handle(outputs[0]->addr); | |||
| dst_h_memory.set_data_handle(outputs[1]->addr); | |||
| dst_c_memory.set_data_handle(outputs[2]->addr); | |||
| fw_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, | |||
| {DNNL_ARG_SRC_ITER, src_h_memory}, | |||
| {DNNL_ARG_SRC_ITER_C, src_c_memory}, | |||
| {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, | |||
| {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, | |||
| {DNNL_ARG_BIAS, bias_memory}, | |||
| {DNNL_ARG_DST_LAYER, dst_memory}, | |||
| {DNNL_ARG_DST_ITER, dst_h_memory}, | |||
| {DNNL_ARG_DST_ITER_C, dst_c_memory}, | |||
| {DNNL_ARG_WORKSPACE, workspace_memory}}); | |||
| s.wait(); | |||
| // set handle | |||
| SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); | |||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); | |||
| SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); | |||
| ExecutePrimitive(); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" | |||
| @@ -41,6 +41,10 @@ class LstmCPUKernel : public MKLCPUKernel { | |||
| int num_directions_; | |||
| bool bidirectional_; | |||
| bool has_bias_; | |||
| dnnl::memory::dims weights_dims_; | |||
| dnnl::memory::dims weights_h_dims_; | |||
| dnnl::memory::dims bias_dims_; | |||
| dnnl::lstm_forward::primitive_desc prim_desc_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(LSTM, | |||
| @@ -24,9 +24,11 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| using tag = dnnl::memory::format_tag; | |||
| using dim = dnnl::memory::dims; | |||
| auto eng = MKLKernelEngine::Get().engine(); | |||
| std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | |||
| @@ -40,7 +42,9 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| if (bidirectional_) { | |||
| num_directions_ = 2; | |||
| } | |||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; | |||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { | |||
| MS_LOG(EXCEPTION) << "error iteration shape!"; | |||
| } | |||
| const int gate_size = 4 * hidden_size_; | |||
| for (int i = 0; i < num_layers_; ++i) { | |||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | |||
| @@ -48,18 +52,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| weight_size_ = weight_size_ * num_directions_; | |||
| weight_h_size_ = weight_h_size_ * num_directions_; | |||
| } | |||
| bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| using tag = dnnl::memory::format_tag; | |||
| using dt = dnnl::memory::data_type; | |||
| using dim = dnnl::memory::dims; | |||
| auto eng = MKLKernelEngine::Get().engine(); | |||
| dnnl::stream s(eng); | |||
| auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; | |||
| auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; | |||
| dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | |||
| if (bidirectional_) { | |||
| direction = dnnl::rnn_direction::bidirectional_concat; | |||
| @@ -67,128 +59,112 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| dim src_dims = {seq_len_, batch_size_, input_size_}; | |||
| dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | |||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | |||
| dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | |||
| dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | |||
| dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); | |||
| dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | |||
| dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | |||
| dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | |||
| dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc( | |||
| dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||
| dnnl::lstm_forward::desc forward_desc = | |||
| dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, | |||
| formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, | |||
| dst_desc, dst_h_desc, dst_c_desc); | |||
| auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); | |||
| dnnl::lstm_backward::desc backward_desc = | |||
| dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, | |||
| generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, | |||
| dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||
| auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); | |||
| dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc( | |||
| dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), | |||
| formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, | |||
| src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, | |||
| dst_h_desc, dst_c_desc); | |||
| prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); | |||
| primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_); | |||
| AddArgument(DNNL_ARG_SRC_LAYER, src_desc); | |||
| AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); | |||
| AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); | |||
| AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); | |||
| AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); | |||
| AddArgument(DNNL_ARG_BIAS, bias_desc); | |||
| AddArgument(DNNL_ARG_DST_LAYER, dst_desc); | |||
| AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); | |||
| AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); | |||
| AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); | |||
| AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); | |||
| AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); | |||
| AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); | |||
| AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); | |||
| AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); | |||
| AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); | |||
| AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); | |||
| AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); | |||
| AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); | |||
| } | |||
| bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| using dt = dnnl::memory::data_type; | |||
| using tag = dnnl::memory::format_tag; | |||
| auto eng = MKLKernelEngine::Get().engine(); | |||
| // construct fw memory | |||
| auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||
| src_memory.set_data_handle(inputs[0]->addr); | |||
| auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||
| auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||
| src_h_memory.set_data_handle(inputs[1]->addr); | |||
| src_c_memory.set_data_handle(inputs[2]->addr); | |||
| auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||
| auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||
| auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); | |||
| auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); | |||
| auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); | |||
| user_weights_memory.set_data_handle(inputs[3]->addr); | |||
| user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | |||
| auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng); | |||
| auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng); | |||
| dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); | |||
| dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); | |||
| // construct bias memory | |||
| auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng); | |||
| Reorder(&user_weights_memory, &weights_memory); | |||
| Reorder(&user_weights_h_memory, &weights_h_memory); | |||
| if (has_bias_) { | |||
| auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||
| user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); | |||
| bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| } else { | |||
| std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f); | |||
| write_to_dnnl_memory(net_bias.data(), bias_memory); | |||
| std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size()); | |||
| } | |||
| auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||
| dst_memory.set_data_handle(inputs[4]->addr); | |||
| auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||
| auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||
| dst_h_memory.set_data_handle(inputs[5]->addr); | |||
| dst_c_memory.set_data_handle(inputs[6]->addr); | |||
| auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng); | |||
| workspace_memory.set_data_handle(inputs[10]->addr); | |||
| // construct bw memory | |||
| std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f); | |||
| std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f); | |||
| auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||
| auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||
| auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||
| auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||
| auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||
| auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng); | |||
| auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng); | |||
| write_to_dnnl_memory(net_w.data(), diff_weights_memory); | |||
| write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory); | |||
| auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||
| auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng); | |||
| write_to_dnnl_memory(net_w.data(), diff_bias_memory); | |||
| auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||
| diff_dst_memory.set_data_handle(inputs[7]->addr); | |||
| auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||
| diff_dst_h_memory.set_data_handle(inputs[8]->addr); | |||
| auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||
| diff_dst_c_memory.set_data_handle(inputs[9]->addr); | |||
| diff_src_memory.set_data_handle(outputs[0]->addr); | |||
| diff_src_h_memory.set_data_handle(outputs[1]->addr); | |||
| diff_src_c_memory.set_data_handle(outputs[2]->addr); | |||
| auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); | |||
| auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); | |||
| auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); | |||
| auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||
| user_diff_weights_memory.set_data_handle(outputs[3]->addr); | |||
| user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_); | |||
| write_to_dnnl_memory(net_w.data(), user_diff_weights_memory); | |||
| write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory); | |||
| // construct bw bias memory | |||
| user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); | |||
| dnnl::lstm_backward bwd_layer(prim_backward_desc); | |||
| bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, | |||
| {DNNL_ARG_SRC_ITER, src_h_memory}, | |||
| {DNNL_ARG_SRC_ITER_C, src_c_memory}, | |||
| {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, | |||
| {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, | |||
| {DNNL_ARG_BIAS, bias_memory}, | |||
| {DNNL_ARG_DST_LAYER, dst_memory}, | |||
| {DNNL_ARG_DST_ITER, dst_h_memory}, | |||
| {DNNL_ARG_DST_ITER_C, dst_c_memory}, | |||
| {DNNL_ARG_DIFF_SRC_LAYER, diff_src_memory}, | |||
| {DNNL_ARG_DIFF_SRC_ITER, diff_src_h_memory}, | |||
| {DNNL_ARG_DIFF_SRC_ITER_C, diff_src_c_memory}, | |||
| {DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory}, | |||
| {DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory}, | |||
| {DNNL_ARG_DIFF_BIAS, diff_bias_memory}, | |||
| {DNNL_ARG_DIFF_DST_LAYER, diff_dst_memory}, | |||
| {DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory}, | |||
| {DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory}, | |||
| {DNNL_ARG_WORKSPACE, workspace_memory}}); | |||
| dnnl::reorder(diff_weights_memory, user_diff_weights_memory) | |||
| .execute(s, diff_weights_memory, user_diff_weights_memory); | |||
| dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory) | |||
| .execute(s, diff_weights_h_memory, user_diff_weights_h_memory); | |||
| std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size()); | |||
| std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size()); | |||
| if (has_bias_) { | |||
| dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory); | |||
| } else { | |||
| write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); | |||
| diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| } | |||
| s.wait(); | |||
| std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size()); | |||
| std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size()); | |||
| std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size()); | |||
| SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); | |||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); | |||
| SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); | |||
| ExecutePrimitive(); | |||
| Reorder(&diff_weights_memory, &user_diff_weights_memory); | |||
| Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -42,6 +42,10 @@ class LSTMGradCPUKernel : public MKLCPUKernel { | |||
| int num_directions_; | |||
| bool bidirectional_; | |||
| bool has_bias_; | |||
| dnnl::memory::dims weights_dims_; | |||
| dnnl::memory::dims weights_h_dims_; | |||
| dnnl::memory::dims bias_dims_; | |||
| dnnl::lstm_backward::primitive_desc prim_backward_desc_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(LSTMGrad, | |||
| @@ -64,5 +68,4 @@ MS_REG_CPU_KERNEL(LSTMGrad, | |||
| LSTMGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ | |||
| @@ -98,11 +98,9 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { | |||
| } | |||
| void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } | |||
| void MKLCPUKernel::write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||
| MKLKernelEngine::Get().write_to_dnnl_memory(handle, mem); | |||
| } | |||
| void MKLCPUKernel::read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||
| MKLKernelEngine::Get().read_from_dnnl_memory(handle, mem); | |||
| void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { | |||
| MKLKernelEngine::Get().Reorder(src_mem, dst_mem); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -39,10 +39,12 @@ class MKLCPUKernel : public CPUKernel { | |||
| dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; | |||
| dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape); | |||
| void ExecutePrimitive(); | |||
| void write_to_dnnl_memory(void *handle, const dnnl::memory &mem); | |||
| void read_from_dnnl_memory(void *handle, const dnnl::memory &mem); | |||
| std::unordered_map<int, dnnl::memory> arguments_; | |||
| std::shared_ptr<dnnl::primitive> primitive_{nullptr}; | |||
| inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { | |||
| return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; | |||
| } | |||
| void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,5 +33,8 @@ dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, b | |||
| return dnnl::memory(mem_desc, engine_, nullptr); | |||
| } | |||
| } | |||
| void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { | |||
| dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -41,30 +41,7 @@ class MKLKernelEngine { | |||
| void Execute(const std::shared_ptr<dnnl::primitive> &primitive, | |||
| const std::unordered_map<int, dnnl::memory> &arguments); | |||
| inline void read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||
| dnnl::engine eng = mem.get_engine(); | |||
| size_t bytes = mem.get_desc().get_size(); | |||
| if (eng.get_kind() == dnnl::engine::kind::cpu) { | |||
| auto dst = reinterpret_cast<uint8_t *>(handle); | |||
| uint8_t *src = reinterpret_cast<uint8_t *>(mem.get_data_handle()); | |||
| for (size_t i = 0; i < bytes; ++i) { | |||
| dst[i] = src[i]; | |||
| } | |||
| } | |||
| } | |||
| // Read from handle, write to memory | |||
| inline void write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||
| dnnl::engine eng = mem.get_engine(); | |||
| size_t bytes = mem.get_desc().get_size(); | |||
| if (eng.get_kind() == dnnl::engine::kind::cpu) { | |||
| auto src = reinterpret_cast<uint8_t *>(handle); | |||
| uint8_t *dst = reinterpret_cast<uint8_t *>(mem.get_data_handle()); | |||
| for (size_t i = 0; i < bytes; ++i) { | |||
| dst[i] = src[i]; | |||
| } | |||
| } | |||
| } | |||
| void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); | |||
| private: | |||
| MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} | |||
| @@ -13,43 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LSTM.""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore import Parameter, Tensor, nn, context, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore import Tensor, nn, context | |||
| from mindspore.ops import operations as P | |||
| def init_lstm_weight( | |||
| input_size, | |||
| hidden_size, | |||
| num_layers, | |||
| bidirectional, | |||
| has_bias=True): | |||
| """Initialize lstm weight.""" | |||
| num_directions = 1 | |||
| if bidirectional: | |||
| num_directions = 2 | |||
| weight_size = 0 | |||
| gate_size = 4 * hidden_size | |||
| for layer in range(num_layers): | |||
| for _ in range(num_directions): | |||
| input_layer_size = input_size if layer == 0 else hidden_size * num_directions | |||
| weight_size += gate_size * input_layer_size | |||
| weight_size += gate_size * hidden_size | |||
| if has_bias: | |||
| weight_size += 2 * gate_size | |||
| stdv = 1 / math.sqrt(hidden_size) | |||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||
| w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||
| return w | |||
| # Initialize short-term memory (h) and long-term memory (c) to 0 | |||
| def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||
| """init default input.""" | |||
| @@ -60,19 +29,15 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||
| if context.get_context("device_target") == "CPU": | |||
| h_list = [] | |||
| c_list = [] | |||
| for i in range(num_layers): | |||
| hi = Parameter(initializer( | |||
| Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), | |||
| [num_directions, batch_size, hidden_size] | |||
| ), name='h' + str(i)) | |||
| i = 0 | |||
| while i < num_layers: | |||
| hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| h_list.append(hi) | |||
| ci = Parameter(initializer( | |||
| Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), | |||
| [num_directions, batch_size, hidden_size] | |||
| ), name='c' + str(i)) | |||
| ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| c_list.append(ci) | |||
| h = ParameterTuple(tuple(h_list)) | |||
| c = ParameterTuple(tuple(c_list)) | |||
| i = i + 1 | |||
| h = tuple(h_list) | |||
| c = tuple(c_list) | |||
| return h, c | |||
| h = Tensor( | |||
| @@ -108,12 +73,7 @@ class SentimentNet(nn.Cell): | |||
| has_bias=True, | |||
| bidirectional=bidirectional, | |||
| dropout=0.0) | |||
| w_init = init_lstm_weight( | |||
| embed_size, | |||
| num_hiddens, | |||
| num_layers, | |||
| bidirectional) | |||
| self.encoder.weight = w_init | |||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||
| self.concat = P.Concat(1) | |||
| @@ -20,7 +20,6 @@ import mindspore.context as context | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import ParameterTuple, Parameter | |||
| @@ -28,7 +27,7 @@ context.set_context(device_target='CPU') | |||
| class LstmNet(nn.Cell): | |||
| def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||
| super(LstmNet, self).__init__() | |||
| num_directions = 1 | |||
| @@ -92,7 +91,7 @@ def test_lstm(): | |||
| num_directions = 1 | |||
| if bidirectional: | |||
| num_directions = 2 | |||
| net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||
| net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||
| y, (h, c) = net() | |||
| print(y) | |||
| print(c) | |||
| @@ -131,7 +130,7 @@ def test_lstm(): | |||
| class MultiLayerBiLstmNet(nn.Cell): | |||
| def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||
| super(MultiLayerBiLstmNet, self).__init__() | |||
| num_directions = 1 | |||
| @@ -166,6 +165,17 @@ class MultiLayerBiLstmNet(nn.Cell): | |||
| self.h = tuple((self.h0, self.h1)) | |||
| self.c = tuple((self.c0, self.c1)) | |||
| input_size_list = [input_size, hidden_size * num_directions] | |||
| weights = [] | |||
| bias_size = 0 if not has_bias else num_directions * hidden_size * 4 | |||
| for i in range(num_layers): | |||
| weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4 | |||
| w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02 | |||
| if has_bias: | |||
| bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32) | |||
| w_np = np.concatenate([w_np, bias_np], axis=0) | |||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) | |||
| self.lstm.weight = weights | |||
| @ms_function | |||
| def construct(self): | |||
| @@ -176,7 +186,6 @@ class MultiLayerBiLstmNet(nn.Cell): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_multi_layer_bilstm(): | |||
| seq_len = 5 | |||
| batch_size = 2 | |||
| input_size = 10 | |||
| hidden_size = 2 | |||
| @@ -185,7 +194,7 @@ def test_multi_layer_bilstm(): | |||
| bidirectional = True | |||
| dropout = 0.0 | |||
| net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, | |||
| net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, | |||
| dropout) | |||
| y, (h, c) = net() | |||
| print(y) | |||
| @@ -274,7 +283,7 @@ def test_grad(): | |||
| input_size = 3 | |||
| hidden_size = 2 | |||
| num_layers = 1 | |||
| has_bias = True | |||
| has_bias = False | |||
| bidirectional = False | |||
| dropout = 0.0 | |||
| net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)) | |||