google-research
75 строк · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Script for running PeakVI on a 10x formatted dataset."""
17
18import os19from typing import Any, Sequence20
21from absl import app22from absl import flags23import anndata24import pandas as pd25import scipy.io26import scipy.sparse27import scvi28import tensorflow as tf29
30FLAGS = flags.FLAGS31flags.DEFINE_string('input_path', None, 'Path to the 10x formatted folder.')32flags.DEFINE_string('output_path', None, 'Path to the output directory.')33
34
35def create_anndata(path):36"""Creates anndata object from raw data.37
38Args:
39path: Path to the 10x formatted input files.
40
41Returns:
42anndata object for the experiment.
43"""
44with tf.io.gfile.GFile(os.path.join(path, 'matrix.mtx'), mode='rb') as f:45matrix = scipy.io.mmread(f)46matrix = scipy.sparse.csr_matrix(matrix)47adata = anndata.AnnData(matrix)48adata = adata.transpose()49with tf.io.gfile.GFile(os.path.join(path, 'barcodes.tsv'), mode='r') as f:50barcodes = pd.read_csv(f, sep='\t', header=None)[0]51adata.obs_names = barcodes52with tf.io.gfile.GFile(os.path.join(path, 'bins.tsv'), mode='r') as f:53bins = pd.read_csv(f, sep='\t', header=None)[0]54adata.var_names = bins55return adata56
57
58def main(argv):59if len(argv) > 1:60raise app.UsageError('Too many command-line arguments.')61
62adata = create_anndata(FLAGS.input_path)63scvi.model.PEAKVI.setup_anndata(adata)64vae = scvi.model.PEAKVI(adata)65vae.train()66dr = pd.DataFrame(vae.get_latent_representation(), index=adata.obs_names)67
68tf.io.gfile.makedirs(FLAGS.output_path)69with tf.io.gfile.GFile(os.path.join(FLAGS.output_path, 'peakVI.csv'),70'w') as f:71dr.to_csv(f)72
73
74if __name__ == '__main__':75app.run(main)76