| @@ -0,0 +1,128 @@ | |||||
| # Contents | |||||
| - [DQN Description](#DQN-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Requirements](#Requirements) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Script Parameters](#script-parameters) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Description of Random Situation](#description-of-random-situation) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [DQN Description](#contents) | |||||
| DQN is the first deep learning model to successfully learn control policies directly from high-dimensional sensory input using reinforcement learning. | |||||
| [Paper](https://www.nature.com/articles/nature14236) Mnih, Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves et al. "Human-level control through deep reinforcement learning." nature 518, no. 7540 (2015): 529-533. | |||||
| ## [Model Architecture](#content) | |||||
| The overall network architecture of DQN is show below: | |||||
| [Paper](https://www.nature.com/articles/nature14236) | |||||
| ## [Dataset](#content) | |||||
| ## [Requirements](#content) | |||||
| - Hardware(Ascend/GPU/CPU) | |||||
| - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||||
| - Framework | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| - third-party libraries | |||||
| ```bash | |||||
| pip install gym | |||||
| ``` | |||||
| ## [Script Description](#content) | |||||
| ### [Scripts and Sample Code](#contents) | |||||
| ```python | |||||
| ├── dqn | |||||
| ├── README.md # descriptions about DQN | |||||
| ├── scripts | |||||
| │ ├──run_standalone_eval_ascend.sh # shell script for evaluation with Ascend | |||||
| │ ├──run_standalone_eval_gpu.sh # shell script for evaluation with GPU | |||||
| │ ├──run_standalone_train_ascend.sh # shell script for train with Ascend | |||||
| │ ├──run_standalone_train_gpu.sh # shell script for train with GPU | |||||
| ├── src | |||||
| │ ├──agent.py # model agent | |||||
| │ ├──config.py # parameter configuration | |||||
| │ ├──dqn.py # dqn architecture | |||||
| ├── train.py # training script | |||||
| ├── eval.py # evaluation script | |||||
| ``` | |||||
| ### [Script Parameter](#content) | |||||
| ```python | |||||
| 'gamma': 0.8 # the proportion of choose next state value | |||||
| 'epsi_high': 0.9 # the highest exploration rate | |||||
| 'epsi_low': 0.05 # the Lowest exploration rate | |||||
| 'decay': 200 # number of steps to start learning | |||||
| 'lr': 0.001 # learning rate | |||||
| 'capacity': 100000 # the capacity of data buffer | |||||
| 'batch_size': 512 # training batch size | |||||
| 'state_space_dim': 4 # the environment state space dim | |||||
| 'action_space_dim': 2 # the action dim | |||||
| ``` | |||||
| ### [Training Process](#content) | |||||
| ```shell | |||||
| # training example | |||||
| python | |||||
| Ascend: python train.py --device_target Ascend --ckpt_path ckpt > log.txt 2>&1 & | |||||
| GPU: python train.py --device_target GPU --ckpt_path ckpt > log.txt 2>&1 & | |||||
| shell: | |||||
| Ascend: sh run_standalone_train_ascend.sh ckpt | |||||
| GPU: sh run_standalone_train_gpu.sh ckpt | |||||
| ``` | |||||
| ### [Evaluation Process](#content) | |||||
| ```shell | |||||
| # evaluat example | |||||
| python | |||||
| Ascend: python eval.py --device_target Ascend --ckpt_path .ckpt/checkpoint_dqn.ckpt | |||||
| GPU: python eval.py --device_target GPU --ckpt_path .ckpt/checkpoint_dqn.ckpt | |||||
| shell: | |||||
| Ascend: sh run_standalone_eval_ascend.sh .ckpt/checkpoint_dqn.ckpt | |||||
| GPU: sh run_standalone_eval_gpu.sh .ckpt/checkpoint_dqn.ckpt | |||||
| ``` | |||||
| ## [Performance](#content) | |||||
| ### Inference Performance | |||||
| | Parameters | DQN | | |||||
| | -------------------------- | ----------------------------------------------------------- | | |||||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | |||||
| | uploaded Date | 03/10/2021 (month/day/year) | | |||||
| | MindSpore Version | 1.1.0 | | |||||
| | Training Parameters | batch_size = 512, lr=0.001 | | |||||
| | Optimizer | RMSProp | | |||||
| | Loss Function | MSELoss | | |||||
| | outputs | probability | | |||||
| | Params (M) | 7.3k | | |||||
| | Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/rl/dqn | | |||||
| ## [Description of Random Situation](#content) | |||||
| We use random seed in train.py. | |||||
| ## [ModeZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Evaluation for DQN""" | |||||
| import argparse | |||||
| import gym | |||||
| from mindspore import context | |||||
| from mindspore.common import set_seed | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import config_dqn as cfg | |||||
| from src.agent import Agent | |||||
| parser = argparse.ArgumentParser(description='MindSpore dqn Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--ckpt_path', type=str, default=None, help='if is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| set_seed(1) | |||||
| if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| env = gym.make('CartPole-v0') | |||||
| cfg.state_space_dim = env.observation_space.shape[0] | |||||
| cfg.action_space_dim = env.action_space.n | |||||
| agent = Agent(**cfg) | |||||
| # load checkpoint | |||||
| if args.ckpt_path: | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| not_load_param = load_param_into_net(agent.policy_net, param_dict) | |||||
| if not_load_param: | |||||
| raise ValueError("Load param into net fail!") | |||||
| score = 0 | |||||
| agent.load_dict() | |||||
| for episode in range(50): | |||||
| s0 = env.reset() | |||||
| total_reward = 1 | |||||
| while True: | |||||
| a0 = agent.eval_act(s0) | |||||
| s1, r1, done, _ = env.step(a0) | |||||
| if done: | |||||
| r1 = -1 | |||||
| if done: | |||||
| break | |||||
| total_reward += r1 | |||||
| s0 = s1 | |||||
| score += total_reward | |||||
| print("episode", episode, "total_reward", total_reward) | |||||
| print("mean_reward", score/50) | |||||
| @@ -0,0 +1 @@ | |||||
| gym | |||||
| @@ -0,0 +1,21 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # an simple tutorial as follows, more parameters can be setting | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| CKPT_PATH=$1 | |||||
| python -s ${self_path}/../eval.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||||
| @@ -0,0 +1,21 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # an simple tutorial as follows, more parameters can be setting | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| CKPT_PATH=$1 | |||||
| python -s ${self_path}/../eval.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||||
| @@ -0,0 +1,21 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # an simple tutorial as follows, more parameters can be setting | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| CKPT_PATH=$1 | |||||
| python -s ${self_path}/../train.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||||
| @@ -0,0 +1,21 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # an simple tutorial as follows, more parameters can be setting | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| CKPT_PATH=$1 | |||||
| python -s ${self_path}/../train.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & | |||||
| @@ -0,0 +1,94 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Agent of reinforcement learning network""" | |||||
| import random | |||||
| import math | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| from src.dqn import DQN, WithLossCell | |||||
| class Agent: | |||||
| """ | |||||
| DQN Agent | |||||
| """ | |||||
| def __init__(self, **kwargs): | |||||
| for key, value in kwargs.items(): | |||||
| setattr(self, key, value) | |||||
| self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim) | |||||
| self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim) | |||||
| self.optimizer = nn.RMSProp(self.policy_net.trainable_params(), learning_rate=self.lr) | |||||
| loss_fn = nn.MSELoss() | |||||
| loss_q_net = WithLossCell(self.policy_net, loss_fn) | |||||
| self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer) | |||||
| self.policy_net_train.set_train(mode=True) | |||||
| self.buffer = [] | |||||
| self.steps = 0 | |||||
| def act(self, s0): | |||||
| """ | |||||
| Agent choose action. | |||||
| """ | |||||
| self.steps += 1 | |||||
| epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay)) | |||||
| if random.random() < epsi: | |||||
| a0 = random.randrange(self.action_space_dim) | |||||
| else: | |||||
| s0 = np.expand_dims(s0, axis=0) | |||||
| s0 = Tensor(s0, mstype.float32) | |||||
| a0 = self.policy_net(s0).asnumpy() | |||||
| a0 = np.argmax(a0) | |||||
| return a0 | |||||
| def eval_act(self, s0): | |||||
| self.steps += 1 | |||||
| s0 = np.expand_dims(s0, axis=0) | |||||
| s0 = Tensor(s0, mstype.float32) | |||||
| a0 = self.policy_net(s0).asnumpy() | |||||
| a0 = np.argmax(a0) | |||||
| return a0 | |||||
| def put(self, *transition): | |||||
| if len(self.buffer) == self.capacity: | |||||
| self.buffer.pop(0) | |||||
| self.buffer.append(transition) | |||||
| def load_dict(self): | |||||
| for target_item, source_item in zip(self.target_net.parameters_dict(), self.policy_net.parameters_dict()): | |||||
| target_param = self.target_net.parameters_dict()[target_item] | |||||
| source_param = self.policy_net.parameters_dict()[source_item] | |||||
| target_param.set_data(source_param.data) | |||||
| def learn(self): | |||||
| """ | |||||
| Agent learn from experience data. | |||||
| """ | |||||
| if (len(self.buffer)) < self.batch_size: | |||||
| return | |||||
| samples = random.sample(self.buffer, self.batch_size) | |||||
| s0, a0, r1, s1 = zip(*samples) | |||||
| s1 = Tensor(s1, mstype.float32) | |||||
| s0 = Tensor(s0, mstype.float32) | |||||
| a0 = Tensor(np.expand_dims(a0, axis=1)) | |||||
| next_state_values = self.target_net(s1).asnumpy() | |||||
| next_state_values = np.max(next_state_values, axis=1) | |||||
| y_true = r1 + self.gamma * next_state_values | |||||
| y_true = Tensor(np.expand_dims(y_true, axis=1), mstype.float32) | |||||
| self.policy_net_train(s0, a0, y_true) | |||||
| @@ -0,0 +1,31 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py and eval.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| config_dqn = edict({ | |||||
| 'gamma': 0.8, | |||||
| 'epsi_high': 0.9, | |||||
| 'epsi_low': 0.05, | |||||
| 'decay': 200, | |||||
| 'lr': 0.001, | |||||
| 'capacity': 100000, | |||||
| 'batch_size': 512, | |||||
| 'state_space_dim': 4, | |||||
| 'action_space_dim': 2 | |||||
| }) | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """DQN net""" | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops as ops | |||||
| class DQN(nn. Cell): | |||||
| def __init__(self, input_size, hidden_size, output_size): | |||||
| super(DQN, self).__init__() | |||||
| self.linear1 = nn.Dense(input_size, hidden_size) | |||||
| self.linear2 = nn.Dense(hidden_size, output_size) | |||||
| self.relu = nn.ReLU() | |||||
| def construct(self, x): | |||||
| x = self.relu(self.linear1(x)) | |||||
| return self.linear2(x) | |||||
| class WithLossCell(nn.Cell): | |||||
| """ | |||||
| network with loss function | |||||
| """ | |||||
| def __init__(self, backbone, loss_fn): | |||||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| self.gather = ops.GatherD() | |||||
| def construct(self, x, act, label): | |||||
| out = self._backbone(x) | |||||
| out = self.gather(out, 1, act) | |||||
| loss = self._loss_fn(out, label) | |||||
| return loss | |||||
| @@ -0,0 +1,69 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Train DQN and get checkpoint files.""" | |||||
| import os | |||||
| import argparse | |||||
| import gym | |||||
| from mindspore import context | |||||
| from mindspore.common import set_seed | |||||
| from mindspore.train.serialization import save_checkpoint | |||||
| from src.config import config_dqn as cfg | |||||
| from src.agent import Agent | |||||
| parser = argparse.ArgumentParser(description='MindSpore dqn Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| set_seed(1) | |||||
| if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| env = gym.make('CartPole-v0') | |||||
| cfg.state_space_dim = env.observation_space.shape[0] | |||||
| cfg.action_space_dim = env.action_space.n | |||||
| agent = Agent(**cfg) | |||||
| agent.load_dict() | |||||
| for episode in range(150): | |||||
| s0 = env.reset() | |||||
| total_reward = 1 | |||||
| while True: | |||||
| a0 = agent.act(s0) | |||||
| s1, r1, done, _ = env.step(a0) | |||||
| if done: | |||||
| r1 = -1 | |||||
| agent.put(s0, a0, r1, s1) | |||||
| if done: | |||||
| break | |||||
| total_reward += r1 | |||||
| s0 = s1 | |||||
| agent.learn() | |||||
| agent.load_dict() | |||||
| print("episode", episode, "total_reward", total_reward) | |||||
| path = os.path.realpath(args.ckpt_path) | |||||
| if not os.path.exists(path): | |||||
| os.makedirs(path) | |||||
| ckpt_name = path + "/dqn.ckpt" | |||||
| save_checkpoint(agent.policy_net, ckpt_name) | |||||