CelestialSurveyor

Форк
0
/
find_asteroids.py 
270 строк · 11.4 Кб
1
import cv2
2
import datetime
3
import h5py
4
import numpy as np
5
import os
6
import sys
7
import tensorflow as tf
8

9
from cryptography.fernet import Fernet
10
from io import BytesIO
11
from multiprocessing import cpu_count
12
from PIL import Image
13
from typing import Optional
14

15
from backend.consuming_functions.measure_execution_time import measure_execution_time
16
from backend.progress_bar import AbstractProgressBar
17
from backend.source_data_v2 import SourceDataV2, CHUNK_SIZE
18
from logger.logger import get_logger
19

20

21
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
22
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23

24
tf.config.threading.set_inter_op_parallelism_threads(cpu_count() - 1)
25

26

27
logger = get_logger()
28

29

30
@measure_execution_time
31
def predict_asteroids(source_data: SourceDataV2, progress_bar: Optional[AbstractProgressBar] = None,
32
                      model_path: Optional[str] = None) -> list[tuple[int, int, float]]:
33
    """
34
    Predict asteroids in the given source data using AI model.
35

36
    Args:
37
        source_data (SourceDataV2): The source data containing calibrated and aligned monochromatic images.
38
        progress_bar (Optional[AbstractProgressBar], optional): Progress bar for tracking prediction progress.
39
            Defaults to None.
40
        model_path (Optional[str], optional): Path to the AI model for prediction. If the path is not provided, the
41
            model with the highest version will be used. Defaults to None.
42

43
    Returns:
44
        List[Tuple[int, int, float]]: A list of tuples containing the coordinates and confidence level of predicted
45
            asteroids.
46
    """
47
    logger.log.info("Finding moving objects...")
48
    model_path = get_model_path() if not model_path else model_path
49
    logger.log.info(f"Loading model: {model_path}")
50
    model = decrypt_model(model_path)
51
    batch_size = 10
52
    logger.log.debug(f"Batch size: {batch_size}")
53
    chunk_generator = source_data.generate_image_chunks()
54
    batch_generator = source_data.generate_batch(chunk_generator, batch_size=batch_size)
55
    ys, xs = source_data.get_number_of_chunks()
56
    progress_bar_len = len(ys) * len(xs)
57
    progress_bar_len = progress_bar_len // batch_size + (1 if progress_bar_len % batch_size != 0 else 0)
58
    if progress_bar:
59
        progress_bar.set_total(progress_bar_len)
60
    objects_coords = []
61
    for coords, batch in batch_generator:
62
        if source_data.stop_event.is_set():
63
            break
64
        results = model.predict(batch, verbose=0)
65
        for res, (y, x) in zip(results, coords):
66
            if res > 0.8:
67
                objects_coords.append((y, x, res))
68
        if progress_bar:
69
            progress_bar.update()
70
    progress_bar.complete()
71
    return objects_coords
72

73

74
def save_results(source_data: SourceDataV2, results, output_folder) -> np.ndarray:
75
    """
76
    This function saves the results of the object recognition process. It marks the areas where the model located
77
    probable asteroids and creates the GIFs of these areas.
78

79
    Args:
80
        source_data (SourceDataV2): The source data object containing image data.
81
        results: The results of the object recognition process.
82
        output_folder (str): The folder where the results will be saved.
83
    """
84
    logger.log.info("Saving results...")
85
    max_image = np.copy(source_data.max_image) * 255.
86
    max_image = cv2.cvtColor(max_image, cv2.COLOR_GRAY2BGR)
87
    gif_size = 5
88
    processed = []
89
    for coord_num, (y, x, probability) in enumerate(results):
90
        probability = probability[0]
91
        color = (0, 255, 255) if probability <= 0.9 else (0, 255, 0)
92
        max_image = cv2.rectangle(max_image, (x, y), (x+CHUNK_SIZE, y+CHUNK_SIZE), color, 2)
93
        max_image = cv2.putText(max_image, "{:.2f}".format(probability), org=(x, y - 10),
94
                            fontFace=1, fontScale=1, color=(0, 0, 255), thickness=0)
95
        for y_pr, x_pr in processed:
96
            if x_pr - (gif_size // 2) * 64 <= x <= x_pr + (gif_size // 2) * 64 and \
97
                    y_pr - (gif_size // 2) * 64 <= y <= y_pr + (gif_size // 2) * 64:
98
                break
99
        else:
100
            processed.append((y, x))
101
            y_new, x_new, size = get_big_rectangle_coords(y, x, max_image.shape, gif_size)
102
            max_image = cv2.rectangle(max_image, (x_new, y_new), (x_new + size, y_new + size), (0, 0, 255), 4)
103
            max_image = cv2.putText(max_image, str(len(processed)), org=(x_new + 20, y_new + 60),
104
                            fontFace=1, fontScale=3, color=(0, 0, 255), thickness=2)
105
            frames = source_data.crop_image(
106
                source_data.images,
107
                (y_new, y_new + size),
108
                (x_new, x_new + size))
109
            frames = frames * 255
110
            new_shape = list(frames.shape)
111
            new_shape[1] += 20
112
            new_frames = np.zeros(new_shape)
113
            new_frames[:, :-20, :] = frames
114
            used_timestamps = [item.timestamp for item in source_data.headers]
115
            for frame, original_ts in zip(new_frames, used_timestamps):
116
                cv2.putText(frame, text=original_ts.strftime("%d/%m/%Y %H:%M:%S %Z"), org=(70, 64 * gif_size + 16),
117
                            fontFace=1, fontScale=1, color=(255, 255, 255), thickness=0)
118
            new_frames = [Image.fromarray(frame.reshape(frame.shape[0], frame.shape[1])).convert('L').convert('P') for frame in new_frames]
119
            new_frames[0].save(
120
                os.path.join(output_folder, f"{len(processed)}.gif"),
121
                save_all=True,
122
                append_images=new_frames[1:],
123
                duration=200,
124
                loop=0)
125
    cv2.imwrite(os.path.join(output_folder, "results.png"), max_image)
126
    return max_image
127

128

129
def annotate_results(source_data: SourceDataV2, img_to_be_annotated: np.ndarray, output_folder: str,
130
                     magnitude_limit: float) -> None:
131
    """
132
    Annotates the results on the input image. If there are known asteroids or comets within Field of View - they will
133
    be marked. Annotation will be done for the first timestamp of each imaging session.
134

135
    Args:
136
        source_data (SourceDataV2): The source data object containing image data.
137
        img_to_be_annotated (np.ndarray): The image to be annotated.
138
        output_folder (str): The folder where the annotated results will be saved.
139
        magnitude_limit (float): The magnitude limit for known asteroids.
140
    """
141
    logger.log.info("Annotating results...")
142
    start_session_frame_nums = [0]
143

144
    start_ts = source_data.headers[0].timestamp
145
    for num, header in enumerate(source_data.headers[1:], start=1):
146
        if header.timestamp - start_ts > datetime.timedelta(hours=12):
147
            start_session_frame_nums.append(num)
148
            start_ts = header.timestamp
149
    for num, start_frame_num in enumerate(start_session_frame_nums, start=1):
150
        logger.log.info(f"Fetching known objects for session number {num}")
151
        for obj_type in source_data.fetch_known_asteroids_for_image(start_frame_num, magnitude_limit=magnitude_limit):
152
            for item in obj_type:
153
                target_x, target_y = item.pixel_coordinates
154
                target_x = round(float(target_x))
155
                target_y = round(float(target_y))
156
                x = (target_x, target_x)
157
                if target_y < 50:
158
                    y = (target_y + 4, target_y + 14)
159
                else:
160
                    y = (target_y - 4, target_y - 14)
161
                img_to_be_annotated =cv2.line(
162
                    img_to_be_annotated, (x[0], y[0]), (x[1], y[1]), (0, 165, 255), 2)
163

164
                if target_y < 50:
165
                    text_y = target_y + 20 + 20
166
                else:
167
                    text_y = target_y - 20
168

169
                if target_x > source_data.shape[1] - 300:
170
                    text_x = target_x - 300
171
                else:
172
                    text_x = target_x
173
                img_to_be_annotated = cv2.putText(img_to_be_annotated, f"{item.name}: {item.magnitude}", org=(text_x, text_y),
174
                                        fontFace=0, fontScale=1, color=(0, 165, 255), thickness=2)
175

176
    cv2.imwrite(os.path.join(output_folder, "results_annotated.png"), img_to_be_annotated)
177

178

179
# Decrypt the model weights
180
def decrypt_model(encrypted_model_path: str,
181
                  key: Optional[bytes] = b'J17tdv3zz2nemLNwd17DV33-sQbo52vFzl2EOYgtScw=') -> tf.keras.Model:
182
    """
183
    Decrypts the model weights using the provided key.
184
    Preliminary version of model encryption. It was done when I was thinking to make this project open source or not.
185

186
    Args:
187
        encrypted_model_path (str): The path to the encrypted model file.
188
        key (bytes): The key used for decryption. Defaults to a preset key.
189

190
    Returns:
191
        tf.keras.Model: The decrypted model.
192
    """
193
    # Read the encrypted weights from the file
194
    with open(encrypted_model_path, "rb") as file:
195
        encrypted_model_data = file.read()
196

197
    # Use the provided key to create a cipher
198
    cipher = Fernet(key)
199

200
    # Decrypt the entire model
201
    decrypted_model_data = cipher.decrypt(encrypted_model_data)
202

203
    # Load the decrypted model directly into memory
204
    decrypted_model_data = BytesIO(decrypted_model_data)
205
    h = h5py.File(decrypted_model_data, 'r')
206
    loaded_model = tf.keras.models.load_model(h)
207

208
    return loaded_model
209

210

211
def get_model_path() -> str:
212
    """
213
    Get the path to the latest model file.
214

215
    Returns:
216
        str: The path to the latest model file.
217
    """
218
    model_dir = os.path.join(sys.path[1], "model")
219
    file_list = []
220
    if os.path.exists(model_dir):
221
        file_list = os.listdir(model_dir)
222
    models = [item for item in file_list if item.startswith('model') and item.endswith('bin')]
223
    model_nums = []
224
    for model in models:
225
        name, _ = model.split('.')
226
        num = name[5:]
227
        model_nums.append(int(num))
228
    if model_nums:
229
        model_num = max(model_nums)
230
        model_path = model_dir
231
    else:
232
        secondary_dir = os.path.split(model_dir)[0]
233
        file_list = os.listdir(secondary_dir)
234
        models = [item for item in file_list if item.startswith('model') and item.endswith('bin')]
235
        model_nums = []
236
        for model in models:
237
            name, _ = model.split('.')
238
            num = name[5:]
239
            model_nums.append(int(num))
240
        if model_nums:
241
            model_num = max(model_nums)
242
            model_path = secondary_dir
243
        else:
244
            raise Exception("AI model was not found.")
245
    model_path = os.path.join(model_path, f"model{model_num}.bin")
246
    return model_path
247

248

249
def get_big_rectangle_coords(y: int, x: int, image_shape: tuple, gif_size: int) -> tuple[int, int, int]:
250
    """
251
    Calculate the coordinates of a large rectangle based on the center coordinates and image properties.
252
    Large rectangle is used for the GIF animation.
253

254
    Args:
255
        y (int): The y-coordinate of the center point.
256
        x (int): The x-coordinate of the center point.
257
        image_shape (Tuple[int, int, int]): The shape of the image (height, width, channels).
258
        gif_size (int): The size of the GIF.
259

260
    Returns:
261
        Tuple[int, int, int]: The coordinates of the top-left corner of the rectangle (box_y, box_x)
262
        and the size of the rectangle.
263
    """
264
    size = CHUNK_SIZE
265
    box_x = 0 if x - size * (gif_size // 2) < 0 else x - size * (gif_size // 2)
266
    box_y = 0 if y - size * (gif_size // 2) < 0 else y - size * (gif_size // 2)
267
    image_size_y, image_size_x = image_shape[:2]
268
    box_x = image_size_x - size * gif_size if x + size * (gif_size // 2 + 1) > image_size_x else box_x
269
    box_y = image_size_y - size * gif_size if y + size * (gif_size // 2 + 1) > image_size_y else box_y
270
    return box_y, box_x, size * gif_size
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.