|
|
|
@@ -34,13 +34,13 @@ set_seed(1) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
env = gym.make('CartPole-v0') |
|
|
|
env = gym.make('CartPole-v1') |
|
|
|
cfg.state_space_dim = env.observation_space.shape[0] |
|
|
|
cfg.action_space_dim = env.action_space.n |
|
|
|
agent = Agent(**cfg) |
|
|
|
agent.load_dict() |
|
|
|
|
|
|
|
for episode in range(150): |
|
|
|
for episode in range(300): |
|
|
|
s0 = env.reset() |
|
|
|
total_reward = 1 |
|
|
|
while True: |
|
|
|
|