CelestialSurveyor

Форк
0
/
source_data_v2.py 
840 строк · 33.2 Кб
1
import json
2
import numpy as np
3
import os
4
import requests
5
import sys
6
import uuid
7

8
from astropy.wcs import WCS
9
from astropy.coordinates import Angle
10
from threading import Event
11
from typing import Optional, Union, Tuple, Generator
12

13
from backend.consuming_functions.load_headers import load_headers
14
from backend.consuming_functions.load_images import load_images, PIXEL_TYPE, load_image
15
from backend.consuming_functions.plate_solve_images import plate_solve, plate_solve_image
16
from backend.consuming_functions.align_images import align_images_wcs
17
from backend.consuming_functions.stretch_images import stretch_images
18
from backend.data_classes import SharedMemoryParams
19
from backend.progress_bar import AbstractProgressBar
20
from backend.data_classes import Header
21
from backend.known_object import KnownObject
22
from logger.logger import get_logger
23

24

25
logger = get_logger()
26

27

28
CHUNK_SIZE = 64
29

30

31
class SourceDataV2:
32
    """
33
    Class to manage image data.
34
    """
35
    def __init__(self, to_debayer: bool = False) -> None:
36
        self.headers = []
37
        self.original_frames = None
38
        self.shm = None
39
        if not os.path.exists(self.tmp_folder):
40
            os.mkdir(self.tmp_folder)
41
        else:
42
            # only one sourcedata instance to be loaded at the same time
43
            for item in os.listdir(self.tmp_folder):
44
                if item.endswith(".np"):
45
                    os.remove(os.path.join(self.tmp_folder, item))
46

47
        self.shm_name = self.__create_shm_name('images')
48
        self.shm_params = None
49
        self.footprint_map = None
50
        self.to_debayer = to_debayer
51
        self.y_borders: slice = slice(None, None)
52
        self.x_borders: slice = slice(None, None)
53
        self.__usage_map = None
54
        self.__chunk_len: int = 0
55
        self.__wcs: Optional[WCS] = None
56
        self.__cropped = False
57
        self.__shared = True
58
        self.__images = None
59
        self.__stop_event = Event()
60
        self.__used_images = None
61
        self.__usage_map_changed = True
62
        self.__original_frames = None
63

64
    def __create_shm_name(self, postfix: str = '') -> str:
65
        """
66
        Creates name for the shared memory file.
67

68
        Parameters:
69
        - postfix (str): Optional postfix to append to the shared memory file name.
70

71
        Returns:
72
        - str: The generated shared memory file name.
73
        """
74
        shm_name = os.path.join(self.tmp_folder, f"tmp_{uuid.uuid4().hex}_{postfix}.np")
75
        return shm_name
76

77
    def __clear_tmp_folder(self):
78
        """
79
        Clears temporary folder by removing all files with '.np' extension (shared memory files).
80
        """
81
        for item in os.listdir(self.tmp_folder):
82
            if item.endswith(".np"):
83
                os.remove(os.path.join(self.tmp_folder, item))
84

85
    def __reset_shm(self):
86
        """
87
        Resets the shared memory by clearing the temporary folder and creating a new shared memory file.
88
        Required in UI mode when user wants to add more images or stops loading data.
89
        """
90
        self.original_frames = None
91
        self.__clear_tmp_folder()
92
        self.shm_name = self.__create_shm_name('images')
93

94
    def raise_stop_event(self):
95
        """
96
        Raise the stop event to let the child processes to stop.
97
        """
98
        self.__stop_event.set()
99

100
    def clear_stop_event(self):
101
        """
102
        Raise the stop event to let the child processes that reloading may be done.
103
        """
104
        self.__stop_event.clear()
105

106
    @property
107
    def tmp_folder(self) -> str:
108
        """
109
        Get the path to the temporary folder where shared memory files are stored.
110

111
        Returns:
112
            str: The path to the temporary folder.
113
        """
114
        return os.path.join(sys.path[1], "tmp")
115

116
    @property
117
    def stop_event(self) -> Event:
118
        return self.__stop_event
119

120
    @staticmethod
121
    def filter_file_list(file_list: list[str]) -> list[str]:
122
        """
123
        Filter the file list to include only files with extensions .xisf, .fit, or .fits.
124

125
        Args:
126
            file_list (list[str]): List of file paths to filter.
127

128
        Returns:
129
            list[str]: Filtered list of file paths.
130
        """
131
        return [item for item in file_list if item.lower().endswith(".xisf") or item.lower().endswith(".fit")
132
                or item.lower().endswith(".fits")]
133

134
    def extend_headers(self, file_list: list[str], progress_bar: Optional[AbstractProgressBar] = None) -> None:
135
        """
136
        Extends the headers with information loaded from the given file list.
137

138
        Args:
139
            file_list (list[str]): List of file paths to load headers from.
140
            progress_bar (Optional[AbstractProgressBar]): An optional progress bar to show the loading progress.
141

142
        Returns:
143
            None
144
        """
145
        file_list = self.filter_file_list(file_list)
146
        self.headers.extend(load_headers(file_list, progress_bar, stop_event=self.stop_event))
147
        self.headers.sort(key=lambda header: header.timestamp)
148

149
    def set_headers(self, headers: list[Header]) -> None:
150
        """
151
        Set the headers.
152

153
        Args:
154
            headers (list[Header]): List of headers to set.
155

156
        Returns:
157
            None
158
        """
159
        self.headers = headers
160
        self.headers.sort(key=lambda header: header.timestamp)
161
        self.__reset_shm()
162

163
    @property
164
    def shape(self) -> tuple:
165
        return self.images.shape
166

167
    @property
168
    def origional_shape(self) -> tuple:
169
        return self.original_frames.shape
170

171
    @property
172
    def usage_map(self) -> np.ndarray:
173
        if self.__usage_map is None:
174
            self.__usage_map = np.ones((len(self.__images), ), dtype=bool)
175
        return self.__usage_map
176

177
    @usage_map.setter
178
    def usage_map(self, usage_map):
179
        self.__usage_map = usage_map
180
        self.__usage_map_changed = True
181

182
    @property
183
    def images(self):
184
        if self.__shared:
185
            usage_map = self.__usage_map if self.__usage_map is not None else np.ones((len(self.headers), ), dtype=bool)
186
            return self.original_frames[
187
                usage_map, self.y_borders, self.x_borders] if self.original_frames is not None else None
188
        else:
189
            if self.__usage_map_changed:
190
                self.__used_images = self.__images[self.usage_map]
191
                self.__usage_map_changed = False
192
            return self.__used_images
193

194
    def images_from_buffer(self) -> None:
195
        """
196
        Copy images from the shared memory file to RAM, update headers based on usage map,
197
        and reset shared memory properties. Needs to be done after image loading, calibration and alignment to speed up
198
        processing.
199
        """
200
        self.__images = np.copy(self.images)
201
        self.headers = [header for idx, header in enumerate(self.headers) if self.usage_map[idx]]
202
        self.__usage_map_changed = True
203
        self.usage_map = np.ones((len(self.__images), ), dtype=bool)
204
        name = self.shm_name
205
        self.original_frames._mmap.close()
206
        del self.original_frames
207
        self.original_frames = None
208
        os.remove(name)
209
        self.__original_frames = None
210
        self.__shared = False
211

212
    @property
213
    def max_image(self) -> np.ndarray:
214
        return np.amax(self.images, axis=0)
215

216
    @property
217
    def wcs(self) -> WCS:
218
        if self.__wcs is None and self.__cropped is True:
219
            self.__wcs, _ = self.plate_solve()
220
        return self.__wcs
221

222
    @wcs.setter
223
    def wcs(self, value):
224
        self.__wcs = value
225

226
    def load_images(self, progress_bar: Optional[AbstractProgressBar] = None) -> None:
227
        """
228
        Load images from the file list specified in the headers.
229

230
        Parameters:
231
        - progress_bar (Optional[AbstractProgressBar]): A progress bar to show the loading progress.
232

233
        Returns:
234
        - None
235
        """
236
        logger.log.info("Loading images...")
237
        file_list = [header.file_name for header in self.headers]
238
        img = load_image(file_list[0])
239
        shape = (len(file_list), *img.shape)
240
        self.shm_params = SharedMemoryParams(
241
            shm_name=self.shm_name, shm_shape=shape, shm_size=img.nbytes * len(file_list), shm_dtype=img.dtype)
242
        self.original_frames = np.memmap(self.shm_name, dtype=PIXEL_TYPE, mode='w+', shape=shape)
243
        self.__shared = True
244
        load_images(
245
            file_list, self.shm_params, to_debayer=self.to_debayer, progress_bar=progress_bar,
246
            stop_event=self.stop_event)
247

248
    @staticmethod
249
    def calculate_raw_crop(footprint: np.ndarray) -> tuple[tuple[int, int], tuple[int, int]]:
250
        """
251
        Calculate the crop coordinates based on the footprint array.
252
        Raw crop means that the lines and columns which contain only zeros will be cut.
253

254
        Parameters:
255
        - footprint (np.ndarray): The input array representing the footprint.
256

257
        Returns:
258
        - Tuple[Tuple[int, int], Tuple[int, int]]: A tuple containing the crop coordinates for y-axis and x-axis.
259

260
        """
261
        y_top = x_left = 0
262
        y_bottom, x_right = footprint.shape[:2]
263
        for num, line in enumerate(footprint):
264
            if not np.all(line):
265
                y_top = num
266
                break
267
        for num, line in enumerate(footprint[::-1]):
268
            if not np.all(line):
269
                y_bottom -= num
270
                break
271

272
        for num, line in enumerate(footprint.T):
273
            if not np.all(line):
274
                x_left = num
275
                break
276

277
        for num, line in enumerate(footprint.T[::-1]):
278
            if not np.all(line):
279
                x_right -= num
280
                break
281

282
        return (y_top, y_bottom), (x_left, x_right)
283

284
    @staticmethod
285
    def crop_image(imgs: np.ndarray,
286
                   y_borders: Union[slice, Tuple[int, int]],
287
                   x_borders: Union[slice, Tuple[int, int]],
288
                   usage_mask: Optional[np.ndarray] = None) -> np.ndarray:
289
        """
290
        Crop the images based on the provided y and x borders.
291

292
        Args:
293
            imgs (np.ndarray): The input image array.
294
            y_borders (Union[slice, Tuple[int, int]]): The borders for the y-axis.
295
            x_borders (Union[slice, Tuple[int, int]]): The borders for the x-axis.
296
            usage_mask (Optional[np.ndarray]): Optional usage mask for cropping.
297

298
        Returns:
299
            np.ndarray: The cropped image array.
300
        """
301
        if isinstance(y_borders, slice):
302
            pass
303
        elif isinstance(y_borders, tuple) and len(y_borders) == 2:
304
            y_borders = slice(*y_borders)
305
        else:
306
            raise ValueError("y_borders must be a tuple of length 2 or a slice")
307
        if isinstance(x_borders, slice):
308
            pass
309
        elif isinstance(x_borders, tuple) and len(x_borders) == 2:
310
            x_borders = slice(*x_borders)
311
        else:
312
            raise ValueError("x_borders must be a tuple of length 2 or a slice")
313
        if usage_mask:
314
            return imgs[usage_mask, y_borders, x_borders]
315
        else:
316
            return imgs[:, y_borders, x_borders]
317

318
    @staticmethod
319
    def __get_num_of_corner_zeros(line: np.ndarray) -> int:
320
        """
321
        Count the number of zeros in the footprint line.
322

323
        Args:
324
            line (np.ndarray): The input line from the footprint.
325

326
        Returns:
327
            int: The number of zeros in the line.
328
        """
329
        # True means zero in footprint
330
        return np.count_nonzero(line)
331

332
    @classmethod
333
    def __fine_crop_border(cls, footprint: np.ndarray, direction: int, transpon: bool = True) -> Tuple[np.array, int]:
334
        """
335
        This method calculates the fine crop border based on the direction and whether to transpose the footprint.
336
        The goal is to leave areas where the image is without zeros after alignment.
337

338
        Args:
339
            footprint (np.ndarray): The input footprint array.
340
            direction (int): The direction to calculate the border.
341
            transpon (bool, optional): Whether to transpose the footprint. Defaults to True.
342

343
        Returns:
344
            Tuple[np.array, int]: The cropped footprint and the calculated border.
345
        """
346
        if transpon:
347
            footprint = footprint.T
348
        x = 0
349
        line: np.ndarray
350
        for num, line in enumerate(footprint[::direction]):
351
            if cls.__get_num_of_corner_zeros(line) <= cls.__get_num_of_corner_zeros(
352
                    footprint[::direction][num + 1]):
353
                x = num
354
                break
355
        if direction == -1:
356
            result_tmp = footprint[: (x + 1) * direction]
357
            x = footprint.shape[0] - x
358
        else:
359
            result_tmp = footprint[x:]
360
        return result_tmp.T if transpon else result_tmp, x
361

362
    @classmethod
363
    def calculate_crop(cls, footprint: np.ndarray) -> Tuple[Tuple[int, int], Tuple[int, int]]:
364
        """
365
        Calculate the crop coordinates based on the footprint array. The goal is to calculate rectangle without zero
366
        areas after alignment.
367

368
        Parameters:
369
        - footprint (np.ndarray): The input array representing the footprint.
370

371
        Returns:
372
        - Tuple[Tuple[int, int], Tuple[int, int]]: A tuple containing the crop coordinates for y-axis and x-axis.
373
        """
374

375
        y_pre_crop, x_pre_crop = cls.calculate_raw_crop(footprint)
376
        pre_cropped = footprint[slice(*y_pre_crop), slice(*x_pre_crop)]
377
        y_top_zeros = cls.__get_num_of_corner_zeros(pre_cropped[0])
378
        y_bottom_zeros = cls.__get_num_of_corner_zeros(pre_cropped[-1])
379
        x_left_zeros = cls.__get_num_of_corner_zeros(pre_cropped.T[0])
380
        x_right_zeros = cls.__get_num_of_corner_zeros(pre_cropped.T[-1])
381
        zeros = y_top_zeros, y_bottom_zeros, x_left_zeros, x_right_zeros
382
        trim_args = (1, False), (-1, False), (1, True), (-1, True)
383
        args_order = (item[1] for item in sorted(zip(zeros, trim_args), key=lambda x: x[0], reverse=True))
384
        border_map = {item: value for item, value in zip(trim_args, ["y_top", "y_bottom", "x_left", "x_right"])}
385
        result = {}
386
        cropped = pre_cropped
387
        for pair in args_order:
388
            boarder_name = border_map[pair]
389
            cropped, x = cls.__fine_crop_border(cropped, *pair)
390
            result.update({boarder_name: x})
391

392
        y_top = result["y_top"] + y_pre_crop[0]
393
        y_bottom = result["y_bottom"] + y_pre_crop[0]
394
        x_left = result["x_left"] + x_pre_crop[0]
395
        x_right = result["x_right"] + x_pre_crop[0]
396
        crop = (y_top, y_bottom), (x_left, x_right)
397
        return crop
398

399
    def align_images_wcs(self, progress_bar: Optional[AbstractProgressBar] = None) -> None:
400
        """
401
        Align images with World Coordinate System (WCS).
402

403
        Args:
404
            progress_bar (Optional[AbstractProgressBar]): Progress bar object.
405

406
        Returns:
407
            None
408
        """
409
        logger.log.info("Aligning images...")
410
        success_map, self.footprint_map = align_images_wcs(
411
            self.shm_params,
412
            [header.wcs for header in self.headers],
413
            progress_bar=progress_bar,
414
            stop_event=self.stop_event)
415
        self.__usage_map = success_map
416

417
    def crop_images(self) -> None:
418
        """
419
        Crop the images based on the footprint map and update borders.
420
        Keeps common non-zero area on all the images after alignment.
421
        """
422
        logger.log.info("Cropping images...")
423
        x_borders, y_borders = [], []
424
        if self.footprint_map is None:
425
            footprint_map = self.original_frames[self.__usage_map] == 0
426
            footprint_map = footprint_map[0]
427
            if len(footprint_map.shape) == 4:
428
                footprint_map = np.reshape(footprint_map, footprint_map.shape[:-1])
429
        else:
430
            footprint_map = self.footprint_map[np.array(self.__usage_map, dtype=bool)]
431
        for item in footprint_map:
432
            if self.stop_event.is_set():
433
                return
434
            y_border, x_border = self.calculate_crop(item)
435

436
            y_borders.append(y_border)
437
            x_borders.append(x_border)
438
        y_borders = np.array(y_borders)
439
        x_borders = np.array(x_borders)
440
        self.y_borders = slice(int(np.max(y_borders[:, 0])), int(np.min(y_borders[:, 1])))
441
        self.x_borders = slice(int(np.max(x_borders[:, 0])), int(np.min(x_borders[:, 1])))
442
        self.__cropped = True
443
        self.footprint_map = None
444
        # Plate solve after cropping
445
        self.wcs, _ = self.plate_solve(0)
446

447
    def make_master_dark(self, filenames: list[str], progress_bar: Optional[AbstractProgressBar] = None) -> np.ndarray:
448
        """
449
        Create a master dark frame from a list of dark frame filenames.
450

451
        Args:
452
            filenames (list[str]): List of dark frame filenames.
453
            progress_bar (Optional[AbstractProgressBar], optional): Progress bar instance. Defaults to None.
454

455
        Returns:
456
            np.ndarray: Master dark frame.
457
        """
458
        shape = (len(filenames), *self.origional_shape[1:])
459
        size = self.original_frames.itemsize
460
        shm_name = self.__create_shm_name('darks')
461
        for value in shape:
462
            size *= value
463
        shm_params = SharedMemoryParams(
464
            shm_name=shm_name, shm_shape=shape, shm_size=size, shm_dtype=PIXEL_TYPE)
465
        darks = np.memmap(shm_params.shm_name, dtype=PIXEL_TYPE, mode='w+', shape=shape)
466
        load_images(filenames, shm_params, progress_bar=progress_bar, to_debayer=self.to_debayer)
467
        master_dark = np.average(darks, axis=0)
468
        darks._mmap.close()
469
        del darks
470
        os.remove(shm_name)
471
        return master_dark
472

473
    def make_master_flat(self, flat_filenames: list[str], dark_flat_filenames: Optional[list[str]] = None,
474
                         progress_bar: Optional[AbstractProgressBar] = None) -> np.ndarray:
475
        """
476
        Create a master flat frame from a list of flat frame filenames.
477

478
        Args:
479
            flat_filenames (list[str]): List of flat frame filenames.
480
            dark_flat_filenames (Optional[list[str]], optional): List of dark flat frame filenames. Defaults to None.
481
            progress_bar (Optional[AbstractProgressBar], optional): Progress bar instance. Defaults to None.
482

483
        Returns:
484
            np.ndarray: Master flat frame.
485
        """
486
        flat_shape = (len(flat_filenames), *self.origional_shape[1:])
487
        flat_size = self.original_frames.itemsize
488
        flat_shm_name = self.__create_shm_name('flats')
489
        for value in flat_shape:
490
            flat_size *= value
491
        flat_shm_params = SharedMemoryParams(
492
            shm_name=flat_shm_name, shm_shape=flat_shape, shm_size=flat_size, shm_dtype=PIXEL_TYPE)
493
        flats = np.memmap(flat_shm_params.shm_name, dtype=PIXEL_TYPE, mode='w+', shape=flat_shape)
494
        load_images(flat_filenames, flat_shm_params, progress_bar=progress_bar, to_debayer=self.to_debayer)
495
        if dark_flat_filenames is not None:
496
            master_dark_flat = self.make_master_dark(dark_flat_filenames, progress_bar=progress_bar)
497
            for flat in flats:
498
                flat -= master_dark_flat
499
        master_flat = np.average(flats, axis=0)
500
        flats._mmap.close()
501
        del flats
502
        os.remove(flat_shm_name)
503
        return master_flat
504

505
    def load_flats(self, flat_filenames: list[str], progress_bar: Optional[AbstractProgressBar] = None) -> np.ndarray:
506
        """
507
        Load flat frames into memory.
508

509
        Args:
510
            flat_filenames (list[str]): List of flat frame filenames.
511
            progress_bar (Optional[AbstractProgressBar], optional): Progress bar instance. Defaults to None.
512

513
        Returns:
514
            np.ndarray: Loaded flat frames.
515
        """
516
        flat_shape = (len(flat_filenames), *self.origional_shape[1:])
517
        flat_size = self.original_frames.itemsize
518
        flat_shm_name = self.__create_shm_name('flats')
519
        for value in flat_shape:
520
            flat_size *= value
521
        flat_shm_params = SharedMemoryParams(
522
            shm_name=flat_shm_name, shm_shape=flat_shape, shm_size=flat_size, shm_dtype=PIXEL_TYPE)
523
        flats = np.memmap(flat_shm_params.shm_name, dtype=PIXEL_TYPE, mode='w+', shape=flat_shape)
524
        load_images(flat_filenames, flat_shm_params, progress_bar=progress_bar, to_debayer=self.to_debayer)
525
        res = np.copy(flats)
526
        flats._mmap.close()
527
        del flats
528
        os.remove(flat_shm_name)
529
        return res
530

531
    def calibrate_images(self, dark_files: Optional[list[str]] = None, flat_files: Optional[list[str]] = None,
532
                         dark_flat_files: Optional[list[str]] = None, progress_bar: Optional[AbstractProgressBar] = None
533
                         ) -> None:
534
        """
535
        Calibrates images by subtracting master dark frames and dividing by master flat frames.
536

537
        Args:
538
            dark_files (Optional[list[str]]): List of dark frame filenames.
539
            flat_files (Optional[list[str]]): List of flat frame filenames.
540
            dark_flat_files (Optional[list[str]]): List of dark flat frame filenames.
541
            progress_bar (Optional[AbstractProgressBar]): Progress bar instance.
542

543
        Returns:
544
            None
545
        """
546
        if dark_files is not None:
547
            master_dark = self.make_master_dark(dark_files, progress_bar=progress_bar)
548
            self.original_frames -= master_dark
549
        if flat_files is not None:
550
            master_flat = self.make_master_flat(flat_files, dark_flat_files, progress_bar=progress_bar)
551
            self.original_frames /= master_flat
552

553
    def stretch_images(self, progress_bar: Optional[AbstractProgressBar] = None) -> None:
554
        """
555
        Stretch images stored in shared memory.
556

557
        Args:
558
            progress_bar (Optional[AbstractProgressBar]): Progress bar to track the stretching progress.
559

560
        Returns:
561
            None
562
        """
563
        logger.log.info("Stretching images...")
564
        shm_params = self.shm_params
565
        shm_params.y_slice = self.y_borders
566
        shm_params.x_slice = self.x_borders
567
        stretch_images(self.shm_params, progress_bar=progress_bar, stop_event=self.stop_event)
568

569
    def get_number_of_chunks(self, size: tuple[int, int] = (CHUNK_SIZE, CHUNK_SIZE),
570
                             overlap: float = 0.5) -> tuple[np.ndarray, np.ndarray]:
571
        """
572
        Calculate the number of image chunks based on the specified size and overlap.
573

574
        Args:
575
            size (tuple[int, int], optional): The size of the image chunks in the format (height, width).
576
                Defaults to (CHUNK_SIZE, CHUNK_SIZE).
577
            overlap (float, optional): The overlap percentage between image chunks. Defaults to 0.5.
578

579
        Returns:
580
            tuple[np.ndarray, np.ndarray]: Two arrays representing the y and x coordinates of the image chunks.
581
        """
582
        size_y, size_x = size
583
        ys = np.arange(0, self.shape[1] - size_y * overlap, size_y * overlap)
584
        ys[-1] = self.shape[1] - size_y
585
        xs = np.arange(0, self.shape[2] - size_x * overlap, size_x * overlap)
586
        xs[-1] = self.shape[2] - size_x
587
        return ys, xs
588

589
    def generate_image_chunks(self, size: tuple[int, int] = (CHUNK_SIZE, CHUNK_SIZE), overlap: float = 0.5):
590
        """
591
        Generate image chunks based on the specified size and overlap.
592

593
        Args:
594
            size (tuple[int, int], optional): The size of the image chunks in the format (height, width).
595
                Defaults to (CHUNK_SIZE, CHUNK_SIZE).
596
            overlap (float, optional): The overlap percentage between image chunks. Defaults to 0.5.
597

598
        Yields:
599
            tuple: A tuple containing the coordinates and prepared images of the generated image chunks.
600
        """
601
        size_y, size_x = size
602
        ys, xs = self.get_number_of_chunks(size, overlap)
603
        coordinates = ((y, x) for y in ys for x in xs)
604
        for y, x in coordinates:
605
            y, x = int(y), int(x)
606
            imgs = np.copy(self.images[:, y:y + size_y, x:x + size_x])
607
            yield (y, x), self.prepare_images(np.copy(imgs))
608

609
    @staticmethod
610
    def generate_batch(chunk_generator: Generator, batch_size: int) -> tuple[tuple[int, int], np.ndarray]:
611
        """
612
        Generate batches of chunks for processing with the given batch size.
613

614
        Args:
615
            chunk_generator (Generator): Generator that yields chunks and coordinates.
616
            batch_size (int): The size of each batch.
617

618
        Yields:
619
            tuple[tuple[int, int], np.ndarray]: A tuple containing the coordinates and batch of chunks.
620
        """
621
        batch = []
622
        coords = []
623
        for coord, chunk in chunk_generator:
624
            batch.append(chunk)
625
            coords.append(coord)
626
            if len(batch) == batch_size:
627
                yield coords, np.array(batch)
628
                batch = []
629
                coords = []
630
        if len(batch) > 0:
631
            yield coords, np.array(batch)
632

633
    @staticmethod
634
    def estimate_image_noize_level(imgs: np.ndarray) -> float:
635
        """
636
        Estimate the noise level of the given images.
637

638
        Args:
639
            imgs (np.ndarray): The images to estimate the noise level for.
640

641
        Returns:
642
            float: The estimated noise level.
643
        """
644
        return np.mean(np.var(imgs, axis=0))
645

646
    @classmethod
647
    def prepare_images(cls, images: np.ndarray) -> np.ndarray:
648
        """
649
        Prepare the given images for processing by AI model.
650

651
        Args:
652
            images (np.ndarray): The images to prepare.
653

654
        Returns:
655
            np.ndarray: The prepared images.
656
        """
657
        # normalize images
658
        images -= cls.estimate_image_noize_level(images)
659
        images = images - np.min(images)
660
        images = images / np.max(images)
661
        images = np.reshape(images, (*images.shape[:3], 1))
662
        return images
663

664
    @staticmethod
665
    def adjust_chunks_to_min_len(imgs: np.ndarray, timestamps: list, min_len: int = 8) -> tuple[np.ndarray, list]:
666
        """
667
        Adjust the given chunks to the minimum length. To be used if there are fewer images than the minimum length.
668

669
        Args:
670
            imgs (np.ndarray): The images to adjust.
671
            timestamps (list): The timestamps of the images.
672
            min_len (int, optional): The minimum length of the chunks. Defaults to 8.
673

674
        Returns:
675
            tuple[np.ndarray, list]: The adjusted images and timestamps.
676
        """
677
        assert len(imgs) == len(timestamps), \
678
            f"Images and timestamp amount mismatch: len(imgs)={len(imgs)}. len(timestamps)={len(timestamps)}"
679

680
        if len(imgs) >= min_len:
681
            return imgs, timestamps
682
        new_imgs = []
683
        new_timestamps = []
684
        while len(new_imgs) < min_len:
685
            new_imgs.extend(list(imgs))
686
            new_timestamps.extend(timestamps)
687
        new_imgs = new_imgs[:8]
688
        new_timestamps = new_timestamps[:8]
689
        timestamped_images = list(zip(new_imgs, new_timestamps))
690
        timestamped_images.sort(key=lambda x: x[1])
691
        new_imgs = [item[0] for item in timestamped_images]
692
        new_timestamps = [item[1] for item in timestamped_images]
693
        new_imgs = np.array(new_imgs)
694
        return new_imgs, new_timestamps
695

696
    @staticmethod
697
    def make_file_paths(folder: str) -> list[str]:
698
        """
699
        Create a list of file paths for files in the specified folder that end with specific extensions.
700

701
        Parameters:
702
        - folder (str): The folder path to search for files.
703

704
        Returns:
705
        - list[str]: A list of file paths with extensions '.xisf', '.fit', or '.fits'.
706
        """
707
        return [os.path.join(folder, item) for item in os.listdir(folder) if item.lower().endswith(
708
            ".xisf") or item.lower().endswith(".fit") or item.lower().endswith(".fits")]
709

710
    def plate_solve(self, ref_idx: int = 0, sky_coord: Optional[np.ndarray] = None) -> tuple[WCS, np.ndarray]:
711
        """
712
        Perform plate solving on the given reference image.
713

714
        Args:
715
            ref_idx (int, optional): The index of the reference image. Defaults to 0.
716
            sky_coord (np.ndarray, optional): The sky coordinates of the reference image. Defaults to None.
717

718
        Returns:
719
            tuple[WCS, np.ndarray]: The plate solved WCS and the sky coordinates of the reference image.
720
        """
721
        logger.log.info("Plate solving...")
722
        wcs, sky_coord = plate_solve_image(self.images[ref_idx], header=self.headers[ref_idx], sky_coord=sky_coord)
723
        self.__wcs = wcs
724
        return wcs, sky_coord
725

726
    def plate_solve_all(self, progress_bar: Optional[AbstractProgressBar] = None) -> None:
727
        """
728
        Perform plate solving on all images and update corresponding headers.
729

730
        Args:
731
            progress_bar (AbstractProgressBar, optional): The progress bar to display the progress. Defaults to None.
732

733
        Returns:
734
            None
735
        """
736
        logger.log.info("Plate solving all images...")
737
        res = plate_solve(self.shm_params, self.headers, progress_bar=progress_bar, stop_event=self.stop_event)
738
        for wcs, header in zip(res, self.headers):
739
            header.wcs = wcs
740

741
    @staticmethod
742
    def convert_ra(ra: Angle) -> str:
743
        """
744
        Convert right ascension from astropy Angle format to a string representation suitable for Small Body Api.
745
        https://ssd-api.jpl.nasa.gov/doc/sb_ident.html
746

747
        Args:
748
            ra (Angle): The right ascension angle to convert.
749

750
        Returns:
751
            str: The string representation of the right ascension.
752
        """
753
        minus_substr = "M" if int(ra.h) < 0 else ""
754
        hour = f"{minus_substr}{abs(int(ra.h)):02d}"
755
        return f"{hour}-{abs(int(ra.m)):02d}-{abs(int(ra.s)):02d}"
756

757
    @staticmethod
758
    def convert_dec(dec: Angle) -> str:
759
        """
760
        Convert declination from astropy Angle format to a string representation suitable for Small Body Api.
761
        https://ssd-api.jpl.nasa.gov/doc/sb_ident.html
762

763
        Args:
764
            dec (Angle): The declination angle to convert.
765

766
        Returns:
767
            str: The string representation of the declination.
768
        """
769
        minus_substr = "M" if int(dec.d) < 0 else ""
770
        hour = f"{minus_substr}{abs(int(dec.d)):02d}"
771
        return f"{hour}-{abs(int(dec.m)):02d}-{abs(int(dec.s)):02d}"
772

773
    def fetch_known_asteroids_for_image(self, img_idx: int, magnitude_limit: float = 18.0
774
                                        ) -> tuple[list[KnownObject], list[KnownObject]]:
775
        """
776
        Fetch known asteroids and comets within the image's field of view based on the specified magnitude limit.
777
        Request data from JPL Small Body Api (https://ssd-api.jpl.nasa.gov/doc/sb_ident.html).
778

779
        Args:
780
            img_idx (int): The index of the image.
781
            magnitude_limit (float): The magnitude limit for known asteroids.
782

783
        Returns:
784
            tuple[list[KnownObject], list[KnownObject]]: A tuple containing lists of KnownObject instances for asteroids
785
                and comets.
786
        """
787
        logger.log.debug("Requesting visible targets")
788
        obs_time = self.headers[img_idx].timestamp
789
        obs_time = (f"{obs_time.year:04d}-{obs_time.month:02d}-{obs_time.day:02d}_{obs_time.hour:02d}:"
790
                    f"{obs_time.minute:02d}:{obs_time.second:02d}")
791
        corner_points = [self.wcs.pixel_to_world(x, y) for x, y in (
792
            (0, 0), (0, self.shape[1]), (self.shape[2], 0), (self.shape[2], self.shape[1]))]
793
        ra_max = max([item.ra.hms for item in corner_points])
794
        ra_min = min([item.ra.hms for item in corner_points])
795
        dec_max = max([item.dec.dms for item in corner_points])
796
        dec_min = min([item.dec.dms for item in corner_points])
797
        fov_ra_lim = f"{self.convert_ra(ra_min)},{self.convert_ra(ra_max)}"
798
        fov_dec_lim = f"{self.convert_dec(dec_min)},{self.convert_dec(dec_max)}"
799
        know_asteroids = []
800
        know_comets = []
801
        for sb_kind in ('a', 'c'):
802
            params = {
803
                "sb-kind": sb_kind,
804
                "lat": round(self.headers[img_idx].site_location.lat, 3),
805
                "lon": round(self.headers[img_idx].site_location.long, 4),
806
                "alt": 0,
807
                "obs-time": obs_time,
808
                "mag-required": True,
809
                "two-pass": True,
810
                "suppress-first-pass": True,
811
                "req-elem": False,
812
                "vmag-lim": magnitude_limit,
813
                "fov-ra-lim": fov_ra_lim,
814
                "fov-dec-lim": fov_dec_lim,
815
            }
816
            logger.log.debug(f"Params: {params}")
817
            res = requests.get("https://ssd-api.jpl.nasa.gov/sb_ident.api", params)
818
            res = json.loads(res.content)
819
            potential_known_objects = [dict(zip(res["fields_second"], item)) for item in
820
                                       res.get("data_second_pass", [])]
821
            potential_known_objects = [KnownObject(item, wcs=self.wcs) for item in potential_known_objects]
822
            first_pass_objects = [dict(zip(res["fields_first"], item)) for item in res.get("data_first_pass", [])]
823
            potential_known_objects.extend([KnownObject(item, wcs=self.wcs) for item in first_pass_objects])
824
            for item in potential_known_objects:
825
                x, y = item.pixel_coordinates
826
                if 0 <= x < self.shape[2] and 0 <= y < self.shape[1]:
827
                    if sb_kind == 'a':
828
                        know_asteroids.append(item)
829
                    if sb_kind == 'c':
830
                        know_comets.append(item)
831
        logger.log.info(f"Found {len(know_asteroids)} known asteroids and {len(know_comets)} known comets in the FOV")
832
        if know_asteroids:
833
            logger.log.info(f"Known asteroids:")
834
            for item in know_asteroids:
835
                logger.log.info(str(item))
836
        if know_comets:
837
            logger.log.info(f"Known comets:")
838
            for item in know_comets:
839
                logger.log.info(str(item))
840
        return know_asteroids, know_comets
841

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.