You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

getting_started.rst 20 kB

update docs Former-commit-id: b4f9391f66a75d85d138cf6c3cbc70c74b2f29bd [formerly b5a8cc6a61eaef56d17bee96aa766bd34ae9f94a] [formerly f9d9933ef930265148d7248a041d0dc85ee7eb45 [formerly 1a4fc54c24a16a35f7fa8fd414e9ddfa8607eaf6]] [formerly 19b7d536d1f13989b01fa23f0b6c24124f9fbea4 [formerly 988ec5c83aa7a9374a67d5ca299af768cbdc6dd2] [formerly 7f480b67834df6503bc462f1d75b27d3e1f85386 [formerly fc11d81413904109e8f5bdc5f3e1c30c514c6cac]]] [formerly f9c3e7e042fd3b917a1b422268054723857ff207 [formerly d2a2ee485d0a74e13b6576e7db8ae8c2655100aa] [formerly c1a3d1dbbef7d762a3aba75300d29fdd2609090e [formerly 8f0498f35266241d3f4903c88677cef6b869ecfc]] [formerly f12d6e39ca5aa104f1c3d2cadb870ed9303d1daf [formerly 8ba54f5192974c0cca14a103c22e68a95f4df35e] [formerly ad523f7ac7ac7f715ccc02caf2a7cba396a655b9 [formerly 43a0aedc832092dfeaa97e0f55273bfe87de4795]]]] [formerly 4b851d9b6a31ca70df14604546b6550f358d32f1 [formerly 964b171b97a34a24073bd07db964cd27e0e459d3] [formerly 91ff1377adf1f4f3a20681b1462032dcc2e4ebfc [formerly b8e31652e5fd407548a4f6b9280c571a92bb799f]] [formerly 13472a5434e929452d9fbb0c31166c3145acee1c [formerly 0d13aa5de2c864f1fdc7f17370c0538d8a7f8936] [formerly a2a009a59703aa7a6c5bb6b5be0f1ce03ac3d2ff [formerly aa96b65c41104c34dab1918f1f7be2bab04c9956]]] [formerly 6efe91fa976b2db1808a2c827083eecbc869e14e [formerly 0e9b285cdcfbf12dbd3646f6c54477079c60bf23] [formerly f9099f57d687df623e261242863d21154f6e0812 [formerly 48c1ed465c0a5fa4a42b54a1335e3bbcc55f18d3]] [formerly a4c38752b228f455193f37333049dac23d11f60d [formerly 3f5372c0d7a4990e42521cdfe037bac26754d90d] [formerly f1fc509223c3176aef589fa966e1034a7d95afd8 [formerly 2c75d3ed66b189216d7a93824a6d20e4d1b28ed6]]]]] [formerly f73e0e79b5f0e8a19d98051d905ef49fcc5610b0 [formerly bb1592d62a5c44662763a2cb76001764260bcfaa] [formerly 9b775b7ce66731579641f1344a060ba6778fb490 [formerly 1a077236633fdfa3aa5006090f137e11c6e15f30]] [formerly 45b8b7a2e7d98786e2882b55d8c1c426d57d006e [formerly f505046a1edd63e329c46ec77a59c4c89285bc81] [formerly 2d91d39c00b5a044f82ff6421ae1ebad025e35e1 [formerly 41f2783fd8ad1c6afdc87da6b78026dcc9c80d75]]] [formerly 98c6897952f7b30f9b0043121d82684d4ff2fc00 [formerly bf39d987b64ba45c91e74047bde53fdf8a896151] [formerly b95d5e1d2fa566cac1e1bacbb152515224a51751 [formerly e99547b3c359bf1232e69d05dd18854696b5301a]] [formerly 36ee095cf597f55378a9762eae2356c9fc078877 [formerly 5f039a7a6a407d9326136350da4380281e690545] [formerly ad2ed53891404805702ab9eb0017553b9b57ef20 [formerly ea94d6e921a298d2d9ae9de5c13cff68823b210f]]]] [formerly 023d2ae8252d83d07e19d5cc75c8ab1f5a03ef6a [formerly e2099c84ff55f717dfd54d6d1e9c0febb5ac51dd] [formerly 385bbb9f020d2cbba3bb0fd01420dfeb55a8be63 [formerly 658d7bc094536181253952e09dba8400cb0f644b]] [formerly aeb9ae3f648abad67dbea26172d0e1a60fefcc3e [formerly 41461da4fbe4426b38017ea7f9e11b4956b25b27] [formerly d2aba9f4bb0ab0b156e0ed2d22e03efc70527fdc [formerly 214e71901b08a5d249d700136025bf9ff418281b]]] [formerly 85bc2482bd8d448aac9093d93962e2e51cfbf4fa [formerly bcdcb18d5a637c2d9a9662e12a1c11b83706cf04] [formerly 063bf00d72e4f1c3bdb53e7a1d1a5f5b9c483d96 [formerly 736726d1b6510a5a814a96841ecfb502eba25d72]] [formerly 3119de7da3cfec37903bfc5ba8c4c5766a245817 [formerly b04f8107faf1b7f7587a2853ec0b7780c3db8e90] [formerly 7d23dc080adbd90e829cc48ff902c41e91be5362 [formerly 8d22f0930b099e25f04611ace88bdef6690dfcfa]]]]]] Former-commit-id: e488b8c4a01487b0342e7bea295330c17e8b0a9e [formerly c0839cf283dfa120d302dab01288b4e7e04fca01] [formerly 35cda6c2222a534eef886a53bb68f4488a8a662d [formerly 21ecbeff1fcca60112e0668c8ceb12c327c40c00]] [formerly 0c39c474d03ff8eb9aba4e0d5ffd76896c210df7 [formerly 753b82ea179137cf754a1bb120438dea73255479] [formerly e1f3995b2f0afb624144c33bc780ba16e51d12a5 [formerly 905df014d2fe3e6862d11d6d0ae53e05329414ec]]] [formerly ce2ecf9247f9a231d31d2a89380f17799de5b1fb [formerly 5d1296adca5bc6c46ab8007f05fb1cf59f3b6676] [formerly c5748d206d06bce06bd83e28c346f74fdca20bdd [formerly 60b93c2611f31e11c982dce1b0d84386176b4f74]] [formerly 67f9fb672c4df74a1c2e21ff738faf98f1a6d4f2 [formerly 261d75c23a2c2d9b41e3366bd1d4f6b1a1178951] [formerly 656522a99c12a616303cbc58c5230b5c224fcceb [formerly 70dd19d77db345bbc8cf0e811a4e70de8ba646b8]]]] [formerly b1f6f39978c7b1b97610a86e8965223100204304 [formerly abb2faee1fec43dcb9908022551f43aee553316b] [formerly 7b8ebcc612f862996a48d6db0a3b4dc0390090c9 [formerly 5e106bbe7d04c89c0120bf5e8a2fd0cd829300a9]] [formerly a887c6ac6703f7fee766efa2f00207244873af63 [formerly 17ec569ae45b2850a92aca50971ed8ebc101d816] [formerly 9d4ffdc32d582e668df2af7bc8b817aa0f8a40d3 [formerly 93ea1a952db3216db85fb0534f43fe07a3be8865]]] [formerly b58006417c350506286b5976ee780c47f7d75b2b [formerly c68fe2a6040af5dfb94e45ff3cacb48a1f3b00a1] [formerly 03f60582a7e3fb4f161b8394010201dccf7ea7cf [formerly ce205f21e2d1f93ac48b05dc5d95982483a867bd]] [formerly 12b8da14983c276372928709ca59b0a7263adc26 [formerly 52ab6ac5509792ba32d62bccfb35b3cbc687f3a3] [formerly 7d23dc080adbd90e829cc48ff902c41e91be5362]]]] Former-commit-id: 9bbe314a259e2c7d58d175bb24b88cdbb417e4fa [formerly debf62c9fbdf3a6e8fc518bea00e7dd727470660] [formerly bee4ae87c815001029e1ca9ac885b671cc3701af [formerly 363e8c4cbb9340683370e40b0b8bbef762a5a152]] [formerly f43cd35c22d2078edf94e2a3b186e8556b16079e [formerly 0453186dc046519343bea0f46b128b208ee1483b] [formerly db37b88820fc55dc61ac36cdeaeba45b1d02be83 [formerly d2122719fcaaa1fda2edd4401aee4bbc2d72af96]]] [formerly a5298f4ead30ccf1b63d5eb6e84b22512bc9e73a [formerly 956a50169cacdd33d34f9a082ae85df326243a00] [formerly a5ff9c9f212b486ae8c8ddf7efa32a94b035540a [formerly 975b32abbb5f528f0c78ba6ae792ae1f45c58a1c]] [formerly bd6266cd7f2be53474add43fb04db21e70fda09d [formerly 3e712311e775ad480e15ce6abc8a09007a1c3a84] [formerly 1a60e83a77286af30c0e5605e6e651c4382570b5 [formerly 3aeaeec27128146b194c9b26478319571e122710]]]] Former-commit-id: 0bd595f9f28265d79894747850e3fd307d9f1e11 [formerly 5328055b46c772fff24317356d5eb01417206a51] [formerly a760df362f25b913926586e7450748eecdb364c6 [formerly 14fa9bb70c614f1be6405304b007d4f4cd5819f3]] [formerly 2268844207020e216c1ecec882c25152c29e01b3 [formerly f92402f2b7839c5863452745b68e459e662c9ceb] [formerly 38144d591e2a34767897cabf0c7fc46f90ea6bd0 [formerly 66df7c3ad8427e7de565941bd1e6775c65ec6f7b]]] Former-commit-id: b8a5d9e66a7dc8e96c495d0336f82eb678a66fef [formerly 1a5590cef5e0ade689868de480367e3b64680aad] [formerly 4122d313d77bf742a1882d3718f81d0f5255033f [formerly ff0ae449df1f869631428b220d1ceb3435cfb517]] Former-commit-id: 2f64bae9a096c4acdd60f94f545bb05d60c4e1cf [formerly 252d5866f86f62d466028e7da98fe89d01f5949d] Former-commit-id: 67c88848996c9a9cd5d404694d133e20889ba722
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. Getting Started
  2. ===============
  3. In this document, we provide some toy examples for getting started. All
  4. the examples in this document and even more examples are available in
  5. `examples/ <https://github.com/datamllab/rlcard/tree/master/examples>`__.
  6. Playing with Random Agents
  7. --------------------------
  8. We have set up a random agent that can play randomly on each
  9. environment. An example of applying a random agent on Blackjack is as
  10. follow:
  11. .. code:: python
  12. import rlcard
  13. from rlcard.agents import RandomAgent
  14. from rlcard.utils import set_global_seed
  15. # Make environment
  16. env = rlcard.make('blackjack', config={'seed': 0})
  17. episode_num = 2
  18. # Set a global seed
  19. set_global_seed(0)
  20. # Set up agents
  21. agent_0 = RandomAgent(action_num=env.action_num)
  22. env.set_agents([agent_0])
  23. for episode in range(episode_num):
  24. # Generate data from the environment
  25. trajectories, _ = env.run(is_training=False)
  26. # Print out the trajectories
  27. print('\nEpisode {}'.format(episode))
  28. for ts in trajectories[0]:
  29. print('State: {}, Action: {}, Reward: {}, Next State: {}, Done: {}'.format(ts[0], ts[1], ts[2], ts[3], ts[4]))
  30. The expected output should look like something as follows:
  31. ::
  32. Episode 0
  33. State: {'obs': array([20, 3]), 'legal_actions': [0, 1]}, Action: 0, Reward: 0, Next State: {'obs': array([15, 3]), 'legal_actions': [0, 1]}, Done: False
  34. State: {'obs': array([15, 3]), 'legal_actions': [0, 1]}, Action: 1, Reward: -1, Next State: {'obs': array([15, 20]), 'legal_actions': [0, 1]}, Done: True
  35. Episode 1
  36. State: {'obs': array([15, 5]), 'legal_actions': [0, 1]}, Action: 1, Reward: 1, Next State: {'obs': array([15, 23]), 'legal_actions': [0, 1]}, Done: True
  37. Note that the states and actions are wrapped by ``env`` in Blackjack. In
  38. this example, the ``[20, 3]`` suggests the current player obtains score
  39. 20 while the card that faces up in the dealer’s hand has score 3. Action
  40. 0 means “hit” while action 1 means “stand”. Reward 1 suggests the player
  41. wins while reward -1 suggests the dealer wins. Reward 0 suggests a tie.
  42. The above data can be directly fed into a RL algorithm for training.
  43. Deep-Q Learning on Blackjack
  44. ----------------------------
  45. The second example is to use Deep-Q learning to train an agent on
  46. Blackjack. We aim to use this example to show how reinforcement learning
  47. algorithms can be developed and applied in our toolkit. We design a
  48. ``run`` function which plays one complete game and provides the data for
  49. training RL agents. The example is shown below:
  50. .. code:: python
  51. import tensorflow as tf
  52. import os
  53. import rlcard
  54. from rlcard.agents import DQNAgent
  55. from rlcard.utils import set_global_seed, tournament
  56. from rlcard.utils import Logger
  57. # Make environment
  58. env = rlcard.make('blackjack', config={'seed': 0})
  59. eval_env = rlcard.make('blackjack', config={'seed': 0})
  60. # Set the iterations numbers and how frequently we evaluate/save plot
  61. evaluate_every = 100
  62. evaluate_num = 10000
  63. episode_num = 100000
  64. # The intial memory size
  65. memory_init_size = 100
  66. # Train the agent every X steps
  67. train_every = 1
  68. # The paths for saving the logs and learning curves
  69. log_dir = './experiments/blackjack_dqn_result/'
  70. # Set a global seed
  71. set_global_seed(0)
  72. with tf.Session() as sess:
  73. # Initialize a global step
  74. global_step = tf.Variable(0, name='global_step', trainable=False)
  75. # Set up the agents
  76. agent = DQNAgent(sess,
  77. scope='dqn',
  78. action_num=env.action_num,
  79. replay_memory_init_size=memory_init_size,
  80. train_every=train_every,
  81. state_shape=env.state_shape,
  82. mlp_layers=[10,10])
  83. env.set_agents([agent])
  84. eval_env.set_agents([agent])
  85. # Initialize global variables
  86. sess.run(tf.global_variables_initializer())
  87. # Init a Logger to plot the learning curve
  88. logger = Logger(log_dir)
  89. for episode in range(episode_num):
  90. # Generate data from the environment
  91. trajectories, _ = env.run(is_training=True)
  92. # Feed transitions into agent memory, and train the agent
  93. for ts in trajectories[0]:
  94. agent.feed(ts)
  95. # Evaluate the performance. Play with random agents.
  96. if episode % evaluate_every == 0:
  97. logger.log_performance(env.timestep, tournament(eval_env, evaluate_num)[0])
  98. # Close files in the logger
  99. logger.close_files()
  100. # Plot the learning curve
  101. logger.plot('DQN')
  102. # Save model
  103. save_dir = 'models/blackjack_dqn'
  104. if not os.path.exists(save_dir):
  105. os.makedirs(save_dir)
  106. saver = tf.train.Saver()
  107. saver.save(sess, os.path.join(save_dir, 'model'))
  108. The expected output is something like below:
  109. ::
  110. ----------------------------------------
  111. timestep | 1
  112. reward | -0.7342
  113. ----------------------------------------
  114. INFO - Agent dqn, step 100, rl-loss: 1.0042707920074463
  115. INFO - Copied model parameters to target network.
  116. INFO - Agent dqn, step 136, rl-loss: 0.7888197302818298
  117. ----------------------------------------
  118. timestep | 136
  119. reward | -0.1406
  120. ----------------------------------------
  121. INFO - Agent dqn, step 278, rl-loss: 0.6946825981140137
  122. ----------------------------------------
  123. timestep | 278
  124. reward | -0.1523
  125. ----------------------------------------
  126. INFO - Agent dqn, step 412, rl-loss: 0.62268990278244025
  127. ----------------------------------------
  128. timestep | 412
  129. reward | -0.088
  130. ----------------------------------------
  131. INFO - Agent dqn, step 544, rl-loss: 0.69050502777099616
  132. ----------------------------------------
  133. timestep | 544
  134. reward | -0.08
  135. ----------------------------------------
  136. INFO - Agent dqn, step 681, rl-loss: 0.61789089441299444
  137. ----------------------------------------
  138. timestep | 681
  139. reward | -0.0793
  140. ----------------------------------------
  141. In Blackjack, the player will get a payoff at the end of the game: 1 if
  142. the player wins, -1 if the player loses, and 0 if it is a tie. The
  143. performance is measured by the average payoff the player obtains by
  144. playing 10000 episodes. The above example shows that the agent achieves
  145. better and better performance during training. The logs and learning
  146. curves are saved in ``./experiments/blackjack_dqn_result/``.
  147. Running Multiple Processes
  148. --------------------------
  149. The environments can be run with multiple processes to accelerate the
  150. training. Below is an example to train DQN on Blackjack with multiple
  151. processes.
  152. .. code:: python
  153. ''' An example of learning a Deep-Q Agent on Blackjack with multiple processes
  154. Note that we must use if __name__ == '__main__' for multiprocessing
  155. '''
  156. import tensorflow as tf
  157. import os
  158. import rlcard
  159. from rlcard.agents import DQNAgent
  160. from rlcard.utils import set_global_seed, tournament
  161. from rlcard.utils import Logger
  162. def main():
  163. # Make environment
  164. env = rlcard.make('blackjack', config={'seed': 0, 'env_num': 4})
  165. eval_env = rlcard.make('blackjack', config={'seed': 0, 'env_num': 4})
  166. # Set the iterations numbers and how frequently we evaluate performance
  167. evaluate_every = 100
  168. evaluate_num = 10000
  169. iteration_num = 100000
  170. # The intial memory size
  171. memory_init_size = 100
  172. # Train the agent every X steps
  173. train_every = 1
  174. # The paths for saving the logs and learning curves
  175. log_dir = './experiments/blackjack_dqn_result/'
  176. # Set a global seed
  177. set_global_seed(0)
  178. with tf.Session() as sess:
  179. # Initialize a global step
  180. global_step = tf.Variable(0, name='global_step', trainable=False)
  181. # Set up the agents
  182. agent = DQNAgent(sess,
  183. scope='dqn',
  184. action_num=env.action_num,
  185. replay_memory_init_size=memory_init_size,
  186. train_every=train_every,
  187. state_shape=env.state_shape,
  188. mlp_layers=[10,10])
  189. env.set_agents([agent])
  190. eval_env.set_agents([agent])
  191. # Initialize global variables
  192. sess.run(tf.global_variables_initializer())
  193. # Initialize a Logger to plot the learning curve
  194. logger = Logger(log_dir)
  195. for iteration in range(iteration_num):
  196. # Generate data from the environment
  197. trajectories, _ = env.run(is_training=True)
  198. # Feed transitions into agent memory, and train the agent
  199. for ts in trajectories[0]:
  200. agent.feed(ts)
  201. # Evaluate the performance. Play with random agents.
  202. if iteration % evaluate_every == 0:
  203. logger.log_performance(env.timestep, tournament(eval_env, evaluate_num)[0])
  204. # Close files in the logger
  205. logger.close_files()
  206. # Plot the learning curve
  207. logger.plot('DQN')
  208. # Save model
  209. save_dir = 'models/blackjack_dqn'
  210. if not os.path.exists(save_dir):
  211. os.makedirs(save_dir)
  212. saver = tf.train.Saver()
  213. saver.save(sess, os.path.join(save_dir, 'model'))
  214. if __name__ == '__main__':
  215. main()
  216. Example output is as follow:
  217. ::
  218. ----------------------------------------
  219. timestep | 17
  220. reward | -0.7378
  221. ----------------------------------------
  222. INFO - Copied model parameters to target network.
  223. INFO - Agent dqn, step 1100, rl-loss: 0.40940183401107797
  224. INFO - Copied model parameters to target network.
  225. INFO - Agent dqn, step 2100, rl-loss: 0.44971221685409546
  226. INFO - Copied model parameters to target network.
  227. INFO - Agent dqn, step 2225, rl-loss: 0.65466868877410897
  228. ----------------------------------------
  229. timestep | 2225
  230. reward | -0.0658
  231. ----------------------------------------
  232. INFO - Agent dqn, step 3100, rl-loss: 0.48663979768753053
  233. INFO - Copied model parameters to target network.
  234. INFO - Agent dqn, step 4100, rl-loss: 0.71293979883193974
  235. INFO - Copied model parameters to target network.
  236. INFO - Agent dqn, step 4440, rl-loss: 0.55871248245239263
  237. ----------------------------------------
  238. timestep | 4440
  239. reward | -0.0736
  240. ----------------------------------------
  241. Training CFR on Leduc Hold’em
  242. -----------------------------
  243. To show how we can use ``step`` and ``step_back`` to traverse the game
  244. tree, we provide an example of solving Leduc Hold’em with CFR:
  245. .. code:: python
  246. import numpy as np
  247. import rlcard
  248. from rlcard.agents import CFRAgent
  249. from rlcard import models
  250. from rlcard.utils import set_global_seed, tournament
  251. from rlcard.utils import Logger
  252. # Make environment and enable human mode
  253. env = rlcard.make('leduc-holdem', config={'seed': 0, 'allow_step_back':True})
  254. eval_env = rlcard.make('leduc-holdem', config={'seed': 0})
  255. # Set the iterations numbers and how frequently we evaluate/save plot
  256. evaluate_every = 100
  257. save_plot_every = 1000
  258. evaluate_num = 10000
  259. episode_num = 10000
  260. # The paths for saving the logs and learning curves
  261. log_dir = './experiments/leduc_holdem_cfr_result/'
  262. # Set a global seed
  263. set_global_seed(0)
  264. # Initilize CFR Agent
  265. agent = CFRAgent(env)
  266. agent.load() # If we have saved model, we first load the model
  267. # Evaluate CFR against pre-trained NFSP
  268. eval_env.set_agents([agent, models.load('leduc-holdem-nfsp').agents[0]])
  269. # Init a Logger to plot the learning curve
  270. logger = Logger(log_dir)
  271. for episode in range(episode_num):
  272. agent.train()
  273. print('\rIteration {}'.format(episode), end='')
  274. # Evaluate the performance. Play with NFSP agents.
  275. if episode % evaluate_every == 0:
  276. agent.save() # Save model
  277. logger.log_performance(env.timestep, tournament(eval_env, evaluate_num)[0])
  278. # Close files in the logger
  279. logger.close_files()
  280. # Plot the learning curve
  281. logger.plot('CFR')
  282. In the above example, the performance is measured by playing against a
  283. pre-trained NFSP model. The expected output is as below:
  284. ::
  285. Iteration 0
  286. ----------------------------------------
  287. timestep | 192
  288. reward | -1.3662
  289. ----------------------------------------
  290. Iteration 100
  291. ----------------------------------------
  292. timestep | 19392
  293. reward | 0.9462
  294. ----------------------------------------
  295. Iteration 200
  296. ----------------------------------------
  297. timestep | 38592
  298. reward | 0.8591
  299. ----------------------------------------
  300. Iteration 300
  301. ----------------------------------------
  302. timestep | 57792
  303. reward | 0.7861
  304. ----------------------------------------
  305. Iteration 400
  306. ----------------------------------------
  307. timestep | 76992
  308. reward | 0.7752
  309. ----------------------------------------
  310. Iteration 500
  311. ----------------------------------------
  312. timestep | 96192
  313. reward | 0.7215
  314. ----------------------------------------
  315. We observe that CFR achieves better performance as NFSP. However, CFR
  316. requires traversal of the game tree, which is infeasible in large
  317. environments.
  318. Having Fun with Pretrained Leduc Model
  319. --------------------------------------
  320. We have designed simple human interfaces to play against the pretrained
  321. model. Leduc Hold’em is a simplified version of Texas Hold’em. Rules can
  322. be found `here <games.md#leduc-holdem>`__. Example of playing against
  323. Leduc Hold’em CFR model is as below:
  324. .. code:: python
  325. import rlcard
  326. from rlcard import models
  327. from rlcard.agents import LeducholdemHumanAgent as HumanAgent
  328. from rlcard.utils import print_card
  329. # Make environment
  330. # Set 'record_action' to True because we need it to print results
  331. env = rlcard.make('leduc-holdem', config={'record_action': True})
  332. human_agent = HumanAgent(env.action_num)
  333. cfr_agent = models.load('leduc-holdem-cfr').agents[0]
  334. env.set_agents([human_agent, cfr_agent])
  335. print(">> Leduc Hold'em pre-trained model")
  336. while (True):
  337. print(">> Start a new game")
  338. trajectories, payoffs = env.run(is_training=False)
  339. # If the human does not take the final action, we need to
  340. # print other players action
  341. final_state = trajectories[0][-1][-2]
  342. action_record = final_state['action_record']
  343. state = final_state['raw_obs']
  344. _action_list = []
  345. for i in range(1, len(action_record)+1):
  346. if action_record[-i][0] == state['current_player']:
  347. break
  348. _action_list.insert(0, action_record[-i])
  349. for pair in _action_list:
  350. print('>> Player', pair[0], 'chooses', pair[1])
  351. # Let's take a look at what the agent card is
  352. print('=============== CFR Agent ===============')
  353. print_card(env.get_perfect_information()['hand_cards'][1])
  354. print('=============== Result ===============')
  355. if payoffs[0] > 0:
  356. print('You win {} chips!'.format(payoffs[0]))
  357. elif payoffs[0] == 0:
  358. print('It is a tie.')
  359. else:
  360. print('You lose {} chips!'.format(-payoffs[0]))
  361. print('')
  362. input("Press any key to continue...")
  363. Example output is as follow:
  364. ::
  365. >> Leduc Hold'em pre-trained model
  366. >> Start a new game!
  367. >> Agent 1 chooses raise
  368. =============== Community Card ===============
  369. ┌─────────┐
  370. │░░░░░░░░░│
  371. │░░░░░░░░░│
  372. │░░░░░░░░░│
  373. │░░░░░░░░░│
  374. │░░░░░░░░░│
  375. │░░░░░░░░░│
  376. │░░░░░░░░░│
  377. └─────────┘
  378. =============== Your Hand ===============
  379. ┌─────────┐
  380. │J │
  381. │ │
  382. │ │
  383. │ ♥ │
  384. │ │
  385. │ │
  386. │ J│
  387. └─────────┘
  388. =============== Chips ===============
  389. Yours: +
  390. Agent 1: +++
  391. =========== Actions You Can Choose ===========
  392. 0: call, 1: raise, 2: fold
  393. >> You choose action (integer):
  394. We also provide a running demo of a rule-based agent for UNO. Try it by
  395. running ``examples/uno_human.py``.
  396. Leduc Hold’em as Single-Agent Environment
  397. -----------------------------------------
  398. We have wrraped the environment as single agent environment by assuming
  399. that other players play with pre-trained models. The interfaces are
  400. exactly the same to OpenAI Gym. Thus, any single-agent algorithm can be
  401. connected to the environment. An example of Leduc Hold’em is as below:
  402. .. code:: python
  403. import tensorflow as tf
  404. import os
  405. import numpy as np
  406. import rlcard
  407. from rlcard.agents import DQNAgent
  408. from rlcard.agents import RandomAgent
  409. from rlcard.utils import set_global_seed, tournament
  410. from rlcard.utils import Logger
  411. # Make environment
  412. env = rlcard.make('leduc-holdem', config={'seed': 0, 'single_agent_mode':True})
  413. eval_env = rlcard.make('leduc-holdem', config={'seed': 0, 'single_agent_mode':True})
  414. # Set the iterations numbers and how frequently we evaluate/save plot
  415. evaluate_every = 1000
  416. evaluate_num = 10000
  417. timesteps = 100000
  418. # The intial memory size
  419. memory_init_size = 1000
  420. # Train the agent every X steps
  421. train_every = 1
  422. # The paths for saving the logs and learning curves
  423. log_dir = './experiments/leduc_holdem_single_dqn_result/'
  424. # Set a global seed
  425. set_global_seed(0)
  426. with tf.Session() as sess:
  427. # Initialize a global step
  428. global_step = tf.Variable(0, name='global_step', trainable=False)
  429. # Set up the agents
  430. agent = DQNAgent(sess,
  431. scope='dqn',
  432. action_num=env.action_num,
  433. replay_memory_init_size=memory_init_size,
  434. train_every=train_every,
  435. state_shape=env.state_shape,
  436. mlp_layers=[128,128])
  437. # Initialize global variables
  438. sess.run(tf.global_variables_initializer())
  439. # Init a Logger to plot the learning curve
  440. logger = Logger(log_dir)
  441. state = env.reset()
  442. for timestep in range(timesteps):
  443. action = agent.step(state)
  444. next_state, reward, done = env.step(action)
  445. ts = (state, action, reward, next_state, done)
  446. agent.feed(ts)
  447. if timestep % evaluate_every == 0:
  448. rewards = []
  449. state = eval_env.reset()
  450. for _ in range(evaluate_num):
  451. action, _ = agent.eval_step(state)
  452. _, reward, done = env.step(action)
  453. if done:
  454. rewards.append(reward)
  455. logger.log_performance(env.timestep, np.mean(rewards))
  456. # Close files in the logger
  457. logger.close_files()
  458. # Plot the learning curve
  459. logger.plot('DQN')
  460. # Save model
  461. save_dir = 'models/leduc_holdem_single_dqn'
  462. if not os.path.exists(save_dir):
  463. os.makedirs(save_dir)
  464. saver = tf.train.Saver()
  465. saver.save(sess, os.path.join(save_dir, 'model'))

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算