You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sparse_attention.py 1.0 kB

123456789101112131415161718192021222324252627
  1. import numpy as np
  2. from mindspore import Tensor
  3. from mindspore.parallel.nn.layers import FixedSparseAttention
  4. import mindspore.context as context
  5. context.set_context(device_target="Ascend")
  6. def test_net():
  7. np.random.seed(0)
  8. bs = 2 # batch size
  9. heads = 2
  10. seq_len = 1024 # this op is designed for seq_len = 1024
  11. size_per_head = 128 # maximum size per head value is 128
  12. block_size = 64 # block size is designed to be 64
  13. fixed_sparse = FixedSparseAttention(bs, heads, size_per_head, block_size)
  14. q = np.random.rand(bs, seq_len, heads * size_per_head)
  15. q = q.astype(np.float16)
  16. k = np.random.rand(bs, seq_len, heads * size_per_head)
  17. k = k.astype(np.float16)
  18. v = np.random.rand(bs, seq_len, heads * size_per_head)
  19. v = v.astype(np.float16)
  20. attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32)
  21. out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask))
  22. out_np = out.asnumpy()
  23. print("local output: ", out_np[0, 0])