CelestialSurveyor
270 строк · 11.4 Кб
1import cv22import datetime3import h5py4import numpy as np5import os6import sys7import tensorflow as tf8
9from cryptography.fernet import Fernet10from io import BytesIO11from multiprocessing import cpu_count12from PIL import Image13from typing import Optional14
15from backend.consuming_functions.measure_execution_time import measure_execution_time16from backend.progress_bar import AbstractProgressBar17from backend.source_data_v2 import SourceDataV2, CHUNK_SIZE18from logger.logger import get_logger19
20
21os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"22os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'23
24tf.config.threading.set_inter_op_parallelism_threads(cpu_count() - 1)25
26
27logger = get_logger()28
29
30@measure_execution_time
31def predict_asteroids(source_data: SourceDataV2, progress_bar: Optional[AbstractProgressBar] = None,32model_path: Optional[str] = None) -> list[tuple[int, int, float]]:33"""34Predict asteroids in the given source data using AI model.
35
36Args:
37source_data (SourceDataV2): The source data containing calibrated and aligned monochromatic images.
38progress_bar (Optional[AbstractProgressBar], optional): Progress bar for tracking prediction progress.
39Defaults to None.
40model_path (Optional[str], optional): Path to the AI model for prediction. If the path is not provided, the
41model with the highest version will be used. Defaults to None.
42
43Returns:
44List[Tuple[int, int, float]]: A list of tuples containing the coordinates and confidence level of predicted
45asteroids.
46"""
47logger.log.info("Finding moving objects...")48model_path = get_model_path() if not model_path else model_path49logger.log.info(f"Loading model: {model_path}")50model = decrypt_model(model_path)51batch_size = 1052logger.log.debug(f"Batch size: {batch_size}")53chunk_generator = source_data.generate_image_chunks()54batch_generator = source_data.generate_batch(chunk_generator, batch_size=batch_size)55ys, xs = source_data.get_number_of_chunks()56progress_bar_len = len(ys) * len(xs)57progress_bar_len = progress_bar_len // batch_size + (1 if progress_bar_len % batch_size != 0 else 0)58if progress_bar:59progress_bar.set_total(progress_bar_len)60objects_coords = []61for coords, batch in batch_generator:62if source_data.stop_event.is_set():63break64results = model.predict(batch, verbose=0)65for res, (y, x) in zip(results, coords):66if res > 0.8:67objects_coords.append((y, x, res))68if progress_bar:69progress_bar.update()70progress_bar.complete()71return objects_coords72
73
74def save_results(source_data: SourceDataV2, results, output_folder) -> np.ndarray:75"""76This function saves the results of the object recognition process. It marks the areas where the model located
77probable asteroids and creates the GIFs of these areas.
78
79Args:
80source_data (SourceDataV2): The source data object containing image data.
81results: The results of the object recognition process.
82output_folder (str): The folder where the results will be saved.
83"""
84logger.log.info("Saving results...")85max_image = np.copy(source_data.max_image) * 255.86max_image = cv2.cvtColor(max_image, cv2.COLOR_GRAY2BGR)87gif_size = 588processed = []89for coord_num, (y, x, probability) in enumerate(results):90probability = probability[0]91color = (0, 255, 255) if probability <= 0.9 else (0, 255, 0)92max_image = cv2.rectangle(max_image, (x, y), (x+CHUNK_SIZE, y+CHUNK_SIZE), color, 2)93max_image = cv2.putText(max_image, "{:.2f}".format(probability), org=(x, y - 10),94fontFace=1, fontScale=1, color=(0, 0, 255), thickness=0)95for y_pr, x_pr in processed:96if x_pr - (gif_size // 2) * 64 <= x <= x_pr + (gif_size // 2) * 64 and \97y_pr - (gif_size // 2) * 64 <= y <= y_pr + (gif_size // 2) * 64:98break99else:100processed.append((y, x))101y_new, x_new, size = get_big_rectangle_coords(y, x, max_image.shape, gif_size)102max_image = cv2.rectangle(max_image, (x_new, y_new), (x_new + size, y_new + size), (0, 0, 255), 4)103max_image = cv2.putText(max_image, str(len(processed)), org=(x_new + 20, y_new + 60),104fontFace=1, fontScale=3, color=(0, 0, 255), thickness=2)105frames = source_data.crop_image(106source_data.images,107(y_new, y_new + size),108(x_new, x_new + size))109frames = frames * 255110new_shape = list(frames.shape)111new_shape[1] += 20112new_frames = np.zeros(new_shape)113new_frames[:, :-20, :] = frames114used_timestamps = [item.timestamp for item in source_data.headers]115for frame, original_ts in zip(new_frames, used_timestamps):116cv2.putText(frame, text=original_ts.strftime("%d/%m/%Y %H:%M:%S %Z"), org=(70, 64 * gif_size + 16),117fontFace=1, fontScale=1, color=(255, 255, 255), thickness=0)118new_frames = [Image.fromarray(frame.reshape(frame.shape[0], frame.shape[1])).convert('L').convert('P') for frame in new_frames]119new_frames[0].save(120os.path.join(output_folder, f"{len(processed)}.gif"),121save_all=True,122append_images=new_frames[1:],123duration=200,124loop=0)125cv2.imwrite(os.path.join(output_folder, "results.png"), max_image)126return max_image127
128
129def annotate_results(source_data: SourceDataV2, img_to_be_annotated: np.ndarray, output_folder: str,130magnitude_limit: float) -> None:131"""132Annotates the results on the input image. If there are known asteroids or comets within Field of View - they will
133be marked. Annotation will be done for the first timestamp of each imaging session.
134
135Args:
136source_data (SourceDataV2): The source data object containing image data.
137img_to_be_annotated (np.ndarray): The image to be annotated.
138output_folder (str): The folder where the annotated results will be saved.
139magnitude_limit (float): The magnitude limit for known asteroids.
140"""
141logger.log.info("Annotating results...")142start_session_frame_nums = [0]143
144start_ts = source_data.headers[0].timestamp145for num, header in enumerate(source_data.headers[1:], start=1):146if header.timestamp - start_ts > datetime.timedelta(hours=12):147start_session_frame_nums.append(num)148start_ts = header.timestamp149for num, start_frame_num in enumerate(start_session_frame_nums, start=1):150logger.log.info(f"Fetching known objects for session number {num}")151for obj_type in source_data.fetch_known_asteroids_for_image(start_frame_num, magnitude_limit=magnitude_limit):152for item in obj_type:153target_x, target_y = item.pixel_coordinates154target_x = round(float(target_x))155target_y = round(float(target_y))156x = (target_x, target_x)157if target_y < 50:158y = (target_y + 4, target_y + 14)159else:160y = (target_y - 4, target_y - 14)161img_to_be_annotated =cv2.line(162img_to_be_annotated, (x[0], y[0]), (x[1], y[1]), (0, 165, 255), 2)163
164if target_y < 50:165text_y = target_y + 20 + 20166else:167text_y = target_y - 20168
169if target_x > source_data.shape[1] - 300:170text_x = target_x - 300171else:172text_x = target_x173img_to_be_annotated = cv2.putText(img_to_be_annotated, f"{item.name}: {item.magnitude}", org=(text_x, text_y),174fontFace=0, fontScale=1, color=(0, 165, 255), thickness=2)175
176cv2.imwrite(os.path.join(output_folder, "results_annotated.png"), img_to_be_annotated)177
178
179# Decrypt the model weights
180def decrypt_model(encrypted_model_path: str,181key: Optional[bytes] = b'J17tdv3zz2nemLNwd17DV33-sQbo52vFzl2EOYgtScw=') -> tf.keras.Model:182"""183Decrypts the model weights using the provided key.
184Preliminary version of model encryption. It was done when I was thinking to make this project open source or not.
185
186Args:
187encrypted_model_path (str): The path to the encrypted model file.
188key (bytes): The key used for decryption. Defaults to a preset key.
189
190Returns:
191tf.keras.Model: The decrypted model.
192"""
193# Read the encrypted weights from the file194with open(encrypted_model_path, "rb") as file:195encrypted_model_data = file.read()196
197# Use the provided key to create a cipher198cipher = Fernet(key)199
200# Decrypt the entire model201decrypted_model_data = cipher.decrypt(encrypted_model_data)202
203# Load the decrypted model directly into memory204decrypted_model_data = BytesIO(decrypted_model_data)205h = h5py.File(decrypted_model_data, 'r')206loaded_model = tf.keras.models.load_model(h)207
208return loaded_model209
210
211def get_model_path() -> str:212"""213Get the path to the latest model file.
214
215Returns:
216str: The path to the latest model file.
217"""
218model_dir = os.path.join(sys.path[1], "model")219file_list = []220if os.path.exists(model_dir):221file_list = os.listdir(model_dir)222models = [item for item in file_list if item.startswith('model') and item.endswith('bin')]223model_nums = []224for model in models:225name, _ = model.split('.')226num = name[5:]227model_nums.append(int(num))228if model_nums:229model_num = max(model_nums)230model_path = model_dir231else:232secondary_dir = os.path.split(model_dir)[0]233file_list = os.listdir(secondary_dir)234models = [item for item in file_list if item.startswith('model') and item.endswith('bin')]235model_nums = []236for model in models:237name, _ = model.split('.')238num = name[5:]239model_nums.append(int(num))240if model_nums:241model_num = max(model_nums)242model_path = secondary_dir243else:244raise Exception("AI model was not found.")245model_path = os.path.join(model_path, f"model{model_num}.bin")246return model_path247
248
249def get_big_rectangle_coords(y: int, x: int, image_shape: tuple, gif_size: int) -> tuple[int, int, int]:250"""251Calculate the coordinates of a large rectangle based on the center coordinates and image properties.
252Large rectangle is used for the GIF animation.
253
254Args:
255y (int): The y-coordinate of the center point.
256x (int): The x-coordinate of the center point.
257image_shape (Tuple[int, int, int]): The shape of the image (height, width, channels).
258gif_size (int): The size of the GIF.
259
260Returns:
261Tuple[int, int, int]: The coordinates of the top-left corner of the rectangle (box_y, box_x)
262and the size of the rectangle.
263"""
264size = CHUNK_SIZE265box_x = 0 if x - size * (gif_size // 2) < 0 else x - size * (gif_size // 2)266box_y = 0 if y - size * (gif_size // 2) < 0 else y - size * (gif_size // 2)267image_size_y, image_size_x = image_shape[:2]268box_x = image_size_x - size * gif_size if x + size * (gif_size // 2 + 1) > image_size_x else box_x269box_y = image_size_y - size * gif_size if y + size * (gif_size // 2 + 1) > image_size_y else box_y270return box_y, box_x, size * gif_size271