pytorch

Форк
0
/
_directory_reader.py 
63 строки · 1.8 Кб
1
import os.path
2
from glob import glob
3
from typing import cast
4

5
import torch
6
from torch.types import Storage
7

8
__serialization_id_record_name__ = ".data/serialization_id"
9

10

11
# because get_storage_from_record returns a tensor!?
12
class _HasStorage:
13
    def __init__(self, storage):
14
        self._storage = storage
15

16
    def storage(self):
17
        return self._storage
18

19

20
class DirectoryReader:
21
    """
22
    Class to allow PackageImporter to operate on unzipped packages. Methods
23
    copy the behavior of the internal PyTorchFileReader class (which is used for
24
    accessing packages in all other cases).
25

26
    N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
27
    class due to ScriptObjects requiring an actual PyTorchFileReader instance.
28
    """
29

30
    def __init__(self, directory):
31
        self.directory = directory
32

33
    def get_record(self, name):
34
        filename = f"{self.directory}/{name}"
35
        with open(filename, "rb") as f:
36
            return f.read()
37

38
    def get_storage_from_record(self, name, numel, dtype):
39
        filename = f"{self.directory}/{name}"
40
        nbytes = torch._utils._element_size(dtype) * numel
41
        storage = cast(Storage, torch.UntypedStorage)
42
        return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
43

44
    def has_record(self, path):
45
        full_path = os.path.join(self.directory, path)
46
        return os.path.isfile(full_path)
47

48
    def get_all_records(
49
        self,
50
    ):
51
        files = []
52
        for filename in glob(f"{self.directory}/**", recursive=True):
53
            if not os.path.isdir(filename):
54
                files.append(filename[len(self.directory) + 1 :])
55
        return files
56

57
    def serialization_id(
58
        self,
59
    ):
60
        if self.has_record(__serialization_id_record_name__):
61
            return self.get_record(__serialization_id_record_name__)
62
        else:
63
            return ""
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.