russian_art_2024
/
make_submission.py
52 строки · 1.7 Кб
1"""Script for generation competition's submission.
2All constants remains with competition rules (cpu usage)
3"""
4
5import numpy as np6import torch7from torchvision import transforms8
9from src.initial_model_utils import init_model10from src.train_utils import ArtDataset11
12MODEL_WEIGHTS = "./data/weights/resnet50_tl_68.pt"13TEST_DATASET = "./data/test/"14SUBMISSION_PATH = "./data/submission.csv"15
16if __name__ == "__main__":17device = torch.device("cpu")18model = init_model(device, num_classes=35)19model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location=device))20model.eval()21
22img_size = 22423trans = transforms.Compose(24[25transforms.Resize(size=(img_size, img_size)),26transforms.Lambda(lambda x: np.array(x, dtype="float32") / 255),27transforms.ToTensor(),28transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),29]30)31
32ard_data = ArtDataset(TEST_DATASET, transform=trans)33batch_size = 1634num_workers = 435testloader = torch.utils.data.DataLoader(36ard_data, batch_size=batch_size, shuffle=False, num_workers=num_workers37)38
39all_image_names = [item.split("/")[-1] for item in ard_data.files]40all_preds = []41
42with torch.no_grad():43for idx, (images, _) in enumerate(testloader, 0):44images = images.to(device)45outputs = model(images)46_, preds = torch.max(outputs, 1)47all_preds.extend(preds.cpu().numpy().tolist())48
49with open(SUBMISSION_PATH, "w") as f:50f.write("image_name\tlabel_id\n")51for name, cl_id in zip(all_image_names, all_preds):52f.write(f"{name}\t{cl_id}\n")53