CelestialSurveyor

Форк
0
229 строк · 8.4 Кб
1
import astropy.io.fits
2
import cv2
3
import numpy as np
4
import traceback
5

6
from xisf import XISF
7
from multiprocessing import Queue, cpu_count, Pool, Manager
8
from typing import Optional
9
from functools import partial
10

11
from backend.progress_bar import AbstractProgressBar
12
from logger.logger import get_logger
13
from backend.data_classes import SharedMemoryParams
14
from threading import Event
15
from backend.consuming_functions.measure_execution_time import measure_execution_time
16
from logging.handlers import QueueHandler
17

18

19
PIXEL_TYPE = np.float32
20
logger = get_logger()
21

22

23
def debayer(img_data: np.ndarray) -> np.ndarray:
24
    """
25
    Perform debayering on the input image data.
26

27
    Args:
28
        img_data: Image data to be debayered.
29

30
    Returns:
31
        np.ndarray: Debayered image data.
32
    """
33
    res = np.array(cv2.cvtColor(img_data, cv2.COLOR_BayerBG2GRAY))
34
    res.reshape(img_data.shape[0], img_data.shape[1], 1)
35
    return res
36

37

38
def to_gray(img_data: np.ndarray) -> np.ndarray:
39
    """
40
    Convert the input image data to grayscale.
41

42
    Args:
43
        img_data (np.ndarray): The input image data to be converted to grayscale.
44

45
    Returns:
46
        np.ndarray: The grayscale image data.
47
    """
48
    return np.array(cv2.cvtColor(img_data, cv2.COLOR_BGR2GRAY))
49

50

51
def load_image_fits(filename: str, to_debayer: bool = False) -> np.ndarray:
52
    """
53
    Load image data from a FIT(S) file and optionally debayer it.
54

55
    Args:
56
        filename (str): The path to the FIT(S) file.
57
        to_debayer (bool): Whether to perform debayering.
58

59
    Returns:
60
        np.ndarray: numpy array containing the image data converted to grayscale.
61
    """
62
    with astropy.io.fits.open(filename) as hdul:
63
        img_data = np.array(hdul[0].data)
64
        if len(img_data.shape) == 2:
65
            img_data = img_data.reshape(img_data.shape[0], img_data.shape[1], 1)
66
        if img_data.shape[0] in [1, 3]:
67
            img_data = np.swapaxes(img_data, 0, 2)
68
        if to_debayer and img_data.shape[2] == 1:
69
            img_data = np.array(debayer(img_data))
70
        img_data = img_data.reshape(img_data.shape[0], img_data.shape[1], 1)
71
        img_data = np.array(img_data)
72
        if img_data.shape[2] == 3:
73
            img_data = np.array(to_gray(img_data))
74
        # Normalize
75
        img_data = img_data.astype('float32')
76
        img_data /= 256 * 256 - 1
77
        img_data = img_data.astype(PIXEL_TYPE)
78
        img_data.shape = *img_data.shape[:2],
79
    return img_data
80

81

82
def load_image_xisf(filename: str, to_debayer: bool = False) -> np.ndarray:
83
    """
84
    Load image data from a XISF file and optionally debayer it.
85

86
    Args:
87
        filename (str): The path to the FITS file.
88
        to_debayer (bool): Whether to perform debayering.
89

90
    Returns:
91
        np.ndarray: numpy array containing the image data converted to grayscale.
92
    """
93
    _ = to_debayer
94
    xisf = XISF(filename)
95
    img_data = xisf.read_image(0)
96
    img_data = np.array(img_data)
97
    if len(img_data.shape) == 2:
98
        img_data.shape = *img_data.shape, 1
99
    if img_data.shape[0] in [1, 3]:
100
        img_data = np.swapaxes(img_data, 0, 2)
101
    if img_data.shape[2] == 3:
102
        img_data = np.array(to_gray(img_data))
103
    img_data = img_data.astype(PIXEL_TYPE)
104
    if len(img_data.shape) == 2:
105
        img_data.shape = *img_data.shape[:2],
106
    return img_data
107

108

109
def load_image(file_path: str, to_debayer: bool = False) -> np.ndarray:
110
    """
111
    Load an image file and optionally perform debayering.
112

113
    Args:
114
        file_path (str): The path to the image file.
115
        to_debayer (bool): Whether to perform debayering.
116

117
    Returns:
118
        np.ndarray: numpy array containing the image data converted to grayscale.
119
    """
120
    if file_path.lower().endswith(".fits") or file_path.lower().endswith(".fit"):
121
        return load_image_fits(file_path, to_debayer)
122
    elif file_path.lower().endswith(".xisf"):
123
        return load_image_xisf(file_path, to_debayer)
124
    else:
125
        raise ValueError(f"Unsupported file format: {file_path}")
126

127

128
def load_worker(indexes: list[int], file_list: list[str], shm_params: SharedMemoryParams, progress_queue: Queue,
129
                to_debayer: bool = False, stop_queue: Optional[Queue] = None, log_queue: Optional[Queue] = None
130
                ) -> None:
131
    """
132
    Load images into shared memory based on the provided indexes and file list.
133

134
    Args:
135
        indexes (list[int]): List of indexes specifying which images to load.
136
        file_list (list[str]): List of file paths for the images.
137
        shm_params (SharedMemoryParams): Shared memory parameters for loading images.
138
        progress_queue (Queue): Queue for reporting progress.
139
        to_debayer (bool, optional): Whether to perform debayering. Defaults to False.
140
        stop_queue (Queue, optional): Queue for stopping the loading process. Defaults to None.
141
        log_queue (Queue, optional): Queue for logging messages. Defaults to None.
142

143
    Returns:
144
        None
145
    """
146
    handler = QueueHandler(log_queue)
147
    logger.log.addHandler(handler)
148
    logger.log.debug(f"Load worker started with {len(indexes)} images")
149
    logger.log.debug(f"Shared memory parameters: {shm_params}")
150
    try:
151
        imgs = np.memmap(shm_params.shm_name, dtype=PIXEL_TYPE, mode='r+', shape=shm_params.shm_shape)
152
        for img_idx in indexes:
153
            if stop_queue and not stop_queue.empty():
154
                logger.log.debug("Load images worker detected stop event. Stopping.")
155
                break
156
            img_data = load_image(file_list[img_idx], to_debayer)
157
            imgs[img_idx] = img_data
158
            progress_queue.put(img_idx)
159

160
    except Exception:
161
        logger.log.error(f"Load worker failed due to the following error:\n{traceback.format_exc()}")
162
        stop_queue.put("ERROR")
163
        raise
164

165

166
@measure_execution_time
167
def load_images(file_list: list[str], shm_params: SharedMemoryParams, to_debayer: bool = False,
168
                progress_bar: Optional[AbstractProgressBar] = None, stop_event: Optional[Event] = None):
169
    """
170
    Load images from the provided file list into shared memory using multiple workers.
171

172
    Args:
173
        file_list (list[str]): List of file paths for the images to load.
174
        shm_params (SharedMemoryParams): Shared memory parameters for loading images.
175
        to_debayer (bool, optional): Whether to perform debayering. Defaults to False.
176
        progress_bar (Optional[AbstractProgressBar], optional): Progress bar for tracking the loading progress.
177
            Defaults to None.
178
        stop_event (Optional[Event], optional): Event to signal stopping the loading process. Defaults to None.
179

180
    Returns:
181
        None
182
    """
183
    available_cpus = cpu_count() - 1
184
    used_cpus = min(available_cpus, len(file_list))
185
    logger.log.debug(f"Number of CPUs to be used for loading images: {used_cpus}")
186
    with (Pool(processes=used_cpus) as pool):
187
        m = Manager()
188
        progress_queue = m.Queue()
189
        log_queue = m.Queue()
190
        logger.start_process_listener(log_queue)
191
        stop_queue = m.Queue()
192
        logger.log.debug(f"Starting loading images with {used_cpus} workers")
193
        results = pool.map_async(
194
            partial(
195
                load_worker,
196
                file_list=file_list,
197
                shm_params=shm_params,
198
                to_debayer=to_debayer,
199
                progress_queue=progress_queue,
200
                stop_queue=stop_queue,
201
                log_queue=log_queue
202
            ),
203
            np.array_split(np.arange(len(file_list)), used_cpus))
204
        if progress_bar is not None:
205
            progress_bar.set_total(len(file_list))
206
            for _ in range(len(file_list)):
207
                if stop_event is not None and stop_event.is_set():
208
                    stop_queue.put(True)
209
                    logger.log.debug("Stop event triggered")
210
                    break
211

212
                got_result = False
213
                while not got_result:
214
                    if not progress_queue.empty():
215
                        progress_queue.get()
216
                        logger.log.debug("Got a result from the progress queue")
217
                        got_result = True
218
                    if not stop_queue.empty():
219
                        logger.log.debug("Detected error from workers. Stopping.")
220
                        break
221
                if not stop_queue.empty():
222
                    break
223
                progress_bar.update()
224
            progress_bar.complete()
225
        results.get()
226
        pool.close()
227
        pool.join()
228
        logger.log.debug(f"Load images pool stopped.")
229
        logger.stop_process_listener()
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.