Amazing-Python-Scripts

Форк
0
76 строк · 2.2 Кб
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Thu Jun  3 13:06:20 2021
4
@author: Ayush
5
"""
6

7
import numpy as np
8
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
9
from rl.memory import SequentialMemory
10
from rl.agents import DQNAgent
11
from tensorflow.keras.optimizers import Adam
12
from tensorflow.keras.layers import Dense, Flatten, Conv2D
13
from tensorflow.keras.models import Sequential
14
import gym
15
env = gym.make("SpaceInvaders-v0")
16

17
episodes = 10
18

19
for episode in range(1, episodes):
20
    state = env.reset()
21
    done = False
22
    score = 0
23

24
    while not done:
25
        env.render()
26
        state, reward, done, info = env.step(env.action_space.sample())
27
        score += reward
28
    print('Episode: {}\nScore: {}'.format(episode, score))
29
env.close()
30

31
# Import Neural Network Packages
32

33

34
def build_model(height, width, channels, actions):
35
    model = Sequential()
36
    model.add(Conv2D(32, (8, 8), strides=(4, 4), activation='relu',
37
              input_shape=(3, height, width, channels)))
38
    model.add(Conv2D(64, (4, 4), strides=(2, 2), activation='relu'))
39
    model.add(Conv2D(64, (4, 4), strides=(2, 2), activation='relu'))
40
    model.add(Flatten())
41
    model.add(Dense(512, activation='relu'))
42
    model.add(Dense(256, activation='relu'))
43
    model.add(Dense(64, activation='relu'))
44
    model.add(Dense(actions, activation='linear'))
45
    return model
46

47

48
height, width, channels = env.observation_space.shape
49
actions = env.action_space.n
50

51
model = build_model(height, width, channels, actions)
52

53

54
# Importing keras-rl2 reinforcement learning functions
55

56
def build_agent(model, actions):
57
    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(
58
    ), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
59
    memory = SequentialMemory(limit=2000, window_length=3)
60
    dqn = DQNAgent(model=model, memory=memory, policy=policy, enable_dueling_network=True,
61
                   dueling_type='avg', nb_actions=actions, nb_steps_warmup=1000)
62
    return dqn
63

64

65
dqn = build_agent(model, actions)
66

67

68
dqn.compile(Adam(lr=0.001))
69

70
dqn.fit(env, nb_steps=40000, visualize=True, verbose=1)
71

72
scores = dqn.test(env, nb_episodes=10, visualize=True)
73
print(np.mean(scores.history['episode_reward']))
74

75
dqn.save_weights('models/dqn.h5f')
76
dqn.load_weights('models/dqn.h5f')
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.