You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.3 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train DQN and get checkpoint files."""
  16. import os
  17. import argparse
  18. import gym
  19. from mindspore import context
  20. from mindspore.common import set_seed
  21. from mindspore.train.serialization import save_checkpoint
  22. from src.config import config_dqn as cfg
  23. from src.agent import Agent
  24. parser = argparse.ArgumentParser(description='MindSpore dqn Example')
  25. parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
  26. help='device where the code will be implemented (default: Ascend)')
  27. parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
  28. path where the trained ckpt file')
  29. args = parser.parse_args()
  30. set_seed(1)
  31. if __name__ == "__main__":
  32. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  33. env = gym.make('CartPole-v1')
  34. cfg.state_space_dim = env.observation_space.shape[0]
  35. cfg.action_space_dim = env.action_space.n
  36. agent = Agent(**cfg)
  37. agent.load_dict()
  38. for episode in range(300):
  39. s0 = env.reset()
  40. total_reward = 1
  41. while True:
  42. a0 = agent.act(s0)
  43. s1, r1, done, _ = env.step(a0)
  44. if done:
  45. r1 = -1
  46. agent.put(s0, a0, r1, s1)
  47. if done:
  48. break
  49. total_reward += r1
  50. s0 = s1
  51. agent.learn()
  52. agent.load_dict()
  53. print("episode", episode, "total_reward", total_reward)
  54. path = os.path.realpath(args.ckpt_path)
  55. if not os.path.exists(path):
  56. os.makedirs(path)
  57. ckpt_name = path + "/dqn.ckpt"
  58. save_checkpoint(agent.policy_net, ckpt_name)