google-research
99 строк · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Resize bins for the 10X formatted dataset."""
17
18import os
19from typing import Sequence, Any
20
21from absl import app
22from absl import flags
23import anndata
24import pandas as pd
25import scipy.io
26import scipy.sparse
27import tensorflow as tf
28
29from schptm_benchmark import resize_bins_lib
30
31FLAGS = flags.FLAGS
32flags.DEFINE_string('input_path', None, 'Path to the 10x formatted folder.')
33flags.DEFINE_string('output_dir', None, 'Path to the output directory.')
34flags.DEFINE_integer('binsize', None, 'Number of bp per bin (in kbp).')
35flags.DEFINE_enum('mode', 'bins', ['bins', 'annotation'],
36'Number of bp per bin (in kbp)')
37flags.DEFINE_string('annotation', None, 'Path to the annotation.')
38
39
40def create_anndata(path):
41"""Creates anndata object from raw data.
42
43Args:
44path: Path to the 10x formatted input files.
45
46Returns:
47anndata object for the experiment.
48"""
49with tf.io.gfile.GFile(os.path.join(path, 'matrix.mtx'), mode='rb') as f:
50matrix = scipy.io.mmread(f)
51matrix = scipy.sparse.csr_matrix(matrix)
52adata = anndata.AnnData(matrix)
53adata = adata.transpose()
54with tf.io.gfile.GFile(os.path.join(path, 'barcodes.tsv'), mode='r') as f:
55barcodes = pd.read_csv(f, sep='\t', header=None)[0]
56adata.obs_names = barcodes
57with tf.io.gfile.GFile(os.path.join(path, 'bins.tsv'), mode='r') as f:
58bins = pd.read_csv(f, sep='\t', header=None)[0]
59adata.var_names = bins
60return adata
61
62
63def save_anndata(adata, output_dir,
64input_path):
65"""Saves AnnData object in 10X format."""
66tf.io.gfile.makedirs(output_dir)
67with tf.io.gfile.GFile(os.path.join(output_dir, 'matrix.mtx'), mode='w') as f:
68scipy.io.mmwrite(f, adata.X.transpose())
69new_bins = pd.DataFrame(adata.var_names, columns=['var_names'])
70with tf.io.gfile.GFile(os.path.join(output_dir, 'bins.tsv'), mode='w') as f:
71new_bins.to_csv(
72f,
73sep='\t',
74index=False,
75header=False,
76columns=['var_names', 'var_names'])
77tf.io.gfile.copy(
78os.path.join(input_path, 'barcodes.tsv'),
79os.path.join(output_dir, 'barcodes.tsv'),
80overwrite=True)
81
82
83def main(argv):
84del argv
85
86adata = create_anndata(FLAGS.input_path)
87if FLAGS.mode == 'bins':
88adata = resize_bins_lib.merge_bins(adata, FLAGS.binsize * (10**3))
89elif FLAGS.mode == 'annotation':
90adata = resize_bins_lib.bins_from_annotation(adata, FLAGS.annotation)
91
92save_anndata(adata, FLAGS.output_dir, FLAGS.input_path)
93
94
95if __name__ == '__main__':
96flags.mark_flag_as_required('input_path')
97flags.mark_flag_as_required('output_dir')
98flags.mark_flag_as_required('binsize')
99app.run(main)
100