google-research

Форк
0
/
ncsn_lib.py 
552 строки · 21.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: skip-file
17
"""Training and evalution for score-based generative models."""
18

19
import functools
20
import gc
21
import io
22
import os
23
import time
24
from typing import Any
25

26
from . import datasets
27
from . import evaluation
28
from . import losses
29
from . import models  # Keep this import for registering all model definitions.
30
from . import sampling
31
from . import utils
32
from .models import utils as mutils
33
from absl import logging
34
import flax
35
import flax.deprecated.nn as nn
36
import flax.jax_utils as flax_utils
37
from flax.metrics import tensorboard
38
from flax.training import checkpoints
39
import jax
40
import jax.numpy as jnp
41
import ml_collections
42
from .models import ddpm, ncsnv2, ncsnv3
43
import numpy as np
44
import tensorflow as tf
45
import tensorflow_gan as tfgan
46

47

48
def train(config, workdir):
49
  """Runs a training loop.
50

51
  Args:
52
    config: Configuration to use.
53
    workdir: Working directory for checkpoints and TF summaries. If this
54
      contains checkpoint training will be resumed from the latest checkpoint.
55
  """
56

57
  # Create directories for experimental logs
58
  tf.io.gfile.makedirs(workdir)
59
  sample_dir = os.path.join(workdir, "samples")
60
  tf.io.gfile.makedirs(sample_dir)
61
  rng = jax.random.PRNGKey(config.seed)
62
  tb_dir = os.path.join(workdir, "tensorboard")
63
  tf.io.gfile.makedirs(tb_dir)
64
  if jax.host_id() == 0:
65
    writer = tensorboard.SummaryWriter(tb_dir)
66

67
  # Initialize model.
68
  rng, model_rng = jax.random.split(rng)
69
  model_name = config.model.name
70
  ncsn_def = mutils.get_model(model_name).partial(config=config)
71
  rng, run_rng = jax.random.split(rng)
72
  # Whether the generative model is conditioned on class labels
73
  class_conditional = "conditional" in config.training.loss.lower()
74
  with nn.stateful() as init_model_state:
75
    with nn.stochastic(run_rng):
76
      input_shape = (jax.local_device_count(), config.data.image_size,
77
                     config.data.image_size, 3)
78
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
79
      if class_conditional:
80
        input_list.append(input_list[-1])
81
      _, initial_params = ncsn_def.init_by_shape(
82
          model_rng, input_list, train=True)
83
      ncsn = nn.Model(ncsn_def, initial_params)
84

85
  optimizer = losses.get_optimizer(config).create(ncsn)
86

87
  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
88
                       model_state=init_model_state,
89
                       ema_rate=config.model.ema_rate,
90
                       params_ema=initial_params,
91
                       rng=rng)  # pytype: disable=wrong-keyword-args
92

93
  del ncsn, init_model_state  # Do not keep a copy of the initial model.
94

95
  # Create checkpoints directory and the initial checkpoint
96
  checkpoint_dir = os.path.join(workdir, "checkpoints")
97
  ckpt = utils.Checkpoint(
98
      checkpoint_dir,
99
      max_to_keep=None)
100
  ckpt.restore_or_initialize(state)
101

102
  # Save intermediate checkpoints to resume training automatically
103
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
104
  ckpt_meta = utils.Checkpoint(
105
      checkpoint_meta_dir,
106
      max_to_keep=1)
107
  state = ckpt_meta.restore_or_initialize(state)
108
  initial_step = int(state.step)
109
  rng = state.rng
110

111
  # Build input pipeline.
112
  rng, ds_rng = jax.random.split(rng)
113
  train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
114
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
115
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
116
  scaler = datasets.get_data_scaler(config)  # data normalizer
117
  inverse_scaler = datasets.get_data_inverse_scaler(config)
118

119
  # Distribute training.
120
  optimize_fn = losses.optimization_manager(config)
121
  if config.training.loss.lower() == "ddpm":
122
    # Use score matching loss with DDPM-type perturbation.
123
    ddpm_params = mutils.get_ddpm_params()
124
    train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
125
                                   train=True, optimize_fn=optimize_fn)
126
    eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
127
                                  train=False)
128
  else:
129
    # Use score matching loss with NCSN-type perturbation.
130
    sigmas = mutils.get_sigmas(config)
131
    # Whether to use a continuous distribution of noise levels
132
    continuous = "continuous" in config.training.loss.lower()
133
    train_step = functools.partial(
134
        losses.ncsn_loss,
135
        sigmas=sigmas,
136
        class_conditional=class_conditional,
137
        continuous=continuous,
138
        train=True,
139
        optimize_fn=optimize_fn,
140
        anneal_power=config.training.anneal_power)
141
    eval_step = functools.partial(
142
        losses.ncsn_loss,
143
        sigmas=sigmas,
144
        class_conditional=class_conditional,
145
        continuous=continuous,
146
        train=False,
147
        anneal_power=config.training.anneal_power)
148

149
  p_train_step = jax.pmap(train_step, axis_name="batch")
150
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
151
  state = flax_utils.replicate(state)
152

153
  num_train_steps = config.training.n_iters
154

155
  logging.info("Starting training loop at step %d.", initial_step)
156
  rng = jax.random.fold_in(rng, jax.host_id())
157
  for step in range(initial_step, num_train_steps + 1):
158
    # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
159
    # devices.
160

161
    # Convert data to JAX arrays. Use ._numpy() to avoid copy.
162
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access
163

164
    rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
165
    next_rng = jnp.asarray(next_rng)
166
    loss, state = p_train_step(next_rng, state, batch)
167
    loss = flax.jax_utils.unreplicate(loss)
168

169
    # Quick indication that training is happening.
170
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
171

172
    if jax.host_id() == 0 and step % 50 == 0:
173
      logging.info("step: %d, training_loss: %.5e", step, loss)
174
      writer.scalar("training_loss", loss, step)
175

176
    # Save a temporary checkpoint to resume training after pre-emption.
177
    if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
178
    ) == 0:
179
      saved_state = flax_utils.unreplicate(state)
180
      saved_state = saved_state.replace(rng=rng)
181
      ckpt_meta.save(saved_state)
182

183
    # Report the loss on an evaluation dataset.
184
    if step % 100 == 0:
185
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
186
      next_rng = jnp.asarray(next_rng)
187
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
188
      eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
189
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
190
      if jax.host_id() == 0:
191
        logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
192
        writer.scalar("eval_loss", eval_loss, step)
193

194
    # Save a checkpoint periodically and generate samples.
195
    if (step +
196
        1) % config.training.snapshot_freq == 0 or step == num_train_steps:
197
      # Save the checkpoint.
198
      if jax.host_id() == 0:
199
        saved_state = flax_utils.unreplicate(state)
200
        saved_state = saved_state.replace(rng=rng)
201
        ckpt.save(saved_state)
202

203
      # Generate and save samples
204
      if config.training.snapshot_sampling:
205
        rng, sample_rng = jax.random.split(rng)
206
        init_shape = tuple(train_ds.element_spec["image"].shape)
207
        samples = sampling.get_samples(sample_rng,
208
                                       config,
209
                                       flax_utils.unreplicate(state),
210
                                       init_shape,
211
                                       scaler,
212
                                       inverse_scaler,
213
                                       class_conditional=class_conditional)
214
        this_sample_dir = os.path.join(
215
            sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
216
        tf.io.gfile.makedirs(this_sample_dir)
217

218
        if config.sampling.final_only:  # Do not save intermediate samples
219
          sample = samples[-1]
220
          image_grid = sample.reshape((-1, *sample.shape[2:]))
221
          nrow = int(np.sqrt(image_grid.shape[0]))
222
          sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
223
          with tf.io.gfile.GFile(
224
              os.path.join(this_sample_dir, "sample.np"), "wb") as fout:
225
            np.save(fout, sample)
226

227
          with tf.io.gfile.GFile(
228
              os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
229
            utils.save_image(image_grid, fout, nrow=nrow, padding=2)
230
        else:  # Save all intermediate samples produced during sampling.
231
          for i, sample in enumerate(samples):
232
            image_grid = sample.reshape((-1, *sample.shape[2:]))
233
            nrow = int(np.sqrt(image_grid.shape[0]))
234
            sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
235
            with tf.io.gfile.GFile(
236
                os.path.join(this_sample_dir, "sample_{}.np".format(i)),
237
                "wb") as fout:
238
              np.save(fout, sample)
239

240
            with tf.io.gfile.GFile(
241
                os.path.join(this_sample_dir, "sample_{}.png".format(i)),
242
                "wb") as fout:
243
              utils.save_image(image_grid, fout, nrow=nrow, padding=2)
244

245

246
def evaluate(config,
247
             workdir,
248
             eval_folder = "eval"):
249
  """Evaluate trained models.
250

251
  Args:
252
    config: Configuration to use.
253
    workdir: Working directory for checkpoints.
254
    eval_folder: The subfolder for storing evaluation results. Default to
255
      "eval".
256
  """
257
  # Create eval_dir
258
  eval_dir = os.path.join(workdir, eval_folder)
259
  tf.io.gfile.makedirs(eval_dir)
260

261
  rng = jax.random.PRNGKey(config.seed + 1)
262

263
  # Build input pipeline.
264
  rng, ds_rng = jax.random.split(rng)
265
  _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True)
266
  scaler = datasets.get_data_scaler(config)
267
  inverse_scaler = datasets.get_data_inverse_scaler(config)
268

269
  # Initialize model.
270
  rng, model_rng = jax.random.split(rng)
271
  model_name = config.model.name
272
  ncsn_def = mutils.get_model(model_name).partial(config=config)
273
  rng, run_rng = jax.random.split(rng)
274
  class_conditional = "conditional" in config.training.loss.lower()
275
  with nn.stateful() as init_model_state:
276
    with nn.stochastic(run_rng):
277
      input_shape = tuple(eval_ds.element_spec["image"].shape[1:])
278
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
279
      if class_conditional:
280
        input_list.append(input_list[-1])
281
      _, initial_params = ncsn_def.init_by_shape(
282
          model_rng, input_list, train=True)
283
      ncsn = nn.Model(ncsn_def, initial_params)
284

285
  optimizer = losses.get_optimizer(config).create(ncsn)
286
  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
287
                       model_state=init_model_state,
288
                       ema_rate=config.model.ema_rate,
289
                       params_ema=initial_params,
290
                       rng=rng)  # pytype: disable=wrong-keyword-args
291

292
  del ncsn, init_model_state  # Do not keep a copy of the initial model.
293

294
  checkpoint_dir = os.path.join(workdir, "checkpoints")
295
  if config.training.loss.lower() == "ddpm":
296
    # Use the score matching loss with DDPM-type perturbation.
297
    ddpm_params = mutils.get_ddpm_params()
298
    eval_step = functools.partial(
299
        losses.ddpm_loss, ddpm_params=ddpm_params, train=False)
300
  else:
301
    # Use the score matching loss with NCSN-type perturbation.
302
    sigmas = mutils.get_sigmas(config)
303
    continuous = "continuous" in config.training.loss.lower()
304
    eval_step = functools.partial(
305
        losses.ncsn_loss,
306
        sigmas=sigmas,
307
        continuous=continuous,
308
        class_conditional=class_conditional,
309
        train=False,
310
        anneal_power=config.training.anneal_power)
311

312
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
313

314
  rng = jax.random.fold_in(rng, jax.host_id())
315

316
  # A data class for checkpointing.
317
  @flax.struct.dataclass
318
  class EvalMeta:
319
    ckpt_id: int
320
    round_id: int
321
    rng: Any
322

323
  # Add one additional round to get the exact number of samples as required.
324
  num_rounds = config.eval.num_samples // config.eval.batch_size + 1
325

326
  eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng)
327
  eval_meta = checkpoints.restore_checkpoint(
328
      eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_")
329

330
  if eval_meta.round_id < num_rounds - 1:
331
    begin_ckpt = eval_meta.ckpt_id
332
    begin_round = eval_meta.round_id + 1
333
  else:
334
    begin_ckpt = eval_meta.ckpt_id + 1
335
    begin_round = 0
336

337
  rng = eval_meta.rng
338
  # Use inceptionV3 for images with higher resolution
339
  inceptionv3 = config.data.image_size >= 256
340
  inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)
341

342
  logging.info("begin checkpoint: %d", begin_ckpt)
343
  for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
344
    ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt))
345

346
    # Wait if the target checkpoint hasn't been produced yet.
347
    waiting_message_printed = False
348
    while not tf.io.gfile.exists(ckpt_filename):
349
      if not waiting_message_printed and jax.host_id() == 0:
350
        logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt)
351
        waiting_message_printed = True
352
      time.sleep(10)
353

354
    # In case the file was just written and not ready to read from yet.
355
    try:
356
      state = utils.load_state_dict(ckpt_filename, state)
357
    except:
358
      time.sleep(60)
359
      try:
360
        state = utils.load_state_dict(ckpt_filename, state)
361
      except:
362
        time.sleep(120)
363
        state = utils.load_state_dict(ckpt_filename, state)
364

365
    pstate = flax.jax_utils.replicate(state)
366
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
367

368
    # Compute the loss function on the full evaluation dataset.
369
    all_losses = []
370
    for i, batch in enumerate(eval_iter):
371
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
372
      next_rng = jnp.asarray(next_rng)
373
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
374
      eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch)
375
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
376
      all_losses.append(eval_loss)
377
      if (i + 1) % 1000 == 0 and jax.host_id() == 0:
378
        logging.info("Finished %dth step loss evaluation", i + 1)
379

380
    all_losses = jnp.asarray(all_losses)
381

382
    state = jax.device_put(state)
383
    # Sampling and computing statistics for Inception scores, FIDs, and KIDs.
384
    # Designed to be pre-emption safe. Automatically resumes when interrupted.
385
    for r in range(begin_round, num_rounds):
386
      if jax.host_id() == 0:
387
        logging.info("sampling -- ckpt: %d, round: %d", ckpt, r)
388
      rng, sample_rng = jax.random.split(rng)
389
      init_shape = tuple(eval_ds.element_spec["image"].shape)
390

391
      this_sample_dir = os.path.join(
392
          eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
393
      tf.io.gfile.makedirs(this_sample_dir)
394
      samples = sampling.get_samples(sample_rng, config, state, init_shape,
395
                                     scaler, inverse_scaler,
396
                                     class_conditional=class_conditional)
397
      samples = samples[-1]
398
      samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
399
      samples = samples.reshape(
400
          (-1, config.data.image_size, config.data.image_size, 3))
401
      with tf.io.gfile.GFile(
402
          os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout:
403
        io_buffer = io.BytesIO()
404
        np.savez_compressed(io_buffer, samples=samples)
405
        fout.write(io_buffer.getvalue())
406

407
      gc.collect()
408
      latents = evaluation.run_inception_distributed(samples, inception_model,
409
                                                     inceptionv3=inceptionv3)
410
      gc.collect()
411
      with tf.io.gfile.GFile(
412
          os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout:
413
        io_buffer = io.BytesIO()
414
        np.savez_compressed(
415
            io_buffer, pool_3=latents["pool_3"], logits=latents["logits"])
416
        fout.write(io_buffer.getvalue())
417

418
      eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng)
419
      # Save an intermediate checkpoint directly if not the last round.
420
      # Otherwise save eval_meta after computing the Inception scores and FIDs
421
      if r < num_rounds - 1:
422
        checkpoints.save_checkpoint(
423
            eval_dir,
424
            eval_meta,
425
            step=ckpt * num_rounds + r,
426
            keep=1,
427
            prefix=f"meta_{jax.host_id()}_")
428

429
    # Compute inception scores, FIDs and KIDs.
430
    if jax.host_id() == 0:
431
      # Load all statistics that have been previously computed and saved.
432
      all_logits = []
433
      all_pools = []
434
      for host in range(jax.host_count()):
435
        this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}")
436

437
        stats = tf.io.gfile.glob(
438
            os.path.join(this_sample_dir, "statistics_*.npz"))
439
        wait_message = False
440
        while len(stats) < num_rounds:
441
          if not wait_message:
442
            logging.warn("Waiting for statistics on host %d", host)
443
            wait_message = True
444
          stats = tf.io.gfile.glob(
445
              os.path.join(this_sample_dir, "statistics_*.npz"))
446
          time.sleep(1)
447

448
        for stat_file in stats:
449
          with tf.io.gfile.GFile(stat_file, "rb") as fin:
450
            stat = np.load(fin)
451
            if not inceptionv3:
452
              all_logits.append(stat["logits"])
453
            all_pools.append(stat["pool_3"])
454

455
      if not inceptionv3:
456
        all_logits = np.concatenate(
457
            all_logits, axis=0)[:config.eval.num_samples]
458
      all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]
459

460
      # Load pre-computed dataset statistics.
461
      data_stats = evaluation.load_dataset_stats(config)
462
      data_pools = data_stats["pool_3"]
463

464
      if hasattr(config.eval, "num_partitions"):
465
        # Divide samples into several partitions and compute FID/KID/IS on them.
466
        assert not inceptionv3
467
        fids = []
468
        kids = []
469
        inception_scores = []
470
        partition_size = config.eval.num_samples // config.eval.num_partitions
471
        tf_data_pools = tf.convert_to_tensor(data_pools)
472
        for i in range(config.eval.num_partitions):
473
          this_pools = all_pools[i * partition_size:(i + 1) * partition_size]
474
          this_logits = all_logits[i * partition_size:(i + 1) * partition_size]
475
          inception_scores.append(
476
              tfgan.eval.classifier_score_from_logits(this_logits))
477
          fids.append(
478
              tfgan.eval.frechet_classifier_distance_from_activations(
479
                  data_pools, this_pools))
480
          this_pools = tf.convert_to_tensor(this_pools)
481
          kids.append(
482
              tfgan.eval.kernel_classifier_distance_from_activations(
483
                  tf_data_pools, this_pools).numpy())
484

485
        fids = np.asarray(fids)
486
        inception_scores = np.asarray(inception_scores)
487
        kids = np.asarray(kids)
488
        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"),
489
                               "wb") as f:
490
          io_buffer = io.BytesIO()
491
          np.savez_compressed(
492
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
493
              ISs=inception_scores, fids=fids, kids=kids)
494
          f.write(io_buffer.getvalue())
495

496
      else:
497
        # Compute FID/KID/IS on all samples together.
498
        if not inceptionv3:
499
          inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
500
        else:
501
          inception_score = -1
502

503
        fid = tfgan.eval.frechet_classifier_distance_from_activations(
504
            data_pools, all_pools)
505
        # Hack to get tfgan KID work for eager execution.
506
        tf_data_pools = tf.convert_to_tensor(data_pools)
507
        tf_all_pools = tf.convert_to_tensor(all_pools)
508
        kid = tfgan.eval.kernel_classifier_distance_from_activations(
509
            tf_data_pools, tf_all_pools).numpy()
510
        del tf_data_pools, tf_all_pools
511

512
        logging.info(
513
            "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e",
514
            ckpt, all_losses.mean(), inception_score, fid, kid)
515

516
        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"),
517
                               "wb") as f:
518
          io_buffer = io.BytesIO()
519
          np.savez_compressed(
520
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
521
              IS=inception_score, fid=fid, kid=kid)
522
          f.write(io_buffer.getvalue())
523
    else:
524
      # For host_id() != 0.
525
      # Use file existence to emulate synchronization across hosts.
526
      if hasattr(config.eval, "num_partitions"):
527
        assert not inceptionv3
528
        while not tf.io.gfile.exists(
529
            os.path.join(eval_dir, f"report_all_{ckpt}.npz")):
530
          time.sleep(1.)
531

532
      else:
533
        while not tf.io.gfile.exists(
534
            os.path.join(eval_dir, f"report_{ckpt}.npz")):
535
          time.sleep(1.)
536

537
    # Save eval_meta after computing IS/KID/FID to mark the end of evaluation
538
    # for this checkpoint.
539
    checkpoints.save_checkpoint(
540
        eval_dir,
541
        eval_meta,
542
        step=ckpt * num_rounds + r,
543
        keep=1,
544
        prefix=f"meta_{jax.host_id()}_")
545

546
    begin_round = 0
547

548
  # Remove all meta files after finishing evaluation.
549
  meta_files = tf.io.gfile.glob(
550
      os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
551
  for file in meta_files:
552
    tf.io.gfile.remove(file)
553

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.