CelestialSurveyor

Форк
0
184 строки · 7.9 Кб
1
import numpy as np
2
import traceback
3
import twirl
4

5
from astropy import units as u
6
from astropy.wcs import WCS
7
from astropy.coordinates import SkyCoord
8
from astroquery.gaia import Gaia
9
from functools import partial
10
from logging.handlers import QueueHandler
11
from multiprocessing import Queue, cpu_count, Pool, Manager
12
from threading import Event
13
from typing import Optional
14

15
from backend.progress_bar import AbstractProgressBar
16
from backend.data_classes import Header
17
from logger.logger import get_logger
18
from backend.data_classes import SharedMemoryParams
19
from backend.consuming_functions.measure_execution_time import measure_execution_time
20

21

22
logger = get_logger()
23

24

25
def plate_solve_image(image: np.ndarray, header: Header,
26
                      sky_coord: Optional[np.ndarray] = None) -> tuple[WCS, np.ndarray]:
27
    """
28
    Plate solves the image basing on the provided GAIA data. If this data is not available - it will be requested.
29

30
    Args:
31
        image (np.ndarray): The image to be plate solved.
32
        header (Header): The header information related to the image.
33
        sky_coord (Optional[np.ndarray]): Source information for the current FOV got from GAIA. Defaults to None.
34

35
    Returns:
36
        tuple[WCS, np.ndarray]: The plate solved WCS and sky coordinates.
37
    """
38
    header_data = header.solve_data
39
    pixel = header_data.pixel_scale * u.arcsec  # known pixel scale
40
    img = np.copy(image)
41
    img = np.reshape(img, (img.shape[0], img.shape[1]))
42
    shape = img.shape
43
    fov = np.min(shape[:2]) * pixel.to(u.deg)
44
    if sky_coord is None:
45
        # sky_coord = twirl.gaia_radecs(header_data.sky_coord, fov)[0:200]
46
        sky_coord = get_sources_from_gaia(header_data.sky_coord, fov)[0:200]
47
        sky_coord = twirl.geometry.sparsify(sky_coord, 0.1)
48
        sky_coord = sky_coord[:25]
49
    top_left_corner = (slice(None, img.shape[0] // 2), slice(None, img.shape[1] // 2), (0, 0))
50
    bottom_left_corner = (slice(img.shape[0] // 2, None), slice(None, img.shape[1] // 2), (img.shape[0]//2, 0))
51
    top_right_corner = (slice(None, img.shape[0] // 2), slice(img.shape[1] // 2, None), (0, img.shape[1]//2))
52
    bottom_right_corner = (slice(img.shape[0] // 2, None), slice(img.shape[1] // 2, None),
53
                           (img.shape[0]//2, img.shape[1]//2))
54
    corners = [top_left_corner, bottom_left_corner, top_right_corner, bottom_right_corner]
55
    all_corner_peaks = []
56
    for y_slice, x_slice, (y_offset, x_offset) in corners:
57
        peak_cos = twirl.find_peaks(img[y_slice, x_slice], threshold=1)[0:200]
58
        corner_peaks = []
59
        for x, y in peak_cos:
60
            y += y_offset
61
            x += x_offset
62
            dist_from_center_x = x - shape[1] // 2
63
            dist_from_center_y = y - shape[0] // 2
64
            if np.sqrt(dist_from_center_x**2 + dist_from_center_y**2) < min(shape) // 2:
65
                corner_peaks.append([x, y])
66
        all_corner_peaks.extend(corner_peaks[:8])
67
    all_corner_peaks = np.array(all_corner_peaks)
68
    wcs = twirl.compute_wcs(all_corner_peaks, sky_coord, asterism=4)
69
    return wcs, sky_coord
70

71

72
@measure_execution_time
73
def plate_solve(shm_params: SharedMemoryParams, headers: list[Header],
74
                progress_bar: Optional[AbstractProgressBar] = None,
75
                stop_event: Optional[Event] = None) -> list[WCS]:
76
    """
77
    Plate solving images stored in shared memory.
78

79
    Args:
80
        shm_params (SharedMemoryParams): Shared memory parameters for accessing image data.
81
        headers (list[Header]): List of image headers.
82
        progress_bar (Optional[AbstractProgressBar]): Progress bar for tracking plate solving progress.
83
        stop_event (Optional[Event]): Event for stopping plate solving process.
84

85
    Returns:
86
        list[WCS]: List of plate solved WCS coordinates.
87
    """
88
    logger.log.info("Plate solving...")
89
    imgs = np.memmap(shm_params.shm_name, dtype=shm_params.shm_dtype, mode='r+', shape=shm_params.shm_shape)
90
    # get reference stars from GAIA for the first image's FOV
91
    _, reference_stars = plate_solve_image(imgs[0], headers[0])
92
    available_cpus = cpu_count() - 1
93
    frames_num = shm_params.shm_shape[0]
94
    used_cpus = min(available_cpus, frames_num)
95
    logger.log.debug(f"Number of CPUs to be used for loading images: {used_cpus}")
96
    m = Manager()
97
    progress_queue = m.Queue()
98
    log_queue = m.Queue()
99
    logger.start_process_listener(log_queue)
100
    stop_queue = m.Queue(maxsize=1)
101
    with Pool(processes=used_cpus) as pool:
102
        logger.log.debug(f"Starting loading images with {used_cpus} workers")
103
        results = pool.map_async(
104
            partial(plate_solve_worker, shm_params=shm_params, progress_queue=progress_queue,
105
                    reference_stars=reference_stars, header=headers[0], stop_queue=stop_queue, log_queue=log_queue),
106
            np.array_split(np.arange(frames_num), used_cpus))
107
        if progress_bar is not None:
108
            progress_bar.set_total(frames_num)
109
            for _ in range(frames_num):
110
                if stop_event is not None and stop_event.is_set():
111
                    stop_queue.put(True)
112
                    logger.log.debug("Stop event triggered")
113
                    break
114
                got_result = False
115
                while not got_result:
116
                    if not progress_queue.empty():
117
                        progress_queue.get()
118
                        logger.log.debug("Got a result from the progress queue")
119
                        got_result = True
120
                    if not stop_queue.empty():
121
                        logger.log.debug("Detected error from workers. Stopping.")
122
                        break
123
                if not stop_queue.empty():
124
                    break
125
                progress_bar.update()
126
            progress_bar.complete()
127
        res = results.get()
128
    pool.close()
129
    pool.join()
130
    logger.log.debug(f"Plate solve pool stopped.")
131
    logger.stop_process_listener()
132
    new_res = []
133
    for item in res:
134
        new_res.extend(item)
135
    new_res.sort(key=lambda x: x[0])
136
    return [item[1] for item in new_res]
137

138

139
def plate_solve_worker(img_indexes: list[int], header: Header, shm_params: SharedMemoryParams,
140
                       reference_stars: np.ndarray, progress_queue: Queue, stop_queue: Optional[Queue] = None,
141
                       log_queue: Optional[Queue] = None) -> list[tuple[int, WCS]]:
142
    handler = QueueHandler(log_queue)
143
    logger.log.addHandler(handler)
144
    logger.log.debug(f"Load worker started with {len(img_indexes)} images")
145
    logger.log.debug(f"Shared memory parameters: {shm_params}")
146
    try:
147
        imgs = np.memmap(shm_params.shm_name, dtype=shm_params.shm_dtype, mode='r', shape=shm_params.shm_shape)
148
        res = []
149
        for img_idx in img_indexes:
150
            if stop_queue is not None and not stop_queue.empty():
151
                logger.log.debug("Plate solve worker detected stop event. Stopping.")
152
                break
153
            img = imgs[img_idx]
154
            wcs, _ = plate_solve_image(img, header, reference_stars)
155
            progress_queue.put(img_idx)
156
            res.append((img_idx, wcs))
157
    except Exception:
158
        logger.log.error(f"Plate solve worker failed due to the following error:\n{traceback.format_exc()}")
159
        stop_queue.put("ERROR")
160
        raise
161
    return res
162

163

164
@measure_execution_time
165
def get_sources_from_gaia(center: SkyCoord, fov, limit: int = 10000, mag_limit: float = 10) -> np.ndarray:
166
    ra = center.ra.deg
167
    dec = center.dec.deg
168
    if fov.ndim == 1:
169
        ra_fov, dec_fov = fov.to(u.deg).value
170
    else:
171
        ra_fov = dec_fov = fov.to(u.deg).value
172
    radius = np.min([ra_fov, dec_fov]) / 2
173

174
    fields = "ra, dec"
175
    job = Gaia.launch_job(
176
        f"select top {limit} {fields} from gaiadr2.gaia_source where "
177
        "1=CONTAINS("
178
        f"POINT('ICRS', {ra}, {dec}), "
179
        f"CIRCLE('ICRS',ra, dec, {radius})) "
180
        f"and phot_g_mean_mag < {mag_limit} "
181
        "order by phot_g_mean_mag"
182
    )
183
    table = job.get_results()
184
    return np.array([table["ra"].value.data, table["dec"].value.data]).T
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.