|
|
|
@@ -22,6 +22,7 @@ from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithInfer |
|
|
|
|
|
|
|
|
|
|
|
class BufferSample(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
In reinforcement learning, the data is sampled from the replaybuffer randomly. |
|
|
|
@@ -134,6 +135,7 @@ class BufferSample(PrimitiveWithInfer): |
|
|
|
validator.check_type_name("head type", head_type, (mstype.int32), self.name) |
|
|
|
return tuple(self._buffer_dtype) |
|
|
|
|
|
|
|
|
|
|
|
class BufferAppend(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
In reinforcement learning, the experience data is collected in each step. We use `BufferAppend` to |
|
|
|
@@ -227,6 +229,7 @@ class BufferAppend(PrimitiveWithInfer): |
|
|
|
validator.check_type_name("head type", head_type, (mstype.int32), self.name) |
|
|
|
return count_type |
|
|
|
|
|
|
|
|
|
|
|
class BufferGetItem(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Get the data from buffer in the position of input inedx. |
|
|
|
|