google-research

Форк
0
384 строки · 13.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""DCA run and evaluation.
17

18
Runs DCA with the specified input parameters and evaluates it.
19

20
This is weird 'hack' that launches a hyperparameter grid locally
21
because our grid is so large that we cannot launch one configuration
22
per machine.
23
"""
24

25
import collections
26
import csv
27
import itertools
28
import os
29
import tempfile
30

31
from absl import app
32
from absl import flags
33
from absl import logging
34
import anndata
35
import numpy as np
36
import pandas as pd
37
import scanpy.api as sc
38
import sklearn.cluster
39
import sklearn.metrics
40
import tensorflow as tf
41

42
FLAGS = flags.FLAGS
43

44
flags.DEFINE_string('input_path', None,
45
                    'Path to the input loom or anndata file.')
46
flags.DEFINE_string('output_csv', None,
47
                    'Path to the folder containing the csv results.')
48
flags.DEFINE_string('log_path', None,
49
                    'Path to the folder containing the runs log.')
50
flags.DEFINE_integer(
51
    'seed', None,
52
    'Random seed to use in the run. If no value is given we will run for 1..5')
53

54
# Flags about the DCA run hyperparameters.
55
flags.DEFINE_enum('ae_type', None,
56
                  ['zinb-conddisp', 'zinb', 'nb-conddisp', 'nb'],
57
                  'Type of autoencoder to use.')
58
flags.DEFINE_boolean(
59
    'normalize_per_cell', None,
60
    'If true, library size normalization is performed '
61
    'using the sc.pp.normalize_per_cell function in '
62
    'Scanpy. If no value is given it will run for both True and False.')
63
flags.DEFINE_boolean(
64
    'scale', None, 'If true, the input of the autoencoder is centered '
65
    'using sc.pp.scale function of Scanpy. If no value is given it will '
66
    'run for both True and False.')
67
flags.DEFINE_boolean(
68
    'log1p', None, 'If true, the input of the autoencoder is log '
69
    'transformed with a pseudocount of one using sc.pp.log1p function of '
70
    'Scanpy. If no value is given it will run for both True and False.')
71
flags.DEFINE_list('hidden_size', None, 'Width of hidden layers.')
72
flags.DEFINE_float('hidden_dropout', None,
73
                   'Probability of weight dropout in the autoencoder.')
74
flags.DEFINE_boolean(
75
    'batchnorm', None,
76
    'Whether to use batchnorm or not. If no value is given it will run for '
77
    'both True and False.')
78
flags.DEFINE_integer('batch_size', None, 'Batch size to use in training.')
79
flags.DEFINE_integer(
80
    'epochs', None,
81
    'Number of epochs to train on. If no value is given it will run for 20, 50,'
82
    '100, 200, 300, 500, and 1000.')
83

84
# Flags about the environment the code is executed in and its output.
85
flags.DEFINE_boolean('from_gcs', True, 'Whether the input is hosted on GCS.')
86
flags.DEFINE_boolean('run_info', False, 'Whether to store the whole run_info.')
87
flags.DEFINE_boolean('save_h5ad', False, 'Whether the anndata should be saved.')
88
flags.DEFINE_boolean('seurat_readable', False,
89
                     'Whether to make the file Seurat readable.')
90

91
Conf = collections.namedtuple(
92
    'Conf',
93
    ['log1p', 'normalize_per_cell', 'scale', 'batchnorm', 'epochs', 'seed'])
94
Metrics = collections.namedtuple(
95
    'Metrics', ['silhouette', 'kmeans_silhouette', 'ami', 'ari'])
96
RunResult = collections.namedtuple('RunResult', [
97
    'method', 'seed', 'ae_type', 'normalize_per_cell', 'scale', 'log1p',
98
    'hidden_size', 'hidden_dropout', 'batchnorm', 'batch_size', 'epochs',
99
    'silhouette', 'kmeans_silhouette', 'kmeans_ami', 'kmeans_ari', 'n_cells',
100
    'tissue', 'n_clusters', 'loss', 'val_loss', 'run_info_fname', 'h5ad_fname'
101
])
102

103

104
def evaluate_method(adata, n_clusters):
105
  """Runs the AMI, ARI, and silhouette computation."""
106
  # If the training diverged, the embedding will have nan for infinity.
107
  if np.any(np.isnan(adata.obsm['X_dca'])):
108
    return Metrics(
109
        silhouette=float('nan'),
110
        kmeans_silhouette=float('nan'),
111
        ari=float('nan'),
112
        ami=float('nan'),
113
    )
114

115
  silhouette = sklearn.metrics.silhouette_score(adata.obsm['X_dca'],
116
                                                adata.obs['label'])
117

118
  kmeans = sklearn.cluster.KMeans(
119
      n_clusters=n_clusters, random_state=0).fit(adata.obsm['X_dca'])
120
  adata.obs['predicted_clusters'] = kmeans.labels_
121

122
  # If all kmeans clusters end up together (failure to converge), the silhouette
123
  # computation will crash.
124
  if len(np.unique(adata.obs['predicted_clusters'])) < 2:
125
    kmeans_silhouette = float('nan')
126
  else:
127
    kmeans_silhouette = sklearn.metrics.silhouette_score(
128
        adata.obsm['X_dca'], adata.obs['predicted_clusters'])
129
  ari = sklearn.metrics.adjusted_rand_score(adata.obs['label'],
130
                                            adata.obs['predicted_clusters'])
131
  ami = sklearn.metrics.adjusted_mutual_info_score(
132
      adata.obs['label'], adata.obs['predicted_clusters'])
133

134
  return Metrics(
135
      silhouette=silhouette,
136
      kmeans_silhouette=kmeans_silhouette,
137
      ami=ami,
138
      ari=ari)
139

140

141
def dca_process(adata, ae_type, normalize_per_cell, scale, log1p, hidden_size,
142
                hidden_dropout, batchnorm, epochs, batch_size, seed,
143
                seurat_readable):
144
  """Runs dca from scanpy."""
145
  sc.pp.dca(
146
      adata,
147
      ae_type=ae_type,
148
      normalize_per_cell=normalize_per_cell,
149
      scale=scale,
150
      log1p=log1p,
151
      hidden_size=hidden_size,
152
      hidden_dropout=hidden_dropout,
153
      mode='latent',
154
      optimizer='Adam',
155
      batchnorm=batchnorm,
156
      epochs=epochs,
157
      batch_size=batch_size,
158
      random_state=seed,
159
      return_info=True)
160
  if seurat_readable:
161
    adata.var['Gene'] = adata.var.index
162
    adata.obs['CellID'] = adata.obs['cell']
163
    adata.obsm['dca_cell_embeddings'] = adata.obsm['X_dca']
164
  return adata
165

166

167
def log_run(path, conf):
168
  """Logs a successful run in a CSV, as well as the header for new files."""
169
  conf_dict = dict(conf._asdict())
170
  # We need to check if the file exists before creating it.
171
  write_header = not tf.io.gfile.exists(path)
172
  with tf.io.gfile.GFile(path, 'a') as f:
173
    csv_writer = csv.DictWriter(f, fieldnames=conf_dict.keys())
174
    if write_header:
175
      csv_writer.writeheader()
176
    csv_writer.writerow(conf_dict)
177

178

179
def log_run_info(save_run_info, infos, log_folder, conf, tissue, ae_type,
180
                 hidden_size, hidden_dropout, batch_size):
181
  """Saves the training stats and returns the path to them."""
182
  if not save_run_info:
183
    return ''
184

185
  # Save the loss for this run.
186
  run_info_fname = os.path.join(
187
      log_folder, f'{tissue}.method=dca.seed={conf.seed}.ae_type={ae_type}.'
188
      f'normalize_per_cell={conf.normalize_per_cell}.scale={conf.scale}.'
189
      f'log1p={conf.log1p}.hidden_size={hidden_size}.'
190
      f'hidden_dropout={hidden_dropout}.batchnorm={conf.batchnorm}.'
191
      f'batch_size={batch_size}.epochs={conf.epochs}.runinfo.csv')
192
  run_info_df = pd.DataFrame.from_dict(infos)
193
  with tf.io.gfile.GFile(run_info_fname, 'w') as f:
194
    run_info_df.to_csv(f)
195
  return run_info_fname
196

197

198
def fetch_anndata(path, from_gcs):
199
  """Reads the input data and turns it into an anndata.AnnData object."""
200
  _, ext = os.path.splitext(path)
201

202
  # AnnData is based of HDF5 and doesn't have GCS file handlers
203
  # so we have to locally copy the file before reading it.
204
  if from_gcs:
205
    with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
206
      tmp_path = tmp_file.name
207
    tf.io.gfile.copy(path, tmp_path, overwrite=True)
208
    path = tmp_path
209

210
  if ext == '.h5ad':
211
    adata = anndata.read_h5ad(path)
212
  elif ext == '.loom':
213
    adata = anndata.read_loom(path)
214
  else:
215
    raise app.UsageError('Only supports loom and h5ad files.')
216

217
  return adata
218

219

220
def write_anndata(save_h5ad, adata, log_folder, conf, tissue, ae_type,
221
                  hidden_size, hidden_dropout, batch_size):
222
  """Writes anndata object with the proper name on GCS and returns the name."""
223
  # We need to write the anndata locally and copy it to GCS for the same
224
  # reason as before.
225
  if not save_h5ad:
226
    return ''
227

228
  with tempfile.NamedTemporaryFile(suffix='.h5ad', delete=True) as tmp_file:
229
    adata.write(tmp_file.name)
230
    h5ad_fname = os.path.join(
231
        log_folder, f'{tissue}.method=dca.seed={conf.seed}.ae_type={ae_type}.'
232
        f'normalize_per_cell={conf.normalize_per_cell}.scale={conf.scale}.'
233
        f'log1p={conf.log1p}.hidden_size={hidden_size}.'
234
        f'hidden_dropout={hidden_dropout}.batchnorm={conf.batchnorm}.'
235
        f'batch_size={batch_size}.epochs={conf.epochs}.h5ad')
236
    tf.io.gfile.copy(tmp_file.name, h5ad_fname, overwrite=True)
237
  return h5ad_fname
238

239

240
def generate_conf(log1p, normalize_per_cell, scale, batchnorm, epochs, seed):
241
  """Generates the local parameter grid."""
242
  local_param_grid = {
243
      'log1p': [True, False] if log1p is None else [log1p],
244
      'normalize_per_cell': [True, False] if normalize_per_cell is None else
245
                            [normalize_per_cell],
246
      'scale': [True, False] if scale is None else [scale],
247
      'batchnorm': [True, False] if batchnorm is None else [batchnorm],
248
      'epochs': [20, 50, 100, 200, 300, 500, 1000]
249
                if epochs is None else [epochs],
250
      'seed': [0, 1, 2, 3, 4] if seed is None else [seed]
251
  }
252

253
  return [Conf(*vals) for vals in itertools.product(*local_param_grid.values())]
254

255

256
def fetch_previous_runs(log_path):
257
  """Reads in the state in which the previous run stopped."""
258
  previous_runs = set()
259
  if tf.io.gfile.exists(log_path):
260
    with tf.io.gfile.GFile(log_path, mode='r') as f:
261
      reader = csv.DictReader(f)
262
      for row in reader:
263
        # Note: we need to do this conversion because DictReader creates an
264
        # OrderedDict, and reads all values as str instead of bool or int.
265
        previous_runs.add(
266
            str(
267
                Conf(
268
                    log1p=row['log1p'] == 'True',
269
                    normalize_per_cell=row['normalize_per_cell'] == 'True',
270
                    scale=row['scale'] == 'True',
271
                    batchnorm=row['batchnorm'] == 'True',
272
                    epochs=int(row['epochs']),
273
                    seed=int(row['seed']),
274
                )))
275
  logging.info('Previous runs:')
276
  for run in previous_runs:
277
    logging.info(run)
278

279
  return previous_runs
280

281

282
def main(unused_argv):
283
  hidden_size = [int(l) for l in FLAGS.hidden_size]
284

285
  tissue, _ = os.path.splitext(os.path.basename(FLAGS.input_path))
286
  adata = fetch_anndata(FLAGS.input_path, FLAGS.from_gcs)
287

288
  confs = generate_conf(
289
      log1p=FLAGS.log1p,
290
      normalize_per_cell=FLAGS.normalize_per_cell,
291
      scale=FLAGS.scale,
292
      batchnorm=FLAGS.batchnorm,
293
      epochs=FLAGS.epochs,
294
      seed=FLAGS.seed)
295
  previous_runs = fetch_previous_runs(FLAGS.log_path)
296

297
  sc.pp.filter_genes(adata, min_cells=1)
298
  n_clusters = adata.obs['label'].nunique()
299
  total_runs = len(confs)
300

301
  for i, conf in enumerate(confs):
302
    if str(conf) in previous_runs:
303
      logging.info('Skipped %s', conf)
304
      continue
305

306
    adata = dca_process(
307
        adata,
308
        ae_type=FLAGS.ae_type,
309
        normalize_per_cell=conf.normalize_per_cell,
310
        scale=conf.scale,
311
        log1p=conf.log1p,
312
        hidden_size=hidden_size,
313
        hidden_dropout=FLAGS.hidden_dropout,
314
        batchnorm=conf.batchnorm,
315
        batch_size=FLAGS.batch_size,
316
        epochs=conf.epochs,
317
        seed=conf.seed,
318
        seurat_readable=FLAGS.seurat_readable)
319
    metrics = evaluate_method(adata, n_clusters)
320
    infos = adata.uns['dca_loss_history']
321

322
    log_folder = os.path.dirname(FLAGS.output_csv)
323

324
    run_info_fname = log_run_info(
325
        save_run_info=FLAGS.run_info,
326
        infos=infos,
327
        log_folder=log_folder,
328
        conf=conf,
329
        tissue=tissue,
330
        ae_type=FLAGS.ae_type,
331
        hidden_size=hidden_size,
332
        hidden_dropout=FLAGS.hidden_dropout,
333
        batch_size=FLAGS.batch_size)
334

335
    h5ad_fname = write_anndata(
336
        adata=adata,
337
        save_h5ad=FLAGS.save_h5ad,
338
        log_folder=log_folder,
339
        conf=conf,
340
        tissue=tissue,
341
        ae_type=FLAGS.ae_type,
342
        hidden_size=hidden_size,
343
        hidden_dropout=FLAGS.hidden_dropout,
344
        batch_size=FLAGS.batch_size)
345

346
    run_result = RunResult(
347
        method='dca',
348
        seed=conf.seed,
349
        ae_type=FLAGS.ae_type,
350
        normalize_per_cell=conf.normalize_per_cell,
351
        scale=conf.scale,
352
        log1p=conf.log1p,
353
        hidden_size=hidden_size,
354
        hidden_dropout=FLAGS.hidden_dropout,
355
        batchnorm=conf.batchnorm,
356
        batch_size=FLAGS.batch_size,
357
        epochs=conf.epochs,
358
        silhouette=metrics.silhouette,
359
        kmeans_silhouette=metrics.kmeans_silhouette,
360
        kmeans_ami=metrics.ami,
361
        kmeans_ari=metrics.ari,
362
        n_cells=adata.n_obs,
363
        tissue=tissue,
364
        n_clusters=n_clusters,
365
        loss=infos['loss'][-1],
366
        val_loss=infos['val_loss'][-1],
367
        run_info_fname=run_info_fname,
368
        h5ad_fname=h5ad_fname)
369
    log_run(FLAGS.output_csv, run_result)
370

371
    logging.info(conf)
372
    logging.info('Done with %s out of %s', i, total_runs)
373
    log_run(FLAGS.log_path, conf)
374

375

376
if __name__ == '__main__':
377
  flags.mark_flag_as_required('input_path')
378
  flags.mark_flag_as_required('output_csv')
379
  flags.mark_flag_as_required('log_path')
380
  flags.mark_flag_as_required('ae_type')
381
  flags.mark_flag_as_required('hidden_size')
382
  flags.mark_flag_as_required('hidden_dropout')
383
  flags.mark_flag_as_required('batch_size')
384
  app.run(main)
385

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.