6
from torch.types import Storage
8
__serialization_id_record_name__ = ".data/serialization_id"
13
def __init__(self, storage):
14
self._storage = storage
22
Class to allow PackageImporter to operate on unzipped packages. Methods
23
copy the behavior of the internal PyTorchFileReader class (which is used for
24
accessing packages in all other cases).
26
N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
27
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
30
def __init__(self, directory):
31
self.directory = directory
33
def get_record(self, name):
34
filename = f"{self.directory}/{name}"
35
with open(filename, "rb") as f:
38
def get_storage_from_record(self, name, numel, dtype):
39
filename = f"{self.directory}/{name}"
40
nbytes = torch._utils._element_size(dtype) * numel
41
storage = cast(Storage, torch.UntypedStorage)
42
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
44
def has_record(self, path):
45
full_path = os.path.join(self.directory, path)
46
return os.path.isfile(full_path)
52
for filename in glob(f"{self.directory}/**", recursive=True):
53
if not os.path.isdir(filename):
54
files.append(filename[len(self.directory) + 1 :])
60
if self.has_record(__serialization_id_record_name__):
61
return self.get_record(__serialization_id_record_name__)