CelestialSurveyor
176 строк · 7.7 Кб
1import numpy as np2import traceback3
4from functools import partial5from logging.handlers import QueueHandler6from multiprocessing import Manager, Pool, Queue, cpu_count7from typing import Optional8from threading import Event9
10from astropy.wcs import WCS11from backend.consuming_functions.measure_execution_time import measure_execution_time12from backend.data_classes import SharedMemoryParams13from backend.progress_bar import AbstractProgressBar14from logger.logger import get_logger15from reproject import reproject_interp16
17
18logger = get_logger()19
20
21@measure_execution_time
22def align_images_wcs(shm_params: SharedMemoryParams, all_wcs: list[WCS],23progress_bar: Optional[AbstractProgressBar] = None, stop_event: Optional[Event] = None24) -> tuple[list[bool], np.ndarray]:25"""26Align images with World Coordinate System (WCS).
27
28Args:
29shm_params (SharedMemoryParams): Shared memory parameters where the images are stored.
30all_wcs (list[WCS]): List of WCS objects to be used for alignment.
31progress_bar (Optional[AbstractProgressBar]): Progress bar object.
32stop_event (Optional[Event]): Stop event object used to stop the child processes.
33
34Returns:
35tuple[list[bool], np.ndarray]: Tuple containing a list of success flags and a numpy array of footprints.
36"""
37available_cpus = min(cpu_count(), 4)38frames_num = shm_params.shm_shape[0]39used_cpus = min(available_cpus, frames_num)40logger.log.debug(f"Number of CPUs to be used for alignment: {used_cpus}")41with Pool(processes=used_cpus) as pool:42m = Manager()43progress_queue = m.Queue()44stop_queue = m.Queue(maxsize=1)45log_queue = m.Queue()46logger.start_process_listener(log_queue)47logger.log.debug(f"Starting alignment with {used_cpus} workers")48results = pool.map_async(49partial(align_wcs_worker, shm_params=shm_params, progress_queue=progress_queue, ref_wcs=all_wcs[0],50all_wcses=all_wcs, stop_queue=stop_queue, log_queue=log_queue),51np.array_split(np.arange(frames_num), used_cpus))52if progress_bar is not None:53progress_bar.set_total(frames_num)54for _ in range(frames_num):55if stop_event is not None and stop_event.is_set():56logger.log.debug("Stop event triggered")57stop_queue.put(True)58break59img_idx = progress_queue.get()60logger.log.debug(f"Aligned image at index '{img_idx}'")61progress_bar.update()62progress_bar.complete()63
64res = results.get()65idxs = []66successes = []67footprints = []68for idx, success, footprint in res:69idxs.extend(idx)70successes.extend(success)71footprints.extend(footprint)72logger.log.debug(f"Alignment finished. Success map: {successes}")73res = list(zip(idxs, successes, footprints))74res.sort(key=lambda item: item[0])75success_map = [item[1] for item in res]76footprint_map = np.array([item[2] for item in res])77pool.close()78pool.join()79logger.log.debug(f"Alignment pool stopped.")80logger.stop_process_listener()81return success_map, footprint_map82
83
84def align_wcs_worker(img_indexes: list[int], shm_params: SharedMemoryParams, progress_queue: Queue, ref_wcs: WCS,85all_wcses: list[WCS], stop_queue: Optional[Queue] = None, log_queue: Optional[Queue] = None86) -> tuple[list[int], list[bool], list[np.ndarray]]:87"""88Worker function for aligning images basing on the WCS information.
89
90Args:
91img_indexes (list[int]): List of image indexes to align within this worker.
92shm_params (SharedMemoryParams): Shared memory parameters where the images are stored.
93progress_queue (Queue): Queue to report progress.
94ref_wcs (WCS): Reference WCS information to be used for alignment.
95all_wcses (list[WCS]): List of all WCS information.
96stop_queue (Optional[Queue], optional): Queue to stop the process. Defaults to None.
97log_queue (Optional[Queue], optional): Queue for logging. Defaults to None.
98
99Returns:
100tuple[list[int], list[bool], list[np.ndarray]]: A tuple containing aligned image indexes, success status list,
101and footprints.
102"""
103imgs = np.memmap(shm_params.shm_name, dtype=shm_params.shm_dtype, mode='r+', shape=shm_params.shm_shape)104footprints = []105successes = []106handler = QueueHandler(log_queue)107logger.log.addHandler(handler)108logger.log.debug(f"Align worker started with {len(img_indexes)} images")109logger.log.debug(f"Shared memory parameters: {shm_params}")110for img_idx in img_indexes:111too_far = False112if stop_queue is not None and not stop_queue.empty():113logger.log.debug("Align worker detected stop event. Stopping.")114break115
116try:117distance = get_centre_distance(imgs[0].shape, ref_wcs, all_wcses[img_idx])118if distance > 0.1 * min(imgs[0].shape[:2]):119too_far = True120logger.log.warning(f"Align worker detected that image at index '{img_idx}' is too far "121f"from the reference solution. Pixel distance: {distance}. Excluding this image")122else:123_, footprint = reproject_interp(124(np.reshape(np.copy(imgs[img_idx]), shm_params.shm_shape[1:3]),125all_wcses[img_idx]),126ref_wcs,127shape_out=shm_params.shm_shape[1:3],128output_array=imgs[img_idx],129)130imgs.flush()131except Exception:132# if an error occurs, assume that the image is not aligned, and mark it as such. In this case alignment133# process is not considered to be failed, failed image will not be used in the next steps134footprint = np.ones(shm_params.shm_shape[1:], dtype=bool)135success = False136logger.log.error(f"Align worker failed to process image at index "137f"'{img_idx}' due to the following error:\n{traceback.format_exc()}")138else:139if too_far:140footprint = np.ones(shm_params.shm_shape[1:], dtype=bool)141success = False142else:143success = True144# this line is needed to be consistent with the legacy code which does image cropping145footprint = 1 - footprint146footprint = np.array(footprint, dtype=bool)147footprints.append(footprint)148successes.append(success)149
150progress_queue.put(img_idx)151logger.log.removeHandler(handler)152return img_indexes, successes, footprints153
154
155def get_centre_distance(img_shape: tuple, wcs1: WCS, wcs2: WCS):156"""157Calculate the distance between the centers of two images.
158
159Args:
160img_shape (tuple[int, int]): Shape of the image.
161wcs1 (WCS): WCS information of the first image.
162wcs2 (WCS): WCS information of the second image.
163
164Returns:
165float: The distance between the centers of the two images.
166"""
167ref_center_x, ref_center_y = img_shape[1] / 2, img_shape[0] / 2168center_coordinates = wcs2.pixel_to_world(ref_center_x, ref_center_y)169second_centre_on_ref_image = wcs1.world_to_pixel(center_coordinates)170logger.log.debug(f"Second centre: {second_centre_on_ref_image}")171second_center_x, second_center_y = second_centre_on_ref_image172second_center_x = int(second_center_x)173second_center_y = int(second_center_y)174distance = np.sqrt((ref_center_x - second_center_x) ** 2 + (ref_center_y - second_center_y) ** 2)175logger.log.debug(f"Distance: {distance}")176return distance177