CelestialSurveyor
344 строки · 13.8 Кб
1
2# it's done to speedup importing in child processes
3if __name__ == "__main__":
4import os
5import tensorflow as tf
6
7from dataclasses import dataclass
8from model_builder import build_model, encrypt_model
9from training_dataset_v2 import TrainingDatasetV2, TrainingSourceDataV2
10from typing import Optional
11
12from backend.progress_bar import ProgressBarCli
13from logger.logger import Logger
14
15logger = Logger()
16
17# os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" # uncomment if you don't have enough GPU memory
18os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19
20
21@dataclass
22class SourceDataProperties:
23"""
24Dataclass to represent source data properties.
25"""
26folder: str
27linear: bool
28to_align: bool
29to_debayer: bool
30secondary_alignment: tuple[int, int] = (3, 3)
31number_of_images: Optional[int] = None
32dark_folder: Optional[str] = None
33flat_folder: Optional[str] = None
34dark_flats_folder: Optional[str] = None
35magnitude_limit: Optional[float] = 18.0
36
37@property
38def file_paths(self) -> list[str]:
39"""
40Property to get the file paths based on the folder.
41
42Returns:
43list[str]: List of file paths.
44"""
45return TrainingSourceDataV2.make_file_paths(self.folder)
46
47@property
48def dark_paths(self) -> Optional[list[str]]:
49"""
50Property to get the dark file paths based on the dark folder.
51
52Returns:
53Optional[list[str]]: List of dark file paths.
54"""
55if self.dark_folder is None:
56return None
57return TrainingSourceDataV2.make_file_paths(self.dark_folder)
58
59@property
60def flat_paths(self) -> Optional[list[str]]:
61"""
62Property to get the flat file paths based on the flat folder.
63
64Returns:
65Optional[list[str]]: List of flat file paths.
66"""
67if self.flat_folder is None:
68return None
69return TrainingSourceDataV2.make_file_paths(self.flat_folder)
70
71@property
72def dark_flat_paths(self) -> Optional[list[str]]:
73"""
74Property to get the dark-flat file paths based on the dark-flats folder.
75
76Returns:
77Optional[list[str]]: List of dark flat file paths.
78"""
79if self.dark_flats_folder is None:
80return None
81return TrainingSourceDataV2.make_file_paths(self.dark_flats_folder)
82
83
84def main() -> None:
85"""
86Training entry point.
87"""
88logger.log.info(tf.__version__)
89load_model_name = "model161"
90save_model_name = "model170"
91
92# Build the model
93
94# Compile the model (uncomment this block to train the model from scratch)
95model = build_model()
96model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
97
98# Load model (uncomment this block to continue model training)
99# model = tf.keras.models.load_model(
100# f'{load_model_name}.h5'
101# )
102
103source_data_properties = [
104SourceDataProperties(
105folder='D:\\git\\dataset\\NGC1333_RASA\\cropped',
106linear=False,
107to_align=False,
108number_of_images=None,
109to_debayer=False),
110# SourceDataProperties(
111# folder='D:\\git\\dataset\\Seahorse\\cropped',
112# linear=False,
113# to_align=False,
114# number_of_images=None,
115# to_debayer=False),
116# SourceDataProperties(
117# folder='D:\\git\\dataset\\Orion\\Part1\\cropped',
118# linear=False,
119# to_align=False,
120# number_of_images=None,
121# to_debayer=False),
122# SourceDataProperties(
123# folder='D:\\git\\dataset\\Orion\\Part2\\cropped',
124# linear=False,
125# to_align=False,
126# number_of_images=None,
127# to_debayer=False),
128# SourceDataProperties(
129# folder='D:\\git\\dataset\\Orion\\Part3\\cropped',
130# linear=False,
131# to_align=False,
132# number_of_images=None,
133# to_debayer=False),
134# SourceDataProperties(
135# folder='D:\\git\\dataset\\Orion\\Part4\\cropped1',
136# linear=False,
137# to_align=False,
138# number_of_images=None,
139# to_debayer=False),
140# # SourceDataProperties(
141# # folder='D:\\git\\dataset\\NGC1333_RASA\\cropped',
142# # linear=False,
143# # to_align=False,
144# # number_of_images=5,
145# # to_debayer=False),
146# # SourceDataProperties(
147# # folder='D:\\git\\dataset\\Seahorse\\cropped',
148# # linear=False,
149# # to_align=False,
150# # number_of_images=7,
151# # to_debayer=False),
152# # SourceDataProperties(
153# # folder='D:\\git\\dataset\\Orion\\Part1\\cropped',
154# # linear=False,
155# # to_align=False,
156# # number_of_images=3,
157# # to_debayer=False),
158# # SourceDataProperties(
159# # folder='D:\\git\\dataset\\Orion\\Part2\\cropped',
160# # linear=False,
161# # to_align=False,
162# # number_of_images=6,
163# # to_debayer=False),
164# # SourceDataProperties(
165# # folder='D:\\git\\dataset\\Orion\\Part3\\cropped',
166# # linear=False,
167# # to_align=False,
168# # number_of_images=7,
169# # to_debayer=False),
170# # SourceDataProperties(
171# # folder='D:\\git\\dataset\\Orion\\Part4\\cropped1',
172# # linear=False,
173# # to_align=False,
174# # number_of_images=4,
175# # to_debayer=False),
176# # SourceDataProperties(
177# # folder='D:\\git\\dataset\\M78\\Light_BIN-1_EXPOSURE-120.00s_FILTER-NoFilter_RGB',
178# # linear=True,
179# # to_align=False,
180# # number_of_images=None,
181# # to_debayer=False),
182# # SourceDataProperties(
183# # folder='D:\\git\\dataset\\M81\\cropped',
184# # linear=False,
185# # to_align=False,
186# # number_of_images=60,
187# # to_debayer=False),
188# SourceDataProperties(
189# folder='D:\\git\\dataset\\Virgo',
190# linear=True,
191# to_align=True,
192# number_of_images=None,
193# to_debayer=True,
194# dark_folder='D:\\git\\dataset\\Virgo\\Dark',
195# flat_folder='D:\\git\\dataset\\Virgo\\Flat',
196# dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
197# SourceDataProperties(
198# folder='D:\\git\\dataset\\Virgo1',
199# linear=True,
200# to_align=True,
201# number_of_images=None,
202# to_debayer=True,
203# dark_folder='D:\\git\\dataset\\Virgo\\Dark',
204# flat_folder='D:\\git\\dataset\\Virgo\\Flat',
205# dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
206# SourceDataProperties(
207# folder='D:\\git\\dataset\\Virgo2',
208# linear=True,
209# to_align=True,
210# number_of_images=None,
211# to_debayer=True,
212# dark_folder='D:\\git\\dataset\\Virgo\\Dark',
213# flat_folder='D:\\git\\dataset\\Virgo\\Flat',
214# dark_flats_folder='D:\\git\\dataset\\Virgo\\DarkFlat'),
215# SourceDataProperties(
216# folder='E:\\Astro\\Antares2024\\Night_1\\PART1',
217# linear=True,
218# to_align=True,
219# number_of_images=None,
220# to_debayer=True,
221# dark_folder='E:\\Astro\\Antares2024\\DARK',
222# flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
223# dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
224# magnitude_limit=18.0),
225# SourceDataProperties(
226# folder='E:\\Astro\\Antares2024\\Night_1\\PART2',
227# linear=True,
228# to_align=True,
229# number_of_images=None,
230# to_debayer=True,
231# dark_folder='E:\\Astro\\Antares2024\\DARK',
232# flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
233# dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
234# magnitude_limit=18.0),
235# # SourceDataProperties(
236# # folder='E:\\Astro\\Antares2024\\Night_3\\PART1',
237# # linear=True,
238# # to_align=True,
239# # number_of_images=None,
240# # to_debayer=True,
241# # dark_folder='E:\\Astro\\Antares2024\\DARK',
242# # flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
243# # dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
244# # magnitude_limit=18.0),
245# # SourceDataProperties(
246# # folder='E:\\Astro\\Antares2024\\Night_3\\PART2',
247# # linear=True,
248# # to_align=True,
249# # number_of_images=None,
250# # to_debayer=True,
251# # dark_folder='E:\\Astro\\Antares2024\\DARK',
252# # flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
253# # dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
254# # magnitude_limit=18.0),
255# # SourceDataProperties(
256# # folder='E:\\Astro\\Antares2024\\Night_4\\PART1',
257# # linear=True,
258# # to_align=True,
259# # number_of_images=None,
260# # to_debayer=True,
261# # dark_folder='E:\\Astro\\Antares2024\\DARK',
262# # flat_folder='E:\\Astro\\Antares2024\\Night_1\\FLAT',
263# # dark_flats_folder='E:\\Astro\\Antares2024\\Night_1\\DARKFLAT',
264# # magnitude_limit=18.0),
265# SourceDataProperties(
266# folder='E:\\Astro\\Andromeda\\Light',
267# linear=True,
268# to_align=True,
269# number_of_images=None,
270# to_debayer=True,
271# dark_folder='E:\\Astro\\Andromeda\\Dark_600',
272# flat_folder='E:\\Astro\\Andromeda\\Flat\\Night2',
273# dark_flats_folder='E:\\Astro\\Andromeda\\Dark_Flat',
274# magnitude_limit=19.0),
275# SourceDataProperties(
276# folder='E:\\Astro\\Rosette\\Light',
277# linear=True,
278# to_align=True,
279# number_of_images=None,
280# to_debayer=True,
281# dark_folder='E:\\Astro\\Rosette\\Dark',
282# flat_folder='E:\\Astro\\Rosette\\Flat',
283# dark_flats_folder='E:\\Astro\\Rosette\\DarkFlat',
284# magnitude_limit=19.0),
285]
286
287source_datas = []
288for num, properties in enumerate(source_data_properties):
289logger.log.info(f"Processing {num+1} of {len(source_data_properties)} source datas")
290source_data = TrainingSourceDataV2(properties.to_debayer)
291source_data.extend_headers(file_list=properties.file_paths)
292if properties.number_of_images is not None:
293source_data.headers = source_data.headers[:properties.number_of_images]
294source_data.load_images(progress_bar=ProgressBarCli())
295logger.log.info(f"dtype = {source_data.original_frames.dtype}")
296source_data.calibrate_images(properties.dark_paths, properties.flat_paths, properties.dark_flat_paths,
297progress_bar=ProgressBarCli())
298if properties.to_align:
299source_data.plate_solve_all(progress_bar=ProgressBarCli())
300source_data.align_images_wcs(progress_bar=ProgressBarCli())
301
302source_data.crop_images()
303if properties.linear:
304source_data.stretch_images(progress_bar=ProgressBarCli())
305source_data.load_exclusion_boxes()
306source_data.images_from_buffer()
307source_datas.append(source_data)
308
309size = source_data.images.itemsize
310for item in source_data.images.shape:
311size *= item
312logger.log.info(f"dtype: {source_data.images.dtype}, shape: {source_data.images.shape}, size: {size}")
313logger.log.info(f"Allocated {size // (1024 * 1024)} Mb")
314
315training_dataset = TrainingDatasetV2(source_datas)
316training_generator = training_dataset.batch_generator(batch_size=20)
317val_generator = training_dataset.batch_generator(batch_size=20)
318
319early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
320monitor='val_loss',
321min_delta=0,
322patience=100,
323verbose=1,
324mode='min',
325baseline=None,
326restore_best_weights=True
327)
328
329try:
330model.fit(
331training_generator,
332validation_data=val_generator,
333steps_per_epoch=500,
334validation_steps=1000,
335epochs=10000,
336callbacks=[early_stopping_monitor]
337)
338except KeyboardInterrupt:
339model.save(f"{save_model_name}.h5")
340encrypt_model(save_model_name)
341
342model.save(f"{save_model_name}.h5")
343encrypt_model(save_model_name)
344
345main()
346