You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rl_buffer_net.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.common.api import ms_function
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.ops import operations as P
  23. import mindspore as ms
  24. def create_tensor(capcity, shapes, types):
  25. buffer = []
  26. for i in range(len(shapes)):
  27. buffer.append(Parameter(Tensor(np.zeros(((capcity,)+shapes[i])), types[i]), \
  28. name="buffer" + str(i)))
  29. return buffer
  30. class RLBuffer(nn.Cell):
  31. def __init__(self, batch_size, capcity, shapes, types):
  32. super(RLBuffer, self).__init__()
  33. self.buffer = create_tensor(capcity, shapes, types)
  34. self._capacity = capcity
  35. self._batch_size = batch_size
  36. self.count = Parameter(Tensor(0, ms.int32), name="count")
  37. self.head = Parameter(Tensor(0, ms.int32), name="head")
  38. self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
  39. self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
  40. self.buffer_sample = P.BufferSample(
  41. self._capacity, batch_size, shapes, types)
  42. @ms_function
  43. def append(self, exps):
  44. return self.buffer_append(self.buffer, exps, self.count, self.head)
  45. @ms_function
  46. def get(self, index):
  47. return self.buffer_get(self.buffer, self.count, self.head, index)
  48. @ms_function
  49. def sample(self):
  50. return self.buffer_sample(self.buffer, self.count, self.head)
  51. s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
  52. a = Tensor(np.array([0, 1]), ms.int32)
  53. r = Tensor(np.array([1]), ms.float32)
  54. s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
  55. exp = [s, a, r, s_]
  56. exp1 = [s_, a, r, s]
  57. @ pytest.mark.level0
  58. @ pytest.mark.platform_x86_cpu
  59. @ pytest.mark.env_onecard
  60. def test_Buffer():
  61. context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
  62. buffer = RLBuffer(batch_size=32, capcity=100, shapes=[(4,), (2,), (1,), (4,)], types=[
  63. ms.float32, ms.int32, ms.float32, ms.float32])
  64. print("init buffer:\n", buffer.buffer)
  65. for _ in range(0, 110):
  66. buffer.append(exp)
  67. buffer.append(exp1)
  68. print("buffer append:\n", buffer.buffer)
  69. b = buffer.get(-1)
  70. print("buffer get:\n", b)
  71. bs = buffer.sample()
  72. print("buffer sample:\n", bs)