| @@ -0,0 +1,128 @@ | |||
| # Contents | |||
| - [DQN Description](#DQN-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Requirements](#Requirements) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Process](#training-process) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [DQN Description](#contents) | |||
| DQN is the first deep learning model to successfully learn control policies directly from high-dimensional sensory input using reinforcement learning. | |||
| [Paper](https://www.nature.com/articles/nature14236) Mnih, Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves et al. "Human-level control through deep reinforcement learning." nature 518, no. 7540 (2015): 529-533. | |||
| ## [Model Architecture](#content) | |||
| The overall network architecture of DQN is show below: | |||
| [Paper](https://www.nature.com/articles/nature14236) | |||
| ## [Dataset](#content) | |||
| ## [Requirements](#content) | |||
| - Hardware(Ascend/GPU/CPU) | |||
| - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| - third-party libraries | |||
| ```bash | |||
| pip install gym | |||
| ``` | |||
| ## [Script Description](#content) | |||
| ### [Scripts and Sample Code](#contents) | |||
| ```python | |||
| ├── dqn | |||
| ├── README.md # descriptions about DQN | |||
| ├── scripts | |||
| │ ├──run_standalone_eval_ascend.sh # shell script for evaluation with Ascend | |||
| │ ├──run_standalone_eval_gpu.sh # shell script for evaluation with GPU | |||
| │ ├──run_standalone_train_ascend.sh # shell script for train with Ascend | |||
| │ ├──run_standalone_train_gpu.sh # shell script for train with GPU | |||
| ├── src | |||
| │ ├──agent.py # model agent | |||
| │ ├──config.py # parameter configuration | |||
| │ ├──dqn.py # dqn architecture | |||
| ├── train.py # training script | |||
| ├── eval.py # evaluation script | |||
| ``` | |||
| ### [Script Parameter](#content) | |||
| ```python | |||
| 'gamma': 0.8 # the proportion of choose next state value | |||
| 'epsi_high': 0.9 # the highest exploration rate | |||
| 'epsi_low': 0.05 # the Lowest exploration rate | |||
| 'decay': 200 # number of steps to start learning | |||
| 'lr': 0.001 # learning rate | |||
| 'capacity': 100000 # the capacity of data buffer | |||
| 'batch_size': 512 # training batch size | |||
| 'state_space_dim': 4 # the environment state space dim | |||
| 'action_space_dim': 2 # the action dim | |||
| ``` | |||
| ### [Training Process](#content) | |||
| ```shell | |||
| # training example | |||
| python | |||
| Ascend: python train.py --device_target Ascend --ckpt_path ckpt > log.txt 2>&1 & | |||
| GPU: python train.py --device_target GPU --ckpt_path ckpt > log.txt 2>&1 & | |||
| shell: | |||
| Ascend: sh run_standalone_train_ascend.sh ckpt | |||
| GPU: sh run_standalone_train_gpu.sh ckpt | |||
| ``` | |||
| ### [Evaluation Process](#content) | |||
| ```shell | |||
| # evaluat example | |||
| python | |||
| Ascend: python eval.py --device_target Ascend --ckpt_path .ckpt/checkpoint_dqn.ckpt | |||
| GPU: python eval.py --device_target GPU --ckpt_path .ckpt/checkpoint_dqn.ckpt | |||
| shell: | |||
| Ascend: sh run_standalone_eval_ascend.sh .ckpt/checkpoint_dqn.ckpt | |||
| GPU: sh run_standalone_eval_gpu.sh .ckpt/checkpoint_dqn.ckpt | |||
| ``` | |||
| ## [Performance](#content) | |||
| ### Inference Performance | |||
| | Parameters | DQN | | |||
| | -------------------------- | ----------------------------------------------------------- | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | |||
| | uploaded Date | 03/10/2021 (month/day/year) | | |||
| | MindSpore Version | 1.1.0 | | |||
| | Training Parameters | batch_size = 512, lr=0.001 | | |||
| | Optimizer | RMSProp | | |||
| | Loss Function | MSELoss | | |||
| | outputs | probability | | |||
| | Params (M) | 7.3k | | |||
| | Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/rl/dqn | | |||
| ## [Description of Random Situation](#content) | |||
| We use random seed in train.py. | |||
| ## [ModeZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation for DQN""" | |||
| import argparse | |||
| import gym | |||
| from mindspore import context | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config_dqn as cfg | |||
| from src.agent import Agent | |||
| parser = argparse.ArgumentParser(description='MindSpore dqn Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--ckpt_path', type=str, default=None, help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| set_seed(1) | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| env = gym.make('CartPole-v0') | |||
| cfg.state_space_dim = env.observation_space.shape[0] | |||
| cfg.action_space_dim = env.action_space.n | |||
| agent = Agent(**cfg) | |||
| # load checkpoint | |||
| if args.ckpt_path: | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| not_load_param = load_param_into_net(agent.policy_net, param_dict) | |||
| if not_load_param: | |||
| raise ValueError("Load param into net fail!") | |||
| score = 0 | |||
| agent.load_dict() | |||
| for episode in range(50): | |||
| s0 = env.reset() | |||
| total_reward = 1 | |||
| while True: | |||
| a0 = agent.eval_act(s0) | |||
| s1, r1, done, _ = env.step(a0) | |||
| if done: | |||
| r1 = -1 | |||
| if done: | |||
| break | |||
| total_reward += r1 | |||
| s0 = s1 | |||
| score += total_reward | |||
| print("episode", episode, "total_reward", total_reward) | |||
| print("mean_reward", score/50) | |||
| @@ -0,0 +1 @@ | |||
| gym | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| CKPT_PATH=$1 | |||
| python -s ${self_path}/../eval.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| CKPT_PATH=$1 | |||
| python -s ${self_path}/../eval.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| CKPT_PATH=$1 | |||
| python -s ${self_path}/../train.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| CKPT_PATH=$1 | |||
| python -s ${self_path}/../train.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Agent of reinforcement learning network""" | |||
| import random | |||
| import math | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from src.dqn import DQN, WithLossCell | |||
| class Agent: | |||
| """ | |||
| DQN Agent | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| for key, value in kwargs.items(): | |||
| setattr(self, key, value) | |||
| self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim) | |||
| self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim) | |||
| self.optimizer = nn.RMSProp(self.policy_net.trainable_params(), learning_rate=self.lr) | |||
| loss_fn = nn.MSELoss() | |||
| loss_q_net = WithLossCell(self.policy_net, loss_fn) | |||
| self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer) | |||
| self.policy_net_train.set_train(mode=True) | |||
| self.buffer = [] | |||
| self.steps = 0 | |||
| def act(self, s0): | |||
| """ | |||
| Agent choose action. | |||
| """ | |||
| self.steps += 1 | |||
| epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay)) | |||
| if random.random() < epsi: | |||
| a0 = random.randrange(self.action_space_dim) | |||
| else: | |||
| s0 = np.expand_dims(s0, axis=0) | |||
| s0 = Tensor(s0, mstype.float32) | |||
| a0 = self.policy_net(s0).asnumpy() | |||
| a0 = np.argmax(a0) | |||
| return a0 | |||
| def eval_act(self, s0): | |||
| self.steps += 1 | |||
| s0 = np.expand_dims(s0, axis=0) | |||
| s0 = Tensor(s0, mstype.float32) | |||
| a0 = self.policy_net(s0).asnumpy() | |||
| a0 = np.argmax(a0) | |||
| return a0 | |||
| def put(self, *transition): | |||
| if len(self.buffer) == self.capacity: | |||
| self.buffer.pop(0) | |||
| self.buffer.append(transition) | |||
| def load_dict(self): | |||
| for target_item, source_item in zip(self.target_net.parameters_dict(), self.policy_net.parameters_dict()): | |||
| target_param = self.target_net.parameters_dict()[target_item] | |||
| source_param = self.policy_net.parameters_dict()[source_item] | |||
| target_param.set_data(source_param.data) | |||
| def learn(self): | |||
| """ | |||
| Agent learn from experience data. | |||
| """ | |||
| if (len(self.buffer)) < self.batch_size: | |||
| return | |||
| samples = random.sample(self.buffer, self.batch_size) | |||
| s0, a0, r1, s1 = zip(*samples) | |||
| s1 = Tensor(s1, mstype.float32) | |||
| s0 = Tensor(s0, mstype.float32) | |||
| a0 = Tensor(np.expand_dims(a0, axis=1)) | |||
| next_state_values = self.target_net(s1).asnumpy() | |||
| next_state_values = np.max(next_state_values, axis=1) | |||
| y_true = r1 + self.gamma * next_state_values | |||
| y_true = Tensor(np.expand_dims(y_true, axis=1), mstype.float32) | |||
| self.policy_net_train(s0, a0, y_true) | |||
| @@ -0,0 +1,31 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py and eval.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| config_dqn = edict({ | |||
| 'gamma': 0.8, | |||
| 'epsi_high': 0.9, | |||
| 'epsi_low': 0.05, | |||
| 'decay': 200, | |||
| 'lr': 0.001, | |||
| 'capacity': 100000, | |||
| 'batch_size': 512, | |||
| 'state_space_dim': 4, | |||
| 'action_space_dim': 2 | |||
| }) | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DQN net""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| class DQN(nn. Cell): | |||
| def __init__(self, input_size, hidden_size, output_size): | |||
| super(DQN, self).__init__() | |||
| self.linear1 = nn.Dense(input_size, hidden_size) | |||
| self.linear2 = nn.Dense(hidden_size, output_size) | |||
| self.relu = nn.ReLU() | |||
| def construct(self, x): | |||
| x = self.relu(self.linear1(x)) | |||
| return self.linear2(x) | |||
| class WithLossCell(nn.Cell): | |||
| """ | |||
| network with loss function | |||
| """ | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| self.gather = ops.GatherD() | |||
| def construct(self, x, act, label): | |||
| out = self._backbone(x) | |||
| out = self.gather(out, 1, act) | |||
| loss = self._loss_fn(out, label) | |||
| return loss | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Train DQN and get checkpoint files.""" | |||
| import os | |||
| import argparse | |||
| import gym | |||
| from mindspore import context | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from src.config import config_dqn as cfg | |||
| from src.agent import Agent | |||
| parser = argparse.ArgumentParser(description='MindSpore dqn Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| set_seed(1) | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| env = gym.make('CartPole-v0') | |||
| cfg.state_space_dim = env.observation_space.shape[0] | |||
| cfg.action_space_dim = env.action_space.n | |||
| agent = Agent(**cfg) | |||
| agent.load_dict() | |||
| for episode in range(150): | |||
| s0 = env.reset() | |||
| total_reward = 1 | |||
| while True: | |||
| a0 = agent.act(s0) | |||
| s1, r1, done, _ = env.step(a0) | |||
| if done: | |||
| r1 = -1 | |||
| agent.put(s0, a0, r1, s1) | |||
| if done: | |||
| break | |||
| total_reward += r1 | |||
| s0 = s1 | |||
| agent.learn() | |||
| agent.load_dict() | |||
| print("episode", episode, "total_reward", total_reward) | |||
| path = os.path.realpath(args.ckpt_path) | |||
| if not os.path.exists(path): | |||
| os.makedirs(path) | |||
| ckpt_name = path + "/dqn.ckpt" | |||
| save_checkpoint(agent.policy_net, ckpt_name) | |||