| @@ -23,6 +23,8 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| using tag = dnnl::memory::format_tag; | |||||
| using dim = dnnl::memory::dims; | |||||
| std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | ||||
| std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | ||||
| bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | ||||
| @@ -36,7 +38,9 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (bidirectional_) { | if (bidirectional_) { | ||||
| num_directions_ = 2; | num_directions_ = 2; | ||||
| } | } | ||||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; | |||||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { | |||||
| MS_LOG(EXCEPTION) << "error iteration shape!"; | |||||
| } | |||||
| const int gate_size = 4 * hidden_size_; | const int gate_size = 4 * hidden_size_; | ||||
| for (int i = 0; i < num_layers_; ++i) { | for (int i = 0; i < num_layers_; ++i) { | ||||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | ||||
| @@ -44,18 +48,8 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| weight_size_ = weight_size_ * num_directions_; | weight_size_ = weight_size_ * num_directions_; | ||||
| weight_h_size_ = weight_h_size_ * num_directions_; | weight_h_size_ = weight_h_size_ * num_directions_; | ||||
| } | |||||
| bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| using dt = dnnl::memory::data_type; | |||||
| using tag = dnnl::memory::format_tag; | |||||
| using dim = dnnl::memory::dims; | |||||
| auto eng = MKLKernelEngine::Get().engine(); | auto eng = MKLKernelEngine::Get().engine(); | ||||
| dnnl::stream s(eng); | dnnl::stream s(eng); | ||||
| auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; | |||||
| auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; | |||||
| dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | ||||
| if (bidirectional_) { | if (bidirectional_) { | ||||
| direction = dnnl::rnn_direction::bidirectional_concat; | direction = dnnl::rnn_direction::bidirectional_concat; | ||||
| @@ -63,68 +57,69 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| dim src_dims = {seq_len_, batch_size_, input_size_}; | dim src_dims = {seq_len_, batch_size_, input_size_}; | ||||
| dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||||
| dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||||
| dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; | |||||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | ||||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | ||||
| dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | ||||
| dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | ||||
| dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); | |||||
| dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | ||||
| dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | ||||
| dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | ||||
| dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc( | |||||
| dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||||
| auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | |||||
| dnnl::lstm_forward::desc desc = | |||||
| dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, | |||||
| formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, | |||||
| dst_desc, dst_h_desc, dst_c_desc); | |||||
| prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng); | |||||
| primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_); | |||||
| AddArgument(DNNL_ARG_SRC_LAYER, src_desc); | |||||
| AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); | |||||
| AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); | |||||
| AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); | |||||
| AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); | |||||
| AddArgument(DNNL_ARG_BIAS, bias_desc); | |||||
| AddArgument(DNNL_ARG_DST_LAYER, dst_desc); | |||||
| AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); | |||||
| AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); | |||||
| AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); | |||||
| } | |||||
| // construct fw memory | |||||
| auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng); | |||||
| auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||||
| src_memory.set_data_handle(inputs[0]->addr); | |||||
| auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||||
| auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||||
| src_h_memory.set_data_handle(inputs[1]->addr); | |||||
| src_c_memory.set_data_handle(inputs[2]->addr); | |||||
| auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||||
| auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||||
| bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| using dt = dnnl::memory::data_type; | |||||
| using tag = dnnl::memory::format_tag; | |||||
| auto eng = MKLKernelEngine::Get().engine(); | |||||
| auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); | |||||
| auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); | |||||
| user_weights_memory.set_data_handle(inputs[3]->addr); | user_weights_memory.set_data_handle(inputs[3]->addr); | ||||
| user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | ||||
| auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng); | |||||
| auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng); | |||||
| dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); | |||||
| dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); | |||||
| auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng); | |||||
| Reorder(&user_weights_memory, &weights_memory); | |||||
| Reorder(&user_weights_h_memory, &weights_h_memory); | |||||
| auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); | |||||
| if (has_bias_) { | if (has_bias_) { | ||||
| auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||||
| user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); | |||||
| bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| } else { | } else { | ||||
| std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f); | |||||
| write_to_dnnl_memory(net_bias.data(), bias_memory); | |||||
| std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size()); | |||||
| } | } | ||||
| auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||||
| auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||||
| auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||||
| dnnl::lstm_forward fw_layer(prim_desc); | |||||
| workspace_memory.set_data_handle(outputs[3]->addr); | |||||
| dst_memory.set_data_handle(outputs[0]->addr); | |||||
| dst_h_memory.set_data_handle(outputs[1]->addr); | |||||
| dst_c_memory.set_data_handle(outputs[2]->addr); | |||||
| fw_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, | |||||
| {DNNL_ARG_SRC_ITER, src_h_memory}, | |||||
| {DNNL_ARG_SRC_ITER_C, src_c_memory}, | |||||
| {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, | |||||
| {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, | |||||
| {DNNL_ARG_BIAS, bias_memory}, | |||||
| {DNNL_ARG_DST_LAYER, dst_memory}, | |||||
| {DNNL_ARG_DST_ITER, dst_h_memory}, | |||||
| {DNNL_ARG_DST_ITER_C, dst_c_memory}, | |||||
| {DNNL_ARG_WORKSPACE, workspace_memory}}); | |||||
| s.wait(); | |||||
| // set handle | |||||
| SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); | |||||
| ExecutePrimitive(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" | #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" | ||||
| @@ -41,6 +41,10 @@ class LstmCPUKernel : public MKLCPUKernel { | |||||
| int num_directions_; | int num_directions_; | ||||
| bool bidirectional_; | bool bidirectional_; | ||||
| bool has_bias_; | bool has_bias_; | ||||
| dnnl::memory::dims weights_dims_; | |||||
| dnnl::memory::dims weights_h_dims_; | |||||
| dnnl::memory::dims bias_dims_; | |||||
| dnnl::lstm_forward::primitive_desc prim_desc_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(LSTM, | MS_REG_CPU_KERNEL(LSTM, | ||||
| @@ -24,9 +24,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| using tag = dnnl::memory::format_tag; | |||||
| using dim = dnnl::memory::dims; | |||||
| auto eng = MKLKernelEngine::Get().engine(); | |||||
| std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | ||||
| std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | ||||
| bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); | ||||
| @@ -40,7 +42,9 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (bidirectional_) { | if (bidirectional_) { | ||||
| num_directions_ = 2; | num_directions_ = 2; | ||||
| } | } | ||||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; | |||||
| if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { | |||||
| MS_LOG(EXCEPTION) << "error iteration shape!"; | |||||
| } | |||||
| const int gate_size = 4 * hidden_size_; | const int gate_size = 4 * hidden_size_; | ||||
| for (int i = 0; i < num_layers_; ++i) { | for (int i = 0; i < num_layers_; ++i) { | ||||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | ||||
| @@ -48,18 +52,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| weight_size_ = weight_size_ * num_directions_; | weight_size_ = weight_size_ * num_directions_; | ||||
| weight_h_size_ = weight_h_size_ * num_directions_; | weight_h_size_ = weight_h_size_ * num_directions_; | ||||
| } | |||||
| bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &workspace /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| using tag = dnnl::memory::format_tag; | |||||
| using dt = dnnl::memory::data_type; | |||||
| using dim = dnnl::memory::dims; | |||||
| auto eng = MKLKernelEngine::Get().engine(); | |||||
| dnnl::stream s(eng); | |||||
| auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; | |||||
| auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; | |||||
| dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; | ||||
| if (bidirectional_) { | if (bidirectional_) { | ||||
| direction = dnnl::rnn_direction::bidirectional_concat; | direction = dnnl::rnn_direction::bidirectional_concat; | ||||
| @@ -67,128 +59,112 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| dim src_dims = {seq_len_, batch_size_, input_size_}; | dim src_dims = {seq_len_, batch_size_, input_size_}; | ||||
| dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||||
| dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||||
| dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; | |||||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | ||||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | ||||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | ||||
| dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); | ||||
| dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); | ||||
| dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); | |||||
| dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); | ||||
| dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); | ||||
| dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); | ||||
| dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc( | |||||
| dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||||
| dnnl::lstm_forward::desc forward_desc = | |||||
| dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, | |||||
| formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, | |||||
| dst_desc, dst_h_desc, dst_c_desc); | |||||
| auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); | auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); | ||||
| dnnl::lstm_backward::desc backward_desc = | |||||
| dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, | |||||
| generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, | |||||
| dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), | |||||
| generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); | |||||
| auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); | |||||
| dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc( | |||||
| dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), | |||||
| formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, | |||||
| src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, | |||||
| dst_h_desc, dst_c_desc); | |||||
| prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); | |||||
| primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_); | |||||
| AddArgument(DNNL_ARG_SRC_LAYER, src_desc); | |||||
| AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); | |||||
| AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); | |||||
| AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); | |||||
| AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); | |||||
| AddArgument(DNNL_ARG_BIAS, bias_desc); | |||||
| AddArgument(DNNL_ARG_DST_LAYER, dst_desc); | |||||
| AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); | |||||
| AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); | |||||
| AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); | |||||
| AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); | |||||
| AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); | |||||
| AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); | |||||
| AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); | |||||
| } | |||||
| bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &workspace /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| using dt = dnnl::memory::data_type; | |||||
| using tag = dnnl::memory::format_tag; | |||||
| auto eng = MKLKernelEngine::Get().engine(); | |||||
| // construct fw memory | // construct fw memory | ||||
| auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||||
| src_memory.set_data_handle(inputs[0]->addr); | |||||
| auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||||
| auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||||
| src_h_memory.set_data_handle(inputs[1]->addr); | |||||
| src_c_memory.set_data_handle(inputs[2]->addr); | |||||
| auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||||
| auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||||
| auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); | |||||
| auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); | |||||
| auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); | |||||
| user_weights_memory.set_data_handle(inputs[3]->addr); | user_weights_memory.set_data_handle(inputs[3]->addr); | ||||
| user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); | ||||
| auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng); | |||||
| auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng); | |||||
| dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); | |||||
| dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); | |||||
| // construct bias memory | |||||
| auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng); | |||||
| Reorder(&user_weights_memory, &weights_memory); | |||||
| Reorder(&user_weights_h_memory, &weights_h_memory); | |||||
| if (has_bias_) { | if (has_bias_) { | ||||
| auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||||
| user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); | |||||
| bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| } else { | } else { | ||||
| std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f); | |||||
| write_to_dnnl_memory(net_bias.data(), bias_memory); | |||||
| std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size()); | |||||
| } | } | ||||
| auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||||
| dst_memory.set_data_handle(inputs[4]->addr); | |||||
| auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||||
| auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||||
| dst_h_memory.set_data_handle(inputs[5]->addr); | |||||
| dst_c_memory.set_data_handle(inputs[6]->addr); | |||||
| auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng); | |||||
| workspace_memory.set_data_handle(inputs[10]->addr); | |||||
| // construct bw memory | // construct bw memory | ||||
| std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f); | |||||
| std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f); | |||||
| auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); | |||||
| auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); | |||||
| auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); | |||||
| auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); | |||||
| auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); | |||||
| auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng); | |||||
| auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng); | |||||
| write_to_dnnl_memory(net_w.data(), diff_weights_memory); | |||||
| write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory); | |||||
| auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); | |||||
| auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng); | |||||
| write_to_dnnl_memory(net_w.data(), diff_bias_memory); | |||||
| auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); | |||||
| diff_dst_memory.set_data_handle(inputs[7]->addr); | |||||
| auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); | |||||
| diff_dst_h_memory.set_data_handle(inputs[8]->addr); | |||||
| auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); | |||||
| diff_dst_c_memory.set_data_handle(inputs[9]->addr); | |||||
| diff_src_memory.set_data_handle(outputs[0]->addr); | |||||
| diff_src_h_memory.set_data_handle(outputs[1]->addr); | |||||
| diff_src_c_memory.set_data_handle(outputs[2]->addr); | |||||
| auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); | |||||
| auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); | |||||
| auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); | |||||
| auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); | |||||
| user_diff_weights_memory.set_data_handle(outputs[3]->addr); | user_diff_weights_memory.set_data_handle(outputs[3]->addr); | ||||
| user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_); | user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_); | ||||
| write_to_dnnl_memory(net_w.data(), user_diff_weights_memory); | |||||
| write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory); | |||||
| // construct bw bias memory | |||||
| user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); | |||||
| dnnl::lstm_backward bwd_layer(prim_backward_desc); | |||||
| bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, | |||||
| {DNNL_ARG_SRC_ITER, src_h_memory}, | |||||
| {DNNL_ARG_SRC_ITER_C, src_c_memory}, | |||||
| {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, | |||||
| {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, | |||||
| {DNNL_ARG_BIAS, bias_memory}, | |||||
| {DNNL_ARG_DST_LAYER, dst_memory}, | |||||
| {DNNL_ARG_DST_ITER, dst_h_memory}, | |||||
| {DNNL_ARG_DST_ITER_C, dst_c_memory}, | |||||
| {DNNL_ARG_DIFF_SRC_LAYER, diff_src_memory}, | |||||
| {DNNL_ARG_DIFF_SRC_ITER, diff_src_h_memory}, | |||||
| {DNNL_ARG_DIFF_SRC_ITER_C, diff_src_c_memory}, | |||||
| {DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory}, | |||||
| {DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory}, | |||||
| {DNNL_ARG_DIFF_BIAS, diff_bias_memory}, | |||||
| {DNNL_ARG_DIFF_DST_LAYER, diff_dst_memory}, | |||||
| {DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory}, | |||||
| {DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory}, | |||||
| {DNNL_ARG_WORKSPACE, workspace_memory}}); | |||||
| dnnl::reorder(diff_weights_memory, user_diff_weights_memory) | |||||
| .execute(s, diff_weights_memory, user_diff_weights_memory); | |||||
| dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory) | |||||
| .execute(s, diff_weights_h_memory, user_diff_weights_h_memory); | |||||
| std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size()); | |||||
| std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size()); | |||||
| if (has_bias_) { | if (has_bias_) { | ||||
| dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory); | |||||
| } else { | |||||
| write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); | |||||
| diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_); | |||||
| } | } | ||||
| s.wait(); | |||||
| std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size()); | |||||
| std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size()); | |||||
| std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size()); | |||||
| SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); | |||||
| ExecutePrimitive(); | |||||
| Reorder(&diff_weights_memory, &user_diff_weights_memory); | |||||
| Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -42,6 +42,10 @@ class LSTMGradCPUKernel : public MKLCPUKernel { | |||||
| int num_directions_; | int num_directions_; | ||||
| bool bidirectional_; | bool bidirectional_; | ||||
| bool has_bias_; | bool has_bias_; | ||||
| dnnl::memory::dims weights_dims_; | |||||
| dnnl::memory::dims weights_h_dims_; | |||||
| dnnl::memory::dims bias_dims_; | |||||
| dnnl::lstm_backward::primitive_desc prim_backward_desc_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(LSTMGrad, | MS_REG_CPU_KERNEL(LSTMGrad, | ||||
| @@ -64,5 +68,4 @@ MS_REG_CPU_KERNEL(LSTMGrad, | |||||
| LSTMGradCPUKernel); | LSTMGradCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ | ||||
| @@ -98,11 +98,9 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { | |||||
| } | } | ||||
| void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } | void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } | ||||
| void MKLCPUKernel::write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||||
| MKLKernelEngine::Get().write_to_dnnl_memory(handle, mem); | |||||
| } | |||||
| void MKLCPUKernel::read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||||
| MKLKernelEngine::Get().read_from_dnnl_memory(handle, mem); | |||||
| void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { | |||||
| MKLKernelEngine::Get().Reorder(src_mem, dst_mem); | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,10 +39,12 @@ class MKLCPUKernel : public CPUKernel { | |||||
| dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; | dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; | ||||
| dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape); | dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape); | ||||
| void ExecutePrimitive(); | void ExecutePrimitive(); | ||||
| void write_to_dnnl_memory(void *handle, const dnnl::memory &mem); | |||||
| void read_from_dnnl_memory(void *handle, const dnnl::memory &mem); | |||||
| std::unordered_map<int, dnnl::memory> arguments_; | std::unordered_map<int, dnnl::memory> arguments_; | ||||
| std::shared_ptr<dnnl::primitive> primitive_{nullptr}; | std::shared_ptr<dnnl::primitive> primitive_{nullptr}; | ||||
| inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { | |||||
| return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; | |||||
| } | |||||
| void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,5 +33,8 @@ dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, b | |||||
| return dnnl::memory(mem_desc, engine_, nullptr); | return dnnl::memory(mem_desc, engine_, nullptr); | ||||
| } | } | ||||
| } | } | ||||
| void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { | |||||
| dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,30 +41,7 @@ class MKLKernelEngine { | |||||
| void Execute(const std::shared_ptr<dnnl::primitive> &primitive, | void Execute(const std::shared_ptr<dnnl::primitive> &primitive, | ||||
| const std::unordered_map<int, dnnl::memory> &arguments); | const std::unordered_map<int, dnnl::memory> &arguments); | ||||
| inline void read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||||
| dnnl::engine eng = mem.get_engine(); | |||||
| size_t bytes = mem.get_desc().get_size(); | |||||
| if (eng.get_kind() == dnnl::engine::kind::cpu) { | |||||
| auto dst = reinterpret_cast<uint8_t *>(handle); | |||||
| uint8_t *src = reinterpret_cast<uint8_t *>(mem.get_data_handle()); | |||||
| for (size_t i = 0; i < bytes; ++i) { | |||||
| dst[i] = src[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Read from handle, write to memory | |||||
| inline void write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { | |||||
| dnnl::engine eng = mem.get_engine(); | |||||
| size_t bytes = mem.get_desc().get_size(); | |||||
| if (eng.get_kind() == dnnl::engine::kind::cpu) { | |||||
| auto src = reinterpret_cast<uint8_t *>(handle); | |||||
| uint8_t *dst = reinterpret_cast<uint8_t *>(mem.get_data_handle()); | |||||
| for (size_t i = 0; i < bytes; ++i) { | |||||
| dst[i] = src[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); | |||||
| private: | private: | ||||
| MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} | MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} | ||||
| @@ -13,43 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """LSTM.""" | """LSTM.""" | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Parameter, Tensor, nn, context, ParameterTuple | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore import Tensor, nn, context | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| def init_lstm_weight( | |||||
| input_size, | |||||
| hidden_size, | |||||
| num_layers, | |||||
| bidirectional, | |||||
| has_bias=True): | |||||
| """Initialize lstm weight.""" | |||||
| num_directions = 1 | |||||
| if bidirectional: | |||||
| num_directions = 2 | |||||
| weight_size = 0 | |||||
| gate_size = 4 * hidden_size | |||||
| for layer in range(num_layers): | |||||
| for _ in range(num_directions): | |||||
| input_layer_size = input_size if layer == 0 else hidden_size * num_directions | |||||
| weight_size += gate_size * input_layer_size | |||||
| weight_size += gate_size * hidden_size | |||||
| if has_bias: | |||||
| weight_size += 2 * gate_size | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||||
| return w | |||||
| # Initialize short-term memory (h) and long-term memory (c) to 0 | # Initialize short-term memory (h) and long-term memory (c) to 0 | ||||
| def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | ||||
| """init default input.""" | """init default input.""" | ||||
| @@ -60,19 +29,15 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||||
| if context.get_context("device_target") == "CPU": | if context.get_context("device_target") == "CPU": | ||||
| h_list = [] | h_list = [] | ||||
| c_list = [] | c_list = [] | ||||
| for i in range(num_layers): | |||||
| hi = Parameter(initializer( | |||||
| Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), | |||||
| [num_directions, batch_size, hidden_size] | |||||
| ), name='h' + str(i)) | |||||
| i = 0 | |||||
| while i < num_layers: | |||||
| hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| h_list.append(hi) | h_list.append(hi) | ||||
| ci = Parameter(initializer( | |||||
| Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), | |||||
| [num_directions, batch_size, hidden_size] | |||||
| ), name='c' + str(i)) | |||||
| ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| c_list.append(ci) | c_list.append(ci) | ||||
| h = ParameterTuple(tuple(h_list)) | |||||
| c = ParameterTuple(tuple(c_list)) | |||||
| i = i + 1 | |||||
| h = tuple(h_list) | |||||
| c = tuple(c_list) | |||||
| return h, c | return h, c | ||||
| h = Tensor( | h = Tensor( | ||||
| @@ -108,12 +73,7 @@ class SentimentNet(nn.Cell): | |||||
| has_bias=True, | has_bias=True, | ||||
| bidirectional=bidirectional, | bidirectional=bidirectional, | ||||
| dropout=0.0) | dropout=0.0) | ||||
| w_init = init_lstm_weight( | |||||
| embed_size, | |||||
| num_hiddens, | |||||
| num_layers, | |||||
| bidirectional) | |||||
| self.encoder.weight = w_init | |||||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | ||||
| self.concat = P.Concat(1) | self.concat = P.Concat(1) | ||||
| @@ -20,7 +20,6 @@ import mindspore.context as context | |||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import ParameterTuple, Parameter | from mindspore.common.parameter import ParameterTuple, Parameter | ||||
| @@ -28,7 +27,7 @@ context.set_context(device_target='CPU') | |||||
| class LstmNet(nn.Cell): | class LstmNet(nn.Cell): | ||||
| def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||||
| super(LstmNet, self).__init__() | super(LstmNet, self).__init__() | ||||
| num_directions = 1 | num_directions = 1 | ||||
| @@ -92,7 +91,7 @@ def test_lstm(): | |||||
| num_directions = 1 | num_directions = 1 | ||||
| if bidirectional: | if bidirectional: | ||||
| num_directions = 2 | num_directions = 2 | ||||
| net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||||
| net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||||
| y, (h, c) = net() | y, (h, c) = net() | ||||
| print(y) | print(y) | ||||
| print(c) | print(c) | ||||
| @@ -131,7 +130,7 @@ def test_lstm(): | |||||
| class MultiLayerBiLstmNet(nn.Cell): | class MultiLayerBiLstmNet(nn.Cell): | ||||
| def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | |||||
| super(MultiLayerBiLstmNet, self).__init__() | super(MultiLayerBiLstmNet, self).__init__() | ||||
| num_directions = 1 | num_directions = 1 | ||||
| @@ -166,6 +165,17 @@ class MultiLayerBiLstmNet(nn.Cell): | |||||
| self.h = tuple((self.h0, self.h1)) | self.h = tuple((self.h0, self.h1)) | ||||
| self.c = tuple((self.c0, self.c1)) | self.c = tuple((self.c0, self.c1)) | ||||
| input_size_list = [input_size, hidden_size * num_directions] | |||||
| weights = [] | |||||
| bias_size = 0 if not has_bias else num_directions * hidden_size * 4 | |||||
| for i in range(num_layers): | |||||
| weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4 | |||||
| w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02 | |||||
| if has_bias: | |||||
| bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32) | |||||
| w_np = np.concatenate([w_np, bias_np], axis=0) | |||||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) | |||||
| self.lstm.weight = weights | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| @@ -176,7 +186,6 @@ class MultiLayerBiLstmNet(nn.Cell): | |||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_multi_layer_bilstm(): | def test_multi_layer_bilstm(): | ||||
| seq_len = 5 | |||||
| batch_size = 2 | batch_size = 2 | ||||
| input_size = 10 | input_size = 10 | ||||
| hidden_size = 2 | hidden_size = 2 | ||||
| @@ -185,7 +194,7 @@ def test_multi_layer_bilstm(): | |||||
| bidirectional = True | bidirectional = True | ||||
| dropout = 0.0 | dropout = 0.0 | ||||
| net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, | |||||
| net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, | |||||
| dropout) | dropout) | ||||
| y, (h, c) = net() | y, (h, c) = net() | ||||
| print(y) | print(y) | ||||
| @@ -274,7 +283,7 @@ def test_grad(): | |||||
| input_size = 3 | input_size = 3 | ||||
| hidden_size = 2 | hidden_size = 2 | ||||
| num_layers = 1 | num_layers = 1 | ||||
| has_bias = True | |||||
| has_bias = False | |||||
| bidirectional = False | bidirectional = False | ||||
| dropout = 0.0 | dropout = 0.0 | ||||
| net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)) | net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)) | ||||