From a941422337635d5ae3ec44f759b6fe2fd536959a Mon Sep 17 00:00:00 2001 From: xiaoyisd Date: Thu, 11 Mar 2021 18:59:08 +0800 Subject: [PATCH] dqn --- model_zoo/official/rl/dqn/README.md | 128 ++++++++++++++++++ model_zoo/official/rl/dqn/eval.py | 67 +++++++++ model_zoo/official/rl/dqn/requirements.txt | 1 + .../dqn/scripts/run_standalone_eval_ascend.sh | 21 +++ .../rl/dqn/scripts/run_standalone_eval_gpu.sh | 21 +++ .../scripts/run_standalone_train_ascend.sh | 21 +++ .../dqn/scripts/run_standalone_train_gpu.sh | 21 +++ model_zoo/official/rl/dqn/src/agent.py | 94 +++++++++++++ model_zoo/official/rl/dqn/src/config.py | 31 +++++ model_zoo/official/rl/dqn/src/dqn.py | 47 +++++++ model_zoo/official/rl/dqn/train.py | 69 ++++++++++ 11 files changed, 521 insertions(+) create mode 100644 model_zoo/official/rl/dqn/README.md create mode 100644 model_zoo/official/rl/dqn/eval.py create mode 100644 model_zoo/official/rl/dqn/requirements.txt create mode 100755 model_zoo/official/rl/dqn/scripts/run_standalone_eval_ascend.sh create mode 100755 model_zoo/official/rl/dqn/scripts/run_standalone_eval_gpu.sh create mode 100755 model_zoo/official/rl/dqn/scripts/run_standalone_train_ascend.sh create mode 100755 model_zoo/official/rl/dqn/scripts/run_standalone_train_gpu.sh create mode 100644 model_zoo/official/rl/dqn/src/agent.py create mode 100644 model_zoo/official/rl/dqn/src/config.py create mode 100644 model_zoo/official/rl/dqn/src/dqn.py create mode 100644 model_zoo/official/rl/dqn/train.py diff --git a/model_zoo/official/rl/dqn/README.md b/model_zoo/official/rl/dqn/README.md new file mode 100644 index 0000000000..b9a44bdbdd --- /dev/null +++ b/model_zoo/official/rl/dqn/README.md @@ -0,0 +1,128 @@ +# Contents + +- [DQN Description](#DQN-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Requirements](#Requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) +- [Model Description](#model-description) + - [Performance](#performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [DQN Description](#contents) + +DQN is the first deep learning model to successfully learn control policies directly from high-dimensional sensory input using reinforcement learning. +[Paper](https://www.nature.com/articles/nature14236) Mnih, Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves et al. "Human-level control through deep reinforcement learning." nature 518, no. 7540 (2015): 529-533. + +## [Model Architecture](#content) + +The overall network architecture of DQN is show below: + +[Paper](https://www.nature.com/articles/nature14236) + +## [Dataset](#content) + +## [Requirements](#content) + +- Hardware(Ascend/GPU/CPU) + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +- third-party libraries + +```bash +pip install gym +``` + +## [Script Description](#content) + +### [Scripts and Sample Code](#contents) + +```python +├── dqn + ├── README.md # descriptions about DQN + ├── scripts + │ ├──run_standalone_eval_ascend.sh # shell script for evaluation with Ascend + │ ├──run_standalone_eval_gpu.sh # shell script for evaluation with GPU + │ ├──run_standalone_train_ascend.sh # shell script for train with Ascend + │ ├──run_standalone_train_gpu.sh # shell script for train with GPU + ├── src + │ ├──agent.py # model agent + │ ├──config.py # parameter configuration + │ ├──dqn.py # dqn architecture + ├── train.py # training script + ├── eval.py # evaluation script +``` + +### [Script Parameter](#content) + +```python + 'gamma': 0.8 # the proportion of choose next state value + 'epsi_high': 0.9 # the highest exploration rate + 'epsi_low': 0.05 # the Lowest exploration rate + 'decay': 200 # number of steps to start learning + 'lr': 0.001 # learning rate + 'capacity': 100000 # the capacity of data buffer + 'batch_size': 512 # training batch size + 'state_space_dim': 4 # the environment state space dim + 'action_space_dim': 2 # the action dim +``` + +### [Training Process](#content) + +```shell +# training example + python + Ascend: python train.py --device_target Ascend --ckpt_path ckpt > log.txt 2>&1 & + GPU: python train.py --device_target GPU --ckpt_path ckpt > log.txt 2>&1 & + + shell: + Ascend: sh run_standalone_train_ascend.sh ckpt + GPU: sh run_standalone_train_gpu.sh ckpt +``` + +### [Evaluation Process](#content) + +```shell +# evaluat example + python + Ascend: python eval.py --device_target Ascend --ckpt_path .ckpt/checkpoint_dqn.ckpt + GPU: python eval.py --device_target GPU --ckpt_path .ckpt/checkpoint_dqn.ckpt + + shell: + Ascend: sh run_standalone_eval_ascend.sh .ckpt/checkpoint_dqn.ckpt + GPU: sh run_standalone_eval_gpu.sh .ckpt/checkpoint_dqn.ckpt +``` + +## [Performance](#content) + +### Inference Performance + +| Parameters | DQN | +| -------------------------- | ----------------------------------------------------------- | +| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | +| uploaded Date | 03/10/2021 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Training Parameters | batch_size = 512, lr=0.001 | +| Optimizer | RMSProp | +| Loss Function | MSELoss | +| outputs | probability | +| Params (M) | 7.3k | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/rl/dqn | + +## [Description of Random Situation](#content) + +We use random seed in train.py. + +## [ModeZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file diff --git a/model_zoo/official/rl/dqn/eval.py b/model_zoo/official/rl/dqn/eval.py new file mode 100644 index 0000000000..59d51e9b71 --- /dev/null +++ b/model_zoo/official/rl/dqn/eval.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation for DQN""" + +import argparse +import gym +from mindspore import context +from mindspore.common import set_seed +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.config import config_dqn as cfg +from src.agent import Agent + +parser = argparse.ArgumentParser(description='MindSpore dqn Example') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--ckpt_path', type=str, default=None, help='if is test, must provide\ + path where the trained ckpt file') +args = parser.parse_args() +set_seed(1) + + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + env = gym.make('CartPole-v0') + cfg.state_space_dim = env.observation_space.shape[0] + cfg.action_space_dim = env.action_space.n + agent = Agent(**cfg) + + # load checkpoint + if args.ckpt_path: + param_dict = load_checkpoint(args.ckpt_path) + not_load_param = load_param_into_net(agent.policy_net, param_dict) + if not_load_param: + raise ValueError("Load param into net fail!") + + score = 0 + agent.load_dict() + for episode in range(50): + s0 = env.reset() + total_reward = 1 + while True: + a0 = agent.eval_act(s0) + s1, r1, done, _ = env.step(a0) + + if done: + r1 = -1 + + if done: + break + + total_reward += r1 + s0 = s1 + score += total_reward + print("episode", episode, "total_reward", total_reward) + print("mean_reward", score/50) diff --git a/model_zoo/official/rl/dqn/requirements.txt b/model_zoo/official/rl/dqn/requirements.txt new file mode 100644 index 0000000000..1e6c2dd43f --- /dev/null +++ b/model_zoo/official/rl/dqn/requirements.txt @@ -0,0 +1 @@ +gym diff --git a/model_zoo/official/rl/dqn/scripts/run_standalone_eval_ascend.sh b/model_zoo/official/rl/dqn/scripts/run_standalone_eval_ascend.sh new file mode 100755 index 0000000000..afd8f8fdc8 --- /dev/null +++ b/model_zoo/official/rl/dqn/scripts/run_standalone_eval_ascend.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# an simple tutorial as follows, more parameters can be setting +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +CKPT_PATH=$1 +python -s ${self_path}/../eval.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & diff --git a/model_zoo/official/rl/dqn/scripts/run_standalone_eval_gpu.sh b/model_zoo/official/rl/dqn/scripts/run_standalone_eval_gpu.sh new file mode 100755 index 0000000000..235b6e0160 --- /dev/null +++ b/model_zoo/official/rl/dqn/scripts/run_standalone_eval_gpu.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# an simple tutorial as follows, more parameters can be setting +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +CKPT_PATH=$1 +python -s ${self_path}/../eval.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & diff --git a/model_zoo/official/rl/dqn/scripts/run_standalone_train_ascend.sh b/model_zoo/official/rl/dqn/scripts/run_standalone_train_ascend.sh new file mode 100755 index 0000000000..99939b4e66 --- /dev/null +++ b/model_zoo/official/rl/dqn/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# an simple tutorial as follows, more parameters can be setting +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +CKPT_PATH=$1 +python -s ${self_path}/../train.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/rl/dqn/scripts/run_standalone_train_gpu.sh b/model_zoo/official/rl/dqn/scripts/run_standalone_train_gpu.sh new file mode 100755 index 0000000000..f4bc654512 --- /dev/null +++ b/model_zoo/official/rl/dqn/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# an simple tutorial as follows, more parameters can be setting +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +CKPT_PATH=$1 +python -s ${self_path}/../train.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & diff --git a/model_zoo/official/rl/dqn/src/agent.py b/model_zoo/official/rl/dqn/src/agent.py new file mode 100644 index 0000000000..c76841cdb5 --- /dev/null +++ b/model_zoo/official/rl/dqn/src/agent.py @@ -0,0 +1,94 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Agent of reinforcement learning network""" + +import random +import math +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.common.dtype as mstype +from src.dqn import DQN, WithLossCell + + +class Agent: + """ + DQN Agent + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim) + self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim) + self.optimizer = nn.RMSProp(self.policy_net.trainable_params(), learning_rate=self.lr) + loss_fn = nn.MSELoss() + loss_q_net = WithLossCell(self.policy_net, loss_fn) + self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer) + self.policy_net_train.set_train(mode=True) + self.buffer = [] + self.steps = 0 + + def act(self, s0): + """ + Agent choose action. + """ + self.steps += 1 + epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay)) + if random.random() < epsi: + a0 = random.randrange(self.action_space_dim) + else: + s0 = np.expand_dims(s0, axis=0) + s0 = Tensor(s0, mstype.float32) + a0 = self.policy_net(s0).asnumpy() + a0 = np.argmax(a0) + return a0 + + def eval_act(self, s0): + self.steps += 1 + s0 = np.expand_dims(s0, axis=0) + s0 = Tensor(s0, mstype.float32) + a0 = self.policy_net(s0).asnumpy() + a0 = np.argmax(a0) + return a0 + + def put(self, *transition): + if len(self.buffer) == self.capacity: + self.buffer.pop(0) + self.buffer.append(transition) + + def load_dict(self): + for target_item, source_item in zip(self.target_net.parameters_dict(), self.policy_net.parameters_dict()): + target_param = self.target_net.parameters_dict()[target_item] + source_param = self.policy_net.parameters_dict()[source_item] + target_param.set_data(source_param.data) + + def learn(self): + """ + Agent learn from experience data. + """ + if (len(self.buffer)) < self.batch_size: + return + + samples = random.sample(self.buffer, self.batch_size) + s0, a0, r1, s1 = zip(*samples) + s1 = Tensor(s1, mstype.float32) + s0 = Tensor(s0, mstype.float32) + a0 = Tensor(np.expand_dims(a0, axis=1)) + next_state_values = self.target_net(s1).asnumpy() + next_state_values = np.max(next_state_values, axis=1) + + y_true = r1 + self.gamma * next_state_values + y_true = Tensor(np.expand_dims(y_true, axis=1), mstype.float32) + self.policy_net_train(s0, a0, y_true) diff --git a/model_zoo/official/rl/dqn/src/config.py b/model_zoo/official/rl/dqn/src/config.py new file mode 100644 index 0000000000..6d7a7ef53f --- /dev/null +++ b/model_zoo/official/rl/dqn/src/config.py @@ -0,0 +1,31 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" + +from easydict import EasyDict as edict + +config_dqn = edict({ + 'gamma': 0.8, + 'epsi_high': 0.9, + 'epsi_low': 0.05, + 'decay': 200, + 'lr': 0.001, + 'capacity': 100000, + 'batch_size': 512, + 'state_space_dim': 4, + 'action_space_dim': 2 +}) diff --git a/model_zoo/official/rl/dqn/src/dqn.py b/model_zoo/official/rl/dqn/src/dqn.py new file mode 100644 index 0000000000..1a3e0b2dd8 --- /dev/null +++ b/model_zoo/official/rl/dqn/src/dqn.py @@ -0,0 +1,47 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DQN net""" + +import mindspore.nn as nn +import mindspore.ops as ops + + +class DQN(nn. Cell): + def __init__(self, input_size, hidden_size, output_size): + super(DQN, self).__init__() + self.linear1 = nn.Dense(input_size, hidden_size) + self.linear2 = nn.Dense(hidden_size, output_size) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.relu(self.linear1(x)) + return self.linear2(x) + + +class WithLossCell(nn.Cell): + """ + network with loss function + """ + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + self.gather = ops.GatherD() + + def construct(self, x, act, label): + out = self._backbone(x) + out = self.gather(out, 1, act) + loss = self._loss_fn(out, label) + return loss diff --git a/model_zoo/official/rl/dqn/train.py b/model_zoo/official/rl/dqn/train.py new file mode 100644 index 0000000000..919c7193fd --- /dev/null +++ b/model_zoo/official/rl/dqn/train.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train DQN and get checkpoint files.""" + +import os +import argparse +import gym +from mindspore import context +from mindspore.common import set_seed +from mindspore.train.serialization import save_checkpoint +from src.config import config_dqn as cfg +from src.agent import Agent + +parser = argparse.ArgumentParser(description='MindSpore dqn Example') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') +args = parser.parse_args() +set_seed(1) + + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + env = gym.make('CartPole-v0') + cfg.state_space_dim = env.observation_space.shape[0] + cfg.action_space_dim = env.action_space.n + agent = Agent(**cfg) + agent.load_dict() + + for episode in range(150): + s0 = env.reset() + total_reward = 1 + while True: + a0 = agent.act(s0) + s1, r1, done, _ = env.step(a0) + + if done: + r1 = -1 + + agent.put(s0, a0, r1, s1) + + if done: + break + + total_reward += r1 + s0 = s1 + agent.learn() + agent.load_dict() + print("episode", episode, "total_reward", total_reward) + + path = os.path.realpath(args.ckpt_path) + if not os.path.exists(path): + os.makedirs(path) + + ckpt_name = path + "/dqn.ckpt" + save_checkpoint(agent.policy_net, ckpt_name)