Browse Source

update cpu lstm

tags/v0.5.0-beta
baihuawei 5 years ago
parent
commit
9c74e39b12
10 changed files with 189 additions and 262 deletions
  1. +53
    -58
      mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc
  2. +6
    -2
      mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h
  3. +90
    -114
      mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc
  4. +4
    -1
      mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h
  5. +3
    -5
      mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc
  6. +4
    -2
      mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h
  7. +3
    -0
      mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc
  8. +1
    -24
      mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h
  9. +9
    -49
      mindspore/model_zoo/lstm.py
  10. +16
    -7
      tests/st/ops/cpu/test_lstm_op.py

+ 53
- 58
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc View File

@@ -23,6 +23,8 @@ namespace mindspore {
namespace kernel { namespace kernel {
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims;
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
@@ -36,7 +38,9 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (bidirectional_) { if (bidirectional_) {
num_directions_ = 2; num_directions_ = 2;
} }
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
const int gate_size = 4 * hidden_size_; const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) { for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
@@ -44,18 +48,8 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
weight_size_ = weight_size_ * num_directions_; weight_size_ = weight_size_ * num_directions_;
weight_h_size_ = weight_h_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_;
}
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims;
auto eng = MKLKernelEngine::Get().engine(); auto eng = MKLKernelEngine::Get().engine();
dnnl::stream s(eng); dnnl::stream s(eng);
auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; };
auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; };
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) { if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat; direction = dnnl::rnn_direction::bidirectional_concat;
@@ -63,68 +57,69 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_};
weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_};
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
dnnl::lstm_forward::desc desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng);
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
}
// construct fw memory
auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng);
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
src_memory.set_data_handle(inputs[0]->addr);
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
src_h_memory.set_data_handle(inputs[1]->addr);
src_c_memory.set_data_handle(inputs[2]->addr);
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
auto eng = MKLKernelEngine::Get().engine();
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng);
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng);
Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory);
auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng);
if (has_bias_) { if (has_bias_) {
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else { } else {
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
write_to_dnnl_memory(net_bias.data(), bias_memory);
std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size());
} }
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
dnnl::lstm_forward fw_layer(prim_desc);
workspace_memory.set_data_handle(outputs[3]->addr);
dst_memory.set_data_handle(outputs[0]->addr);
dst_h_memory.set_data_handle(outputs[1]->addr);
dst_c_memory.set_data_handle(outputs[2]->addr);
fw_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
{DNNL_ARG_SRC_ITER, src_h_memory},
{DNNL_ARG_SRC_ITER_C, src_c_memory},
{DNNL_ARG_WEIGHTS_LAYER, weights_memory},
{DNNL_ARG_WEIGHTS_ITER, weights_h_memory},
{DNNL_ARG_BIAS, bias_memory},
{DNNL_ARG_DST_LAYER, dst_memory},
{DNNL_ARG_DST_ITER, dst_h_memory},
{DNNL_ARG_DST_ITER_C, dst_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
s.wait();
// set handle
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr);
ExecutePrimitive();
return true; return true;
} }
} // namespace kernel } // namespace kernel


+ 6
- 2
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h"
@@ -41,6 +41,10 @@ class LstmCPUKernel : public MKLCPUKernel {
int num_directions_; int num_directions_;
bool bidirectional_; bool bidirectional_;
bool has_bias_; bool has_bias_;
dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;
dnnl::lstm_forward::primitive_desc prim_desc_;
}; };
MS_REG_CPU_KERNEL(LSTM, MS_REG_CPU_KERNEL(LSTM,


+ 90
- 114
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc View File

@@ -24,9 +24,11 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims;
auto eng = MKLKernelEngine::Get().engine();
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
@@ -40,7 +42,9 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (bidirectional_) { if (bidirectional_) {
num_directions_ = 2; num_directions_ = 2;
} }
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
const int gate_size = 4 * hidden_size_; const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) { for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
@@ -48,18 +52,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
weight_size_ = weight_size_ * num_directions_; weight_size_ = weight_size_ * num_directions_;
weight_h_size_ = weight_h_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_;
}
bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
using dim = dnnl::memory::dims;
auto eng = MKLKernelEngine::Get().engine();
dnnl::stream s(eng);
auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; };
auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; };
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) { if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat; direction = dnnl::rnn_direction::bidirectional_concat;
@@ -67,128 +59,112 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_};
weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_};
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
dnnl::lstm_forward::desc forward_desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
dst_desc, dst_h_desc, dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
dnnl::lstm_backward::desc backward_desc =
dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc,
generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc,
dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc());
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc());
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc);
}
bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
auto eng = MKLKernelEngine::Get().engine();
// construct fw memory // construct fw memory
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
src_memory.set_data_handle(inputs[0]->addr);
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
src_h_memory.set_data_handle(inputs[1]->addr);
src_c_memory.set_data_handle(inputs[2]->addr);
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng);
auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng);
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
// construct bias memory
auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng);
Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory);
if (has_bias_) { if (has_bias_) {
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else { } else {
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
write_to_dnnl_memory(net_bias.data(), bias_memory);
std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size());
} }
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
dst_memory.set_data_handle(inputs[4]->addr);
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
dst_h_memory.set_data_handle(inputs[5]->addr);
dst_c_memory.set_data_handle(inputs[6]->addr);
auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng);
workspace_memory.set_data_handle(inputs[10]->addr);
// construct bw memory // construct bw memory
std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f);
std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f);
auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng);
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng);
write_to_dnnl_memory(net_w.data(), diff_weights_memory);
write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory);
auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng);
write_to_dnnl_memory(net_w.data(), diff_bias_memory);
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
diff_dst_memory.set_data_handle(inputs[7]->addr);
auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
diff_dst_h_memory.set_data_handle(inputs[8]->addr);
auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
diff_dst_c_memory.set_data_handle(inputs[9]->addr);
diff_src_memory.set_data_handle(outputs[0]->addr);
diff_src_h_memory.set_data_handle(outputs[1]->addr);
diff_src_c_memory.set_data_handle(outputs[2]->addr);
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng);
auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng);
auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_memory.set_data_handle(outputs[3]->addr);
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_); user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
write_to_dnnl_memory(net_w.data(), user_diff_weights_memory);
write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory);
// construct bw bias memory
user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
dnnl::lstm_backward bwd_layer(prim_backward_desc);
bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
{DNNL_ARG_SRC_ITER, src_h_memory},
{DNNL_ARG_SRC_ITER_C, src_c_memory},
{DNNL_ARG_WEIGHTS_LAYER, weights_memory},
{DNNL_ARG_WEIGHTS_ITER, weights_h_memory},
{DNNL_ARG_BIAS, bias_memory},
{DNNL_ARG_DST_LAYER, dst_memory},
{DNNL_ARG_DST_ITER, dst_h_memory},
{DNNL_ARG_DST_ITER_C, dst_c_memory},
{DNNL_ARG_DIFF_SRC_LAYER, diff_src_memory},
{DNNL_ARG_DIFF_SRC_ITER, diff_src_h_memory},
{DNNL_ARG_DIFF_SRC_ITER_C, diff_src_c_memory},
{DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory},
{DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory},
{DNNL_ARG_DIFF_BIAS, diff_bias_memory},
{DNNL_ARG_DIFF_DST_LAYER, diff_dst_memory},
{DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory},
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
dnnl::reorder(diff_weights_memory, user_diff_weights_memory)
.execute(s, diff_weights_memory, user_diff_weights_memory);
dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory)
.execute(s, diff_weights_h_memory, user_diff_weights_h_memory);
std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size());
std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size());
if (has_bias_) { if (has_bias_) {
dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory);
} else {
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
} }
s.wait();
std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size());
std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size());
std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size());
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
ExecutePrimitive();
Reorder(&diff_weights_memory, &user_diff_weights_memory);
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
return true; return true;
} }
} // namespace kernel } // namespace kernel


+ 4
- 1
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h View File

@@ -42,6 +42,10 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
int num_directions_; int num_directions_;
bool bidirectional_; bool bidirectional_;
bool has_bias_; bool has_bias_;
dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;
dnnl::lstm_backward::primitive_desc prim_backward_desc_;
}; };
MS_REG_CPU_KERNEL(LSTMGrad, MS_REG_CPU_KERNEL(LSTMGrad,
@@ -64,5 +68,4 @@ MS_REG_CPU_KERNEL(LSTMGrad,
LSTMGradCPUKernel); LSTMGradCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_

+ 3
- 5
mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc View File

@@ -98,11 +98,9 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) {
} }


void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); }
void MKLCPUKernel::write_to_dnnl_memory(void *handle, const dnnl::memory &mem) {
MKLKernelEngine::Get().write_to_dnnl_memory(handle, mem);
}
void MKLCPUKernel::read_from_dnnl_memory(void *handle, const dnnl::memory &mem) {
MKLKernelEngine::Get().read_from_dnnl_memory(handle, mem);

void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
MKLKernelEngine::Get().Reorder(src_mem, dst_mem);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 4
- 2
mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h View File

@@ -39,10 +39,12 @@ class MKLCPUKernel : public CPUKernel {
dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const;
dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape); dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape);
void ExecutePrimitive(); void ExecutePrimitive();
void write_to_dnnl_memory(void *handle, const dnnl::memory &mem);
void read_from_dnnl_memory(void *handle, const dnnl::memory &mem);
std::unordered_map<int, dnnl::memory> arguments_; std::unordered_map<int, dnnl::memory> arguments_;
std::shared_ptr<dnnl::primitive> primitive_{nullptr}; std::shared_ptr<dnnl::primitive> primitive_{nullptr};
inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) {
return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
}
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore


+ 3
- 0
mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc View File

@@ -33,5 +33,8 @@ dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, b
return dnnl::memory(mem_desc, engine_, nullptr); return dnnl::memory(mem_desc, engine_, nullptr);
} }
} }
void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem);
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 1
- 24
mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h View File

@@ -41,30 +41,7 @@ class MKLKernelEngine {


void Execute(const std::shared_ptr<dnnl::primitive> &primitive, void Execute(const std::shared_ptr<dnnl::primitive> &primitive,
const std::unordered_map<int, dnnl::memory> &arguments); const std::unordered_map<int, dnnl::memory> &arguments);

inline void read_from_dnnl_memory(void *handle, const dnnl::memory &mem) {
dnnl::engine eng = mem.get_engine();
size_t bytes = mem.get_desc().get_size();
if (eng.get_kind() == dnnl::engine::kind::cpu) {
auto dst = reinterpret_cast<uint8_t *>(handle);
uint8_t *src = reinterpret_cast<uint8_t *>(mem.get_data_handle());
for (size_t i = 0; i < bytes; ++i) {
dst[i] = src[i];
}
}
}
// Read from handle, write to memory
inline void write_to_dnnl_memory(void *handle, const dnnl::memory &mem) {
dnnl::engine eng = mem.get_engine();
size_t bytes = mem.get_desc().get_size();
if (eng.get_kind() == dnnl::engine::kind::cpu) {
auto src = reinterpret_cast<uint8_t *>(handle);
uint8_t *dst = reinterpret_cast<uint8_t *>(mem.get_data_handle());
for (size_t i = 0; i < bytes; ++i) {
dst[i] = src[i];
}
}
}
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);


private: private:
MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {}


+ 9
- 49
mindspore/model_zoo/lstm.py View File

@@ -13,43 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""LSTM.""" """LSTM."""
import math


import numpy as np import numpy as np


from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore import Tensor, nn, context
from mindspore.ops import operations as P from mindspore.ops import operations as P



def init_lstm_weight(
input_size,
hidden_size,
num_layers,
bidirectional,
has_bias=True):
"""Initialize lstm weight."""
num_directions = 1
if bidirectional:
num_directions = 2

weight_size = 0
gate_size = 4 * hidden_size
for layer in range(num_layers):
for _ in range(num_directions):
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
weight_size += gate_size * input_layer_size
weight_size += gate_size * hidden_size
if has_bias:
weight_size += 2 * gate_size

stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')

return w


# Initialize short-term memory (h) and long-term memory (c) to 0 # Initialize short-term memory (h) and long-term memory (c) to 0
def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
"""init default input.""" """init default input."""
@@ -60,19 +29,15 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if context.get_context("device_target") == "CPU": if context.get_context("device_target") == "CPU":
h_list = [] h_list = []
c_list = [] c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
i = 0
while i < num_layers:
hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
h_list.append(hi) h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
c_list.append(ci) c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
i = i + 1
h = tuple(h_list)
c = tuple(c_list)
return h, c return h, c


h = Tensor( h = Tensor(
@@ -108,12 +73,7 @@ class SentimentNet(nn.Cell):
has_bias=True, has_bias=True,
bidirectional=bidirectional, bidirectional=bidirectional,
dropout=0.0) dropout=0.0)
w_init = init_lstm_weight(
embed_size,
num_hiddens,
num_layers,
bidirectional)
self.encoder.weight = w_init

self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)


self.concat = P.Concat(1) self.concat = P.Concat(1)


+ 16
- 7
tests/st/ops/cpu/test_lstm_op.py View File

@@ -20,7 +20,6 @@ import mindspore.context as context
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import ParameterTuple, Parameter from mindspore.common.parameter import ParameterTuple, Parameter
@@ -28,7 +27,7 @@ context.set_context(device_target='CPU')
class LstmNet(nn.Cell): class LstmNet(nn.Cell):
def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
super(LstmNet, self).__init__() super(LstmNet, self).__init__()
num_directions = 1 num_directions = 1
@@ -92,7 +91,7 @@ def test_lstm():
num_directions = 1 num_directions = 1
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
y, (h, c) = net() y, (h, c) = net()
print(y) print(y)
print(c) print(c)
@@ -131,7 +130,7 @@ def test_lstm():
class MultiLayerBiLstmNet(nn.Cell): class MultiLayerBiLstmNet(nn.Cell):
def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
super(MultiLayerBiLstmNet, self).__init__() super(MultiLayerBiLstmNet, self).__init__()
num_directions = 1 num_directions = 1
@@ -166,6 +165,17 @@ class MultiLayerBiLstmNet(nn.Cell):
self.h = tuple((self.h0, self.h1)) self.h = tuple((self.h0, self.h1))
self.c = tuple((self.c0, self.c1)) self.c = tuple((self.c0, self.c1))
input_size_list = [input_size, hidden_size * num_directions]
weights = []
bias_size = 0 if not has_bias else num_directions * hidden_size * 4
for i in range(num_layers):
weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02
if has_bias:
bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
w_np = np.concatenate([w_np, bias_np], axis=0)
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
self.lstm.weight = weights
@ms_function @ms_function
def construct(self): def construct(self):
@@ -176,7 +186,6 @@ class MultiLayerBiLstmNet(nn.Cell):
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_multi_layer_bilstm(): def test_multi_layer_bilstm():
seq_len = 5
batch_size = 2 batch_size = 2
input_size = 10 input_size = 10
hidden_size = 2 hidden_size = 2
@@ -185,7 +194,7 @@ def test_multi_layer_bilstm():
bidirectional = True bidirectional = True
dropout = 0.0 dropout = 0.0
net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional,
net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional,
dropout) dropout)
y, (h, c) = net() y, (h, c) = net()
print(y) print(y)
@@ -274,7 +283,7 @@ def test_grad():
input_size = 3 input_size = 3
hidden_size = 2 hidden_size = 2
num_layers = 1 num_layers = 1
has_bias = True
has_bias = False
bidirectional = False bidirectional = False
dropout = 0.0 dropout = 0.0
net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)) net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout))


Loading…
Cancel
Save