Amazing-Python-Scripts

Форк
0
155 строк · 5.1 Кб
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Sun Jun  6 10:57:01 2021
4

5
@author: Ayush
6
"""
7

8
# Import Dependencies
9
import random
10
import numpy as np
11
import flappy_bird_gym
12
from collections import deque
13
from tensorflow.keras.layers import Input, Dense
14
from tensorflow.keras.models import load_model, save_model, Sequential
15
from tensorflow.keras.optimizers import RMSprop
16

17
# Neural Network for Agent
18

19

20
def NeuralNetwork(input_shape, output_shape):
21
    model = Sequential()
22
    model.add(Input(input_shape))
23
    model.add(Dense(512, input_shape=input_shape,
24
              activation='relu', kernel_initializer='he_uniform'))
25
    model.add(Dense(256, activation='relu', kernel_initializer='he_uniform'))
26
    model.add(Dense(64, activation='relu', kernel_initializer='he_uniform'))
27
    model.add(Dense(output_shape, activation='linear',
28
              kernel_initializer='he_uniform'))
29
    model.compile(loss='mse', optimizer=RMSprop(
30
        lr=0.0001, rho=0.95, epsilon=0.01), metrics=['accuracy'])
31
    model.summary()
32
    return model
33

34

35
# Brain of Agent || BluePrint of Agent
36

37
class DQNAgent:
38
    def __init__(self):
39
        self.env = flappy_bird_gym.make("FlappyBird-v0")
40
        self.episodes = 1000
41
        self.state_space = self.env.observation_space.shape[0]
42
        self.action_space = self.env.action_space.n
43
        self.memory = deque(maxlen=2000)
44

45
        # Hyperparameters
46
        self.gamma = 0.95
47
        self.epsilon = 1
48
        self.epsilon_decay = 0.9999
49
        self.epsilon_min = 0.01
50
        self.batch_number = 64  # 16, 32, 128, 256
51

52
        self.train_start = 1000
53
        self.jump_prob = 0.01
54
        self.model = NeuralNetwork(input_shape=(
55
            self.state_space,), output_shape=self.action_space)
56

57
    def act(self, state):
58
        if np.random.random() > self.epsilon:
59
            return np.argmax(self.model.predict(state))
60
        return 1 if np.random.random() < self.jump_prob else 0
61

62
    def learn(self):
63
        # Make sure we have enough data
64
        if len(self.memory) < self.train_start:
65
            return
66

67
        # Create minibatch
68
        minibatch = random.sample(self.memory, min(
69
            len(self.memory), self.batch_number))
70
        # Variables to store minibatch info
71
        state = np.zeros((self.batch_number, self.state_space))
72
        next_state = np.zeros((self.batch_number, self.state_space))
73

74
        action, reward, done = [], [], []
75

76
        # Store data in variables
77
        for i in range(self.batch_number):
78
            state[i] = minibatch[i][0]
79
            action.append(minibatch[i][1])
80
            reward.append(minibatch[i][2])
81
            next_state[i] = minibatch[i][3]
82
            done.append(minibatch[i][4])
83

84
        # Predict y label
85
        target = self.model.predict(state)
86
        target_next = self.model.predict(next_state)
87

88
        for i in range(self.batch_number):
89
            if done[i]:
90
                target[i][action[i]] = reward[i]
91
            else:
92
                target[i][action[i]] = reward[i] + \
93
                    self.gamma * (np.argmax(target_next[i]))
94
        print('training')
95
        self.model.fit(state, target, batch_size=self.batch_number, verbose=0)
96

97
    def train(self):
98
        # n episode Iterartions for training
99
        for i in range(self.episodes):
100
            # Environment variables for training
101
            state = self.env.reset()
102
            state = np.reshape(state, [1, self.state_space])
103
            done = False
104
            score = 0
105
            self.epsilon = self.epsilon * self.epsilon_decay if self.epsilon * \
106
                self.epsilon_decay > self.epsilon_min else self.epsilon_min
107

108
            while not done:
109
                self.env.render()
110
                action = self.act(state)
111
                next_state, reward, done, info = self.env.step(action)
112

113
                # reshape nextstate
114
                next_state = np.reshape(next_state, [1, self.state_space])
115
                score += 1
116
                if done:
117
                    reward -= 100
118

119
                self.memory.append((state, action, reward, next_state, done))
120
                state = next_state
121

122
                if done:
123
                    print("Episode: {}\nScore: {}\nEpsilon: {:.2}".format(
124
                        i, score, self.epsilon))
125
                    # Save model
126
                    if score >= 1000:
127
                        self.model.save_model('flappybrain.h5')
128
                self.learn()
129

130
    def perform(self):
131
        self.model = load_model('flappybrain.h5')
132
        while 1:
133
            state = self.env.reset()
134
            state = np.reshape(state, [1, self.state_space])
135
            done = False
136
            score = 0
137

138
            while not done:
139
                self.env.render()
140
                action = np.argmax(self.model.predict(state))
141
                next_state, reward, done, info = self.env.step(action)
142
                state = np.reshape(next_state, [1, self.state_space])
143
                score += 1
144

145
                print("Current Score: {}".format(score))
146

147
                if done:
148
                    print('DEAD')
149
                    break
150

151

152
if __name__ == '__main__':
153
    agent = DQNAgent()
154
    agent.train()
155
    # agent.perform()
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.