CelestialSurveyor

Форк
0
176 строк · 7.7 Кб
1
import numpy as np
2
import traceback
3

4
from functools import partial
5
from logging.handlers import QueueHandler
6
from multiprocessing import Manager, Pool, Queue, cpu_count
7
from typing import Optional
8
from threading import Event
9

10
from astropy.wcs import WCS
11
from backend.consuming_functions.measure_execution_time import measure_execution_time
12
from backend.data_classes import SharedMemoryParams
13
from backend.progress_bar import AbstractProgressBar
14
from logger.logger import get_logger
15
from reproject import reproject_interp
16

17

18
logger = get_logger()
19

20

21
@measure_execution_time
22
def align_images_wcs(shm_params: SharedMemoryParams, all_wcs: list[WCS],
23
                     progress_bar: Optional[AbstractProgressBar] = None, stop_event: Optional[Event] = None
24
                     ) -> tuple[list[bool], np.ndarray]:
25
    """
26
    Align images with World Coordinate System (WCS).
27

28
    Args:
29
        shm_params (SharedMemoryParams): Shared memory parameters where the images are stored.
30
        all_wcs (list[WCS]): List of WCS objects to be used for alignment.
31
        progress_bar (Optional[AbstractProgressBar]): Progress bar object.
32
        stop_event (Optional[Event]): Stop event object used to stop the child processes.
33

34
    Returns:
35
        tuple[list[bool], np.ndarray]: Tuple containing a list of success flags and a numpy array of footprints.
36
    """
37
    available_cpus = min(cpu_count(), 4)
38
    frames_num = shm_params.shm_shape[0]
39
    used_cpus = min(available_cpus, frames_num)
40
    logger.log.debug(f"Number of CPUs to be used for alignment: {used_cpus}")
41
    with Pool(processes=used_cpus) as pool:
42
        m = Manager()
43
        progress_queue = m.Queue()
44
        stop_queue = m.Queue(maxsize=1)
45
        log_queue = m.Queue()
46
        logger.start_process_listener(log_queue)
47
        logger.log.debug(f"Starting alignment with {used_cpus} workers")
48
        results = pool.map_async(
49
            partial(align_wcs_worker, shm_params=shm_params, progress_queue=progress_queue, ref_wcs=all_wcs[0],
50
                    all_wcses=all_wcs, stop_queue=stop_queue, log_queue=log_queue),
51
            np.array_split(np.arange(frames_num), used_cpus))
52
        if progress_bar is not None:
53
            progress_bar.set_total(frames_num)
54
            for _ in range(frames_num):
55
                if stop_event is not None and stop_event.is_set():
56
                    logger.log.debug("Stop event triggered")
57
                    stop_queue.put(True)
58
                    break
59
                img_idx = progress_queue.get()
60
                logger.log.debug(f"Aligned image at index '{img_idx}'")
61
                progress_bar.update()
62
            progress_bar.complete()
63

64
        res = results.get()
65
        idxs = []
66
        successes = []
67
        footprints = []
68
        for idx, success, footprint in res:
69
            idxs.extend(idx)
70
            successes.extend(success)
71
            footprints.extend(footprint)
72
        logger.log.debug(f"Alignment finished. Success map: {successes}")
73
        res = list(zip(idxs, successes, footprints))
74
        res.sort(key=lambda item: item[0])
75
        success_map = [item[1] for item in res]
76
        footprint_map = np.array([item[2] for item in res])
77
        pool.close()
78
        pool.join()
79
        logger.log.debug(f"Alignment pool stopped.")
80
        logger.stop_process_listener()
81
    return success_map, footprint_map
82

83

84
def align_wcs_worker(img_indexes: list[int], shm_params: SharedMemoryParams, progress_queue: Queue, ref_wcs: WCS,
85
                     all_wcses: list[WCS], stop_queue: Optional[Queue] = None, log_queue: Optional[Queue] = None
86
                     ) -> tuple[list[int], list[bool], list[np.ndarray]]:
87
    """
88
    Worker function for aligning images basing on the WCS information.
89

90
    Args:
91
        img_indexes (list[int]): List of image indexes to align within this worker.
92
        shm_params (SharedMemoryParams): Shared memory parameters where the images are stored.
93
        progress_queue (Queue): Queue to report progress.
94
        ref_wcs (WCS): Reference WCS information to be used for alignment.
95
        all_wcses (list[WCS]): List of all WCS information.
96
        stop_queue (Optional[Queue], optional): Queue to stop the process. Defaults to None.
97
        log_queue (Optional[Queue], optional): Queue for logging. Defaults to None.
98

99
    Returns:
100
        tuple[list[int], list[bool], list[np.ndarray]]: A tuple containing aligned image indexes, success status list,
101
        and footprints.
102
    """
103
    imgs = np.memmap(shm_params.shm_name, dtype=shm_params.shm_dtype, mode='r+', shape=shm_params.shm_shape)
104
    footprints = []
105
    successes = []
106
    handler = QueueHandler(log_queue)
107
    logger.log.addHandler(handler)
108
    logger.log.debug(f"Align worker started with {len(img_indexes)} images")
109
    logger.log.debug(f"Shared memory parameters: {shm_params}")
110
    for img_idx in img_indexes:
111
        too_far = False
112
        if stop_queue is not None and not stop_queue.empty():
113
            logger.log.debug("Align worker detected stop event. Stopping.")
114
            break
115

116
        try:
117
            distance = get_centre_distance(imgs[0].shape, ref_wcs, all_wcses[img_idx])
118
            if distance > 0.1 * min(imgs[0].shape[:2]):
119
                too_far = True
120
                logger.log.warning(f"Align worker detected that image at index '{img_idx}' is too far "
121
                                   f"from the reference solution.  Pixel distance: {distance}. Excluding this image")
122
            else:
123
                _, footprint = reproject_interp(
124
                    (np.reshape(np.copy(imgs[img_idx]), shm_params.shm_shape[1:3]),
125
                     all_wcses[img_idx]),
126
                    ref_wcs,
127
                    shape_out=shm_params.shm_shape[1:3],
128
                    output_array=imgs[img_idx],
129
                )
130
                imgs.flush()
131
        except Exception:
132
            # if an error occurs, assume that the image is not aligned, and mark it as such. In this case alignment
133
            # process is not considered to be failed, failed image will not be used in the next steps
134
            footprint = np.ones(shm_params.shm_shape[1:], dtype=bool)
135
            success = False
136
            logger.log.error(f"Align worker failed to process image at index "
137
                             f"'{img_idx}' due to the following error:\n{traceback.format_exc()}")
138
        else:
139
            if too_far:
140
                footprint = np.ones(shm_params.shm_shape[1:], dtype=bool)
141
                success = False
142
            else:
143
                success = True
144
                # this line is needed to be consistent with the legacy code which does image cropping
145
                footprint = 1 - footprint
146
                footprint = np.array(footprint, dtype=bool)
147
        footprints.append(footprint)
148
        successes.append(success)
149

150
        progress_queue.put(img_idx)
151
    logger.log.removeHandler(handler)
152
    return img_indexes, successes, footprints
153

154

155
def get_centre_distance(img_shape: tuple, wcs1: WCS, wcs2: WCS):
156
    """
157
    Calculate the distance between the centers of two images.
158

159
    Args:
160
        img_shape (tuple[int, int]): Shape of the image.
161
        wcs1 (WCS): WCS information of the first image.
162
        wcs2 (WCS): WCS information of the second image.
163

164
    Returns:
165
        float: The distance between the centers of the two images.
166
    """
167
    ref_center_x, ref_center_y = img_shape[1] / 2, img_shape[0] / 2
168
    center_coordinates = wcs2.pixel_to_world(ref_center_x, ref_center_y)
169
    second_centre_on_ref_image = wcs1.world_to_pixel(center_coordinates)
170
    logger.log.debug(f"Second centre: {second_centre_on_ref_image}")
171
    second_center_x, second_center_y = second_centre_on_ref_image
172
    second_center_x = int(second_center_x)
173
    second_center_y = int(second_center_y)
174
    distance = np.sqrt((ref_center_x - second_center_x) ** 2 + (ref_center_y - second_center_y) ** 2)
175
    logger.log.debug(f"Distance: {distance}")
176
    return distance
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.