CelestialSurveyor

Форк
0
/
model_trainer_v2.py 
344 строки · 13.8 Кб
1

2
# it's done to speedup importing in child processes
3
if __name__ == "__main__":
4
    import os
5
    import tensorflow as tf
6

7
    from dataclasses import dataclass
8
    from model_builder import build_model, encrypt_model
9
    from training_dataset_v2 import TrainingDatasetV2, TrainingSourceDataV2
10
    from typing import Optional
11

12
    from backend.progress_bar import ProgressBarCli
13
    from logger.logger import Logger
14

15
    logger = Logger()
16

17
    # os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"  # uncomment if you don't have enough GPU memory
18
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19

20

21
    @dataclass
22
    class SourceDataProperties:
23
        """
24
        Dataclass to represent source data properties.
25
        """
26
        folder: str
27
        linear: bool
28
        to_align: bool
29
        to_debayer: bool
30
        secondary_alignment: tuple[int, int] = (3, 3)
31
        number_of_images: Optional[int] = None
32
        dark_folder: Optional[str] = None
33
        flat_folder: Optional[str] = None
34
        dark_flats_folder: Optional[str] = None
35
        magnitude_limit: Optional[float] = 18.0
36

37
        @property
38
        def file_paths(self) -> list[str]:
39
            """
40
            Property to get the file paths based on the folder.
41

42
            Returns:
43
                list[str]: List of file paths.
44
            """
45
            return TrainingSourceDataV2.make_file_paths(self.folder)
46

47
        @property
48
        def dark_paths(self) -> Optional[list[str]]:
49
            """
50
            Property to get the dark file paths based on the dark folder.
51

52
            Returns:
53
                Optional[list[str]]: List of dark file paths.
54
            """
55
            if self.dark_folder is None:
56
                return None
57
            return TrainingSourceDataV2.make_file_paths(self.dark_folder)
58

59
        @property
60
        def flat_paths(self) -> Optional[list[str]]:
61
            """
62
            Property to get the flat file paths based on the flat folder.
63

64
            Returns:
65
                Optional[list[str]]: List of flat file paths.
66
            """
67
            if self.flat_folder is None:
68
                return None
69
            return TrainingSourceDataV2.make_file_paths(self.flat_folder)
70

71
        @property
72
        def dark_flat_paths(self) -> Optional[list[str]]:
73
            """
74
            Property to get the dark-flat file paths based on the dark-flats folder.
75

76
            Returns:
77
                Optional[list[str]]: List of dark flat file paths.
78
            """
79
            if self.dark_flats_folder is None:
80
                return None
81
            return TrainingSourceDataV2.make_file_paths(self.dark_flats_folder)
82

83

84
    def main() -> None:
85
        """
86
        Training entry point.
87
        """
88
        logger.log.info(tf.__version__)
89
        load_model_name = "model161"
90
        save_model_name = "model170"
91

92
        # Build the model
93

94
        # Compile the model (uncomment this block to train the model from scratch)
95
        model = build_model()
96
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
97

98
        # Load model (uncomment this block to continue model training)
99
        # model = tf.keras.models.load_model(
100
        #     f'{load_model_name}.h5'
101
        # )
102

103
        source_data_properties = [
104
            SourceDataProperties(
105
                folder='D:\\git\\dataset\\NGC1333_RASA\\cropped',
106
                linear=False,
107
                to_align=False,
108
                number_of_images=None,
109
                to_debayer=False),
110
            # SourceDataProperties(
111
            #     folder='D:\\git\\dataset\\Seahorse\\cropped',
112
            #     linear=False,
113
            #     to_align=False,
114
            #     number_of_images=None,
115
            #     to_debayer=False),
116
            # SourceDataProperties(
117
            #     folder='D:\\git\\dataset\\Orion\\Part1\\cropped',
118
            #     linear=False,
119
            #     to_align=False,
120
            #     number_of_images=None,
121
            #     to_debayer=False),
122
            # SourceDataProperties(
123
            #     folder='D:\\git\\dataset\\Orion\\Part2\\cropped',
124
            #     linear=False,
125
            #     to_align=False,
126
            #     number_of_images=None,
127
            #     to_debayer=False),
128
            # SourceDataProperties(
129
            #     folder='D:\\git\\dataset\\Orion\\Part3\\cropped',
130
            #     linear=False,
131
            #     to_align=False,
132
            #     number_of_images=None,
133
            #     to_debayer=False),
134
            # SourceDataProperties(
135
            #     folder='D:\\git\\dataset\\Orion\\Part4\\cropped1',
136
            #     linear=False,
137
            #     to_align=False,
138
            #     number_of_images=None,
139
            #     to_debayer=False),
140
            # # SourceDataProperties(
141
            # #     folder='D:\\git\\dataset\\NGC1333_RASA\\cropped',
142
            # #     linear=False,
143
            # #     to_align=False,
144
            # #     number_of_images=5,
145
            # #     to_debayer=False),
146
            # # SourceDataProperties(
147
            # #     folder='D:\\git\\dataset\\Seahorse\\cropped',
148
            # #     linear=False,
149
            # #     to_align=False,
150
            # #     number_of_images=7,
151
            # #     to_debayer=False),
152
            # # SourceDataProperties(
153
            # #     folder='D:\\git\\dataset\\Orion\\Part1\\cropped',
154
            # #     linear=False,
155
            # #     to_align=False,
156
            # #     number_of_images=3,
157
            # #     to_debayer=False),
158
            # # SourceDataProperties(
159
            # #     folder='D:\\git\\dataset\\Orion\\Part2\\cropped',
160
            # #     linear=False,
161
            # #     to_align=False,
162
            # #     number_of_images=6,
163
            # #     to_debayer=False),
164
            # # SourceDataProperties(
165
            # #     folder='D:\\git\\dataset\\Orion\\Part3\\cropped',
166
            # #     linear=False,
167
            # #     to_align=False,
168
            # #     number_of_images=7,
169
            # #     to_debayer=False),
170
            # # SourceDataProperties(
171
            # #     folder='D:\\git\\dataset\\Orion\\Part4\\cropped1',
172
            # #     linear=False,
173
            # #     to_align=False,
174
            # #     number_of_images=4,
175
            # #     to_debayer=False),
176
            # # SourceDataProperties(
177
            # #     folder='D:\\git\\dataset\\M78\\Light_BIN-1_EXPOSURE-120.00s_FILTER-NoFilter_RGB',
178
            # #     linear=True,
179
            # #     to_align=False,
180
            # #     number_of_images=None,
181
            # #     to_debayer=False),
182
            # # SourceDataProperties(
183
            # #     folder='D:\\git\\dataset\\M81\\cropped',
184
            # #     linear=False,
185
            # #     to_align=False,
186
            # #     number_of_images=60,
187
            # #     to_debayer=False),
188
            # SourceDataProperties(
189
            #     folder='D:\\git\\dataset\\Virgo',
190
            #     linear=True,
191
            #     to_align=True,
192
            #     number_of_images=None,
193
            #     to_debayer=True,
194
            #     dark_folder='D:\\git\\dataset\\Virgo\\Dark',
195
            #     flat_folder='D:\\git\\dataset\\Virgo\\Flat',
196
            #     dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
197
            # SourceDataProperties(
198
            #     folder='D:\\git\\dataset\\Virgo1',
199
            #     linear=True,
200
            #     to_align=True,
201
            #     number_of_images=None,
202
            #     to_debayer=True,
203
            #     dark_folder='D:\\git\\dataset\\Virgo\\Dark',
204
            #     flat_folder='D:\\git\\dataset\\Virgo\\Flat',
205
            #     dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
206
            # SourceDataProperties(
207
            #     folder='D:\\git\\dataset\\Virgo2',
208
            #     linear=True,
209
            #     to_align=True,
210
            #     number_of_images=None,
211
            #     to_debayer=True,
212
            #     dark_folder='D:\\git\\dataset\\Virgo\\Dark',
213
            #     flat_folder='D:\\git\\dataset\\Virgo\\Flat',
214
            #     dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
215
            # SourceDataProperties(
216
            #     folder='E:\\Astro\\Antares2024\\Night_1\\PART1',
217
            #     linear=True,
218
            #     to_align=True,
219
            #     number_of_images=None,
220
            #     to_debayer=True,
221
            #     dark_folder='E:\\Astro\\Antares2024\\DARK',
222
            #     flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
223
            #     dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
224
            #     magnitude_limit=18.0),
225
            # SourceDataProperties(
226
            #     folder='E:\\Astro\\Antares2024\\Night_1\\PART2',
227
            #     linear=True,
228
            #     to_align=True,
229
            #     number_of_images=None,
230
            #     to_debayer=True,
231
            #     dark_folder='E:\\Astro\\Antares2024\\DARK',
232
            #     flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
233
            #     dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
234
            #     magnitude_limit=18.0),
235
            # # SourceDataProperties(
236
            # #     folder='E:\\Astro\\Antares2024\\Night_3\\PART1',
237
            # #     linear=True,
238
            # #     to_align=True,
239
            # #     number_of_images=None,
240
            # #     to_debayer=True,
241
            # #     dark_folder='E:\\Astro\\Antares2024\\DARK',
242
            # #     flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
243
            # #     dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
244
            # #     magnitude_limit=18.0),
245
            # # SourceDataProperties(
246
            # #     folder='E:\\Astro\\Antares2024\\Night_3\\PART2',
247
            # #     linear=True,
248
            # #     to_align=True,
249
            # #     number_of_images=None,
250
            # #     to_debayer=True,
251
            # #     dark_folder='E:\\Astro\\Antares2024\\DARK',
252
            # #     flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
253
            # #     dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
254
            # #     magnitude_limit=18.0),
255
            # # SourceDataProperties(
256
            # #     folder='E:\\Astro\\Antares2024\\Night_4\\PART1',
257
            # #     linear=True,
258
            # #     to_align=True,
259
            # #     number_of_images=None,
260
            # #     to_debayer=True,
261
            # #     dark_folder='E:\\Astro\\Antares2024\\DARK',
262
            # #     flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
263
            # #     dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
264
            # #     magnitude_limit=18.0),
265
            # SourceDataProperties(
266
            #     folder='E:\\Astro\\Andromeda\\Light',
267
            #     linear=True,
268
            #     to_align=True,
269
            #     number_of_images=None,
270
            #     to_debayer=True,
271
            #     dark_folder='E:\\Astro\\Andromeda\\Dark_600',
272
            #     flat_folder='E:\\Astro\\Andromeda\\Flat\\Night2',
273
            #     dark_flats_folder='E:\\Astro\\Andromeda\\Dark_Flat',
274
            #     magnitude_limit=19.0),
275
            # SourceDataProperties(
276
            #     folder='E:\\Astro\\Rosette\\Light',
277
            #     linear=True,
278
            #     to_align=True,
279
            #     number_of_images=None,
280
            #     to_debayer=True,
281
            #     dark_folder='E:\\Astro\\Rosette\\Dark',
282
            #     flat_folder='E:\\Astro\\Rosette\\Flat',
283
            #     dark_flats_folder='E:\\Astro\\Rosette\\DarkFlat',
284
            #     magnitude_limit=19.0),
285
        ]
286

287
        source_datas = []
288
        for num, properties in enumerate(source_data_properties):
289
            logger.log.info(f"Processing {num+1} of {len(source_data_properties)} source datas")
290
            source_data = TrainingSourceDataV2(properties.to_debayer)
291
            source_data.extend_headers(file_list=properties.file_paths)
292
            if properties.number_of_images is not None:
293
                source_data.headers = source_data.headers[:properties.number_of_images]
294
            source_data.load_images(progress_bar=ProgressBarCli())
295
            logger.log.info(f"dtype = {source_data.original_frames.dtype}")
296
            source_data.calibrate_images(properties.dark_paths, properties.flat_paths, properties.dark_flat_paths,
297
                                         progress_bar=ProgressBarCli())
298
            if properties.to_align:
299
                source_data.plate_solve_all(progress_bar=ProgressBarCli())
300
                source_data.align_images_wcs(progress_bar=ProgressBarCli())
301

302
            source_data.crop_images()
303
            if properties.linear:
304
                source_data.stretch_images(progress_bar=ProgressBarCli())
305
            source_data.load_exclusion_boxes()
306
            source_data.images_from_buffer()
307
            source_datas.append(source_data)
308

309
            size = source_data.images.itemsize
310
            for item in source_data.images.shape:
311
                size *= item
312
            logger.log.info(f"dtype: {source_data.images.dtype}, shape: {source_data.images.shape}, size: {size}")
313
            logger.log.info(f"Allocated {size // (1024 * 1024)} Mb")
314

315
        training_dataset = TrainingDatasetV2(source_datas)
316
        training_generator = training_dataset.batch_generator(batch_size=20)
317
        val_generator = training_dataset.batch_generator(batch_size=20)
318

319
        early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
320
            monitor='val_loss',
321
            min_delta=0,
322
            patience=100,
323
            verbose=1,
324
            mode='min',
325
            baseline=None,
326
            restore_best_weights=True
327
        )
328

329
        try:
330
            model.fit(
331
                training_generator,
332
                validation_data=val_generator,
333
                steps_per_epoch=500,
334
                validation_steps=1000,
335
                epochs=10000,
336
                callbacks=[early_stopping_monitor]
337
            )
338
        except KeyboardInterrupt:
339
            model.save(f"{save_model_name}.h5")
340
            encrypt_model(save_model_name)
341

342
        model.save(f"{save_model_name}.h5")
343
        encrypt_model(save_model_name)
344

345
    main()
346

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.