google-research
384 строки · 13.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""DCA run and evaluation.
17
18Runs DCA with the specified input parameters and evaluates it.
19
20This is weird 'hack' that launches a hyperparameter grid locally
21because our grid is so large that we cannot launch one configuration
22per machine.
23"""
24
25import collections26import csv27import itertools28import os29import tempfile30
31from absl import app32from absl import flags33from absl import logging34import anndata35import numpy as np36import pandas as pd37import scanpy.api as sc38import sklearn.cluster39import sklearn.metrics40import tensorflow as tf41
42FLAGS = flags.FLAGS43
44flags.DEFINE_string('input_path', None,45'Path to the input loom or anndata file.')46flags.DEFINE_string('output_csv', None,47'Path to the folder containing the csv results.')48flags.DEFINE_string('log_path', None,49'Path to the folder containing the runs log.')50flags.DEFINE_integer(51'seed', None,52'Random seed to use in the run. If no value is given we will run for 1..5')53
54# Flags about the DCA run hyperparameters.
55flags.DEFINE_enum('ae_type', None,56['zinb-conddisp', 'zinb', 'nb-conddisp', 'nb'],57'Type of autoencoder to use.')58flags.DEFINE_boolean(59'normalize_per_cell', None,60'If true, library size normalization is performed '61'using the sc.pp.normalize_per_cell function in '62'Scanpy. If no value is given it will run for both True and False.')63flags.DEFINE_boolean(64'scale', None, 'If true, the input of the autoencoder is centered '65'using sc.pp.scale function of Scanpy. If no value is given it will '66'run for both True and False.')67flags.DEFINE_boolean(68'log1p', None, 'If true, the input of the autoencoder is log '69'transformed with a pseudocount of one using sc.pp.log1p function of '70'Scanpy. If no value is given it will run for both True and False.')71flags.DEFINE_list('hidden_size', None, 'Width of hidden layers.')72flags.DEFINE_float('hidden_dropout', None,73'Probability of weight dropout in the autoencoder.')74flags.DEFINE_boolean(75'batchnorm', None,76'Whether to use batchnorm or not. If no value is given it will run for '77'both True and False.')78flags.DEFINE_integer('batch_size', None, 'Batch size to use in training.')79flags.DEFINE_integer(80'epochs', None,81'Number of epochs to train on. If no value is given it will run for 20, 50,'82'100, 200, 300, 500, and 1000.')83
84# Flags about the environment the code is executed in and its output.
85flags.DEFINE_boolean('from_gcs', True, 'Whether the input is hosted on GCS.')86flags.DEFINE_boolean('run_info', False, 'Whether to store the whole run_info.')87flags.DEFINE_boolean('save_h5ad', False, 'Whether the anndata should be saved.')88flags.DEFINE_boolean('seurat_readable', False,89'Whether to make the file Seurat readable.')90
91Conf = collections.namedtuple(92'Conf',93['log1p', 'normalize_per_cell', 'scale', 'batchnorm', 'epochs', 'seed'])94Metrics = collections.namedtuple(95'Metrics', ['silhouette', 'kmeans_silhouette', 'ami', 'ari'])96RunResult = collections.namedtuple('RunResult', [97'method', 'seed', 'ae_type', 'normalize_per_cell', 'scale', 'log1p',98'hidden_size', 'hidden_dropout', 'batchnorm', 'batch_size', 'epochs',99'silhouette', 'kmeans_silhouette', 'kmeans_ami', 'kmeans_ari', 'n_cells',100'tissue', 'n_clusters', 'loss', 'val_loss', 'run_info_fname', 'h5ad_fname'101])102
103
104def evaluate_method(adata, n_clusters):105"""Runs the AMI, ARI, and silhouette computation."""106# If the training diverged, the embedding will have nan for infinity.107if np.any(np.isnan(adata.obsm['X_dca'])):108return Metrics(109silhouette=float('nan'),110kmeans_silhouette=float('nan'),111ari=float('nan'),112ami=float('nan'),113)114
115silhouette = sklearn.metrics.silhouette_score(adata.obsm['X_dca'],116adata.obs['label'])117
118kmeans = sklearn.cluster.KMeans(119n_clusters=n_clusters, random_state=0).fit(adata.obsm['X_dca'])120adata.obs['predicted_clusters'] = kmeans.labels_121
122# If all kmeans clusters end up together (failure to converge), the silhouette123# computation will crash.124if len(np.unique(adata.obs['predicted_clusters'])) < 2:125kmeans_silhouette = float('nan')126else:127kmeans_silhouette = sklearn.metrics.silhouette_score(128adata.obsm['X_dca'], adata.obs['predicted_clusters'])129ari = sklearn.metrics.adjusted_rand_score(adata.obs['label'],130adata.obs['predicted_clusters'])131ami = sklearn.metrics.adjusted_mutual_info_score(132adata.obs['label'], adata.obs['predicted_clusters'])133
134return Metrics(135silhouette=silhouette,136kmeans_silhouette=kmeans_silhouette,137ami=ami,138ari=ari)139
140
141def dca_process(adata, ae_type, normalize_per_cell, scale, log1p, hidden_size,142hidden_dropout, batchnorm, epochs, batch_size, seed,143seurat_readable):144"""Runs dca from scanpy."""145sc.pp.dca(146adata,147ae_type=ae_type,148normalize_per_cell=normalize_per_cell,149scale=scale,150log1p=log1p,151hidden_size=hidden_size,152hidden_dropout=hidden_dropout,153mode='latent',154optimizer='Adam',155batchnorm=batchnorm,156epochs=epochs,157batch_size=batch_size,158random_state=seed,159return_info=True)160if seurat_readable:161adata.var['Gene'] = adata.var.index162adata.obs['CellID'] = adata.obs['cell']163adata.obsm['dca_cell_embeddings'] = adata.obsm['X_dca']164return adata165
166
167def log_run(path, conf):168"""Logs a successful run in a CSV, as well as the header for new files."""169conf_dict = dict(conf._asdict())170# We need to check if the file exists before creating it.171write_header = not tf.io.gfile.exists(path)172with tf.io.gfile.GFile(path, 'a') as f:173csv_writer = csv.DictWriter(f, fieldnames=conf_dict.keys())174if write_header:175csv_writer.writeheader()176csv_writer.writerow(conf_dict)177
178
179def log_run_info(save_run_info, infos, log_folder, conf, tissue, ae_type,180hidden_size, hidden_dropout, batch_size):181"""Saves the training stats and returns the path to them."""182if not save_run_info:183return ''184
185# Save the loss for this run.186run_info_fname = os.path.join(187log_folder, f'{tissue}.method=dca.seed={conf.seed}.ae_type={ae_type}.'188f'normalize_per_cell={conf.normalize_per_cell}.scale={conf.scale}.'189f'log1p={conf.log1p}.hidden_size={hidden_size}.'190f'hidden_dropout={hidden_dropout}.batchnorm={conf.batchnorm}.'191f'batch_size={batch_size}.epochs={conf.epochs}.runinfo.csv')192run_info_df = pd.DataFrame.from_dict(infos)193with tf.io.gfile.GFile(run_info_fname, 'w') as f:194run_info_df.to_csv(f)195return run_info_fname196
197
198def fetch_anndata(path, from_gcs):199"""Reads the input data and turns it into an anndata.AnnData object."""200_, ext = os.path.splitext(path)201
202# AnnData is based of HDF5 and doesn't have GCS file handlers203# so we have to locally copy the file before reading it.204if from_gcs:205with tempfile.NamedTemporaryFile(delete=False) as tmp_file:206tmp_path = tmp_file.name207tf.io.gfile.copy(path, tmp_path, overwrite=True)208path = tmp_path209
210if ext == '.h5ad':211adata = anndata.read_h5ad(path)212elif ext == '.loom':213adata = anndata.read_loom(path)214else:215raise app.UsageError('Only supports loom and h5ad files.')216
217return adata218
219
220def write_anndata(save_h5ad, adata, log_folder, conf, tissue, ae_type,221hidden_size, hidden_dropout, batch_size):222"""Writes anndata object with the proper name on GCS and returns the name."""223# We need to write the anndata locally and copy it to GCS for the same224# reason as before.225if not save_h5ad:226return ''227
228with tempfile.NamedTemporaryFile(suffix='.h5ad', delete=True) as tmp_file:229adata.write(tmp_file.name)230h5ad_fname = os.path.join(231log_folder, f'{tissue}.method=dca.seed={conf.seed}.ae_type={ae_type}.'232f'normalize_per_cell={conf.normalize_per_cell}.scale={conf.scale}.'233f'log1p={conf.log1p}.hidden_size={hidden_size}.'234f'hidden_dropout={hidden_dropout}.batchnorm={conf.batchnorm}.'235f'batch_size={batch_size}.epochs={conf.epochs}.h5ad')236tf.io.gfile.copy(tmp_file.name, h5ad_fname, overwrite=True)237return h5ad_fname238
239
240def generate_conf(log1p, normalize_per_cell, scale, batchnorm, epochs, seed):241"""Generates the local parameter grid."""242local_param_grid = {243'log1p': [True, False] if log1p is None else [log1p],244'normalize_per_cell': [True, False] if normalize_per_cell is None else245[normalize_per_cell],246'scale': [True, False] if scale is None else [scale],247'batchnorm': [True, False] if batchnorm is None else [batchnorm],248'epochs': [20, 50, 100, 200, 300, 500, 1000]249if epochs is None else [epochs],250'seed': [0, 1, 2, 3, 4] if seed is None else [seed]251}252
253return [Conf(*vals) for vals in itertools.product(*local_param_grid.values())]254
255
256def fetch_previous_runs(log_path):257"""Reads in the state in which the previous run stopped."""258previous_runs = set()259if tf.io.gfile.exists(log_path):260with tf.io.gfile.GFile(log_path, mode='r') as f:261reader = csv.DictReader(f)262for row in reader:263# Note: we need to do this conversion because DictReader creates an264# OrderedDict, and reads all values as str instead of bool or int.265previous_runs.add(266str(267Conf(268log1p=row['log1p'] == 'True',269normalize_per_cell=row['normalize_per_cell'] == 'True',270scale=row['scale'] == 'True',271batchnorm=row['batchnorm'] == 'True',272epochs=int(row['epochs']),273seed=int(row['seed']),274)))275logging.info('Previous runs:')276for run in previous_runs:277logging.info(run)278
279return previous_runs280
281
282def main(unused_argv):283hidden_size = [int(l) for l in FLAGS.hidden_size]284
285tissue, _ = os.path.splitext(os.path.basename(FLAGS.input_path))286adata = fetch_anndata(FLAGS.input_path, FLAGS.from_gcs)287
288confs = generate_conf(289log1p=FLAGS.log1p,290normalize_per_cell=FLAGS.normalize_per_cell,291scale=FLAGS.scale,292batchnorm=FLAGS.batchnorm,293epochs=FLAGS.epochs,294seed=FLAGS.seed)295previous_runs = fetch_previous_runs(FLAGS.log_path)296
297sc.pp.filter_genes(adata, min_cells=1)298n_clusters = adata.obs['label'].nunique()299total_runs = len(confs)300
301for i, conf in enumerate(confs):302if str(conf) in previous_runs:303logging.info('Skipped %s', conf)304continue305
306adata = dca_process(307adata,308ae_type=FLAGS.ae_type,309normalize_per_cell=conf.normalize_per_cell,310scale=conf.scale,311log1p=conf.log1p,312hidden_size=hidden_size,313hidden_dropout=FLAGS.hidden_dropout,314batchnorm=conf.batchnorm,315batch_size=FLAGS.batch_size,316epochs=conf.epochs,317seed=conf.seed,318seurat_readable=FLAGS.seurat_readable)319metrics = evaluate_method(adata, n_clusters)320infos = adata.uns['dca_loss_history']321
322log_folder = os.path.dirname(FLAGS.output_csv)323
324run_info_fname = log_run_info(325save_run_info=FLAGS.run_info,326infos=infos,327log_folder=log_folder,328conf=conf,329tissue=tissue,330ae_type=FLAGS.ae_type,331hidden_size=hidden_size,332hidden_dropout=FLAGS.hidden_dropout,333batch_size=FLAGS.batch_size)334
335h5ad_fname = write_anndata(336adata=adata,337save_h5ad=FLAGS.save_h5ad,338log_folder=log_folder,339conf=conf,340tissue=tissue,341ae_type=FLAGS.ae_type,342hidden_size=hidden_size,343hidden_dropout=FLAGS.hidden_dropout,344batch_size=FLAGS.batch_size)345
346run_result = RunResult(347method='dca',348seed=conf.seed,349ae_type=FLAGS.ae_type,350normalize_per_cell=conf.normalize_per_cell,351scale=conf.scale,352log1p=conf.log1p,353hidden_size=hidden_size,354hidden_dropout=FLAGS.hidden_dropout,355batchnorm=conf.batchnorm,356batch_size=FLAGS.batch_size,357epochs=conf.epochs,358silhouette=metrics.silhouette,359kmeans_silhouette=metrics.kmeans_silhouette,360kmeans_ami=metrics.ami,361kmeans_ari=metrics.ari,362n_cells=adata.n_obs,363tissue=tissue,364n_clusters=n_clusters,365loss=infos['loss'][-1],366val_loss=infos['val_loss'][-1],367run_info_fname=run_info_fname,368h5ad_fname=h5ad_fname)369log_run(FLAGS.output_csv, run_result)370
371logging.info(conf)372logging.info('Done with %s out of %s', i, total_runs)373log_run(FLAGS.log_path, conf)374
375
376if __name__ == '__main__':377flags.mark_flag_as_required('input_path')378flags.mark_flag_as_required('output_csv')379flags.mark_flag_as_required('log_path')380flags.mark_flag_as_required('ae_type')381flags.mark_flag_as_required('hidden_size')382flags.mark_flag_as_required('hidden_dropout')383flags.mark_flag_as_required('batch_size')384app.run(main)385