CelestialSurveyor
108 строк · 4.5 Кб
1import numpy as np
2import traceback
3
4from auto_stretch.stretch import Stretch
5from functools import partial
6from logging.handlers import QueueHandler
7from multiprocessing import Queue, cpu_count, Pool, Manager
8from threading import Event
9from typing import Optional
10
11from backend.progress_bar import AbstractProgressBar
12from logger.logger import get_logger
13from backend.data_classes import SharedMemoryParams
14from backend.consuming_functions.measure_execution_time import measure_execution_time
15
16
17logger = get_logger()
18
19
20@measure_execution_time
21def stretch_images(shm_params: SharedMemoryParams, progress_bar: Optional[AbstractProgressBar] = None,
22stop_event: Optional[Event] = None) -> None:
23"""
24Stretch images stored in shared memory in parallel using multiprocessing.
25
26Args:
27shm_params (SharedMemoryParams): Shared memory parameters for the images.
28progress_bar (Optional[AbstractProgressBar]): Progress bar to track the stretching progress.
29stop_event (Optional[Event]): Event to stop the stretching process.
30
31Returns:
32None
33"""
34available_cpus = cpu_count() - 1
35frames_num = shm_params.shm_shape[0]
36used_cpus = min(available_cpus, frames_num)
37logger.log.debug(f"Number of CPUs to be used for loading images: {used_cpus}")
38with Pool(processes=used_cpus) as pool:
39m = Manager()
40progress_queue = m.Queue()
41stop_queue = m.Queue(maxsize=1)
42log_queue = m.Queue()
43logger.start_process_listener(log_queue)
44logger.log.debug(f"Starting stretching images with {used_cpus} workers")
45results = pool.map_async(
46partial(stretch_worker, shm_params=shm_params, progress_queue=progress_queue, stop_queue=stop_queue,
47log_queue=log_queue),
48np.array_split(np.arange(frames_num), used_cpus))
49if progress_bar is not None:
50progress_bar.set_total(frames_num)
51for _ in range(frames_num):
52if stop_event is not None and stop_event.is_set():
53stop_queue.put(True)
54logger.log.debug("Stop event triggered")
55break
56got_result = False
57while not got_result:
58if not progress_queue.empty():
59progress_queue.get()
60got_result = True
61logger.log.debug("Got a result from the progress queue")
62if not stop_queue.empty():
63logger.log.debug("Detected error from workers. Stopping.")
64break
65if not stop_queue.empty():
66break
67progress_bar.update()
68progress_bar.complete()
69results.get()
70pool.close()
71pool.join()
72logger.log.debug(f"Plate solve pool stopped.")
73logger.stop_process_listener()
74
75
76def stretch_worker(img_indexes: list[int], shm_params: SharedMemoryParams, progress_queue: Queue,
77stop_queue: Optional[Queue] = None, log_queue: Optional[Queue] = None) -> None:
78"""
79Worker function to stretch images with the provided indexes in shared memory.
80
81Args:
82img_indexes (list[int]): List of image indexes to stretch.
83shm_params (SharedMemoryParams): Shared memory parameters for images.
84progress_queue (Queue): Queue for reporting progress.
85stop_queue (Optional[Queue], optional): Queue for stopping the worker process. Defaults to None.
86log_queue (Optional[Queue], optional): Queue for logging messages. Defaults to None.
87
88Returns:
89None
90"""
91handler = QueueHandler(log_queue)
92logger.log.addHandler(handler)
93logger.log.debug(f"Load worker started with {len(img_indexes)} images")
94logger.log.debug(f"Shared memory parameters: {shm_params}")
95try:
96imgs = np.memmap(shm_params.shm_name, dtype=shm_params.shm_dtype, mode='r+', shape=shm_params.shm_shape)
97for img_idx in img_indexes:
98if stop_queue is not None and not stop_queue.empty():
99logger.log.debug("Plate solve worker detected stop event. Stopping.")
100break
101img = imgs[img_idx, shm_params.y_slice, shm_params.x_slice]
102imgs[img_idx, shm_params.y_slice, shm_params.x_slice] = Stretch().stretch(img)
103progress_queue.put(img_idx)
104imgs.flush()
105except Exception:
106logger.log.error(f"Stretch worker failed due to the following error:\n{traceback.format_exc()}")
107stop_queue.put("ERROR")
108raise
109