import concurrent
import functools
import multiprocessing
import os
import struct
import sys
import time
from concurrent.futures import as_completed
from contextlib import contextmanager
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union
import anndata
import ngs_tools as ngs
import numpy as np
import pandas as pd
import psutil
from scipy import sparse
from . import config
from .logging import logger
# As of 0.0.1, these are provided by ngs_tools, but keep these here for now because
# they are imported from this file in many places.
[docs]run_executable = ngs.utils.run_executable
[docs]open_as_text = ngs.utils.open_as_text
[docs]decompress_gzip = ngs.utils.decompress_gzip
[docs]flatten_dict_values = ngs.utils.flatten_dict_values
[docs]mkstemp = ngs.utils.mkstemp
[docs]all_exists = ngs.utils.all_exists
[docs]flatten_dictionary = ngs.utils.flatten_dictionary
[docs]flatten_iter = ngs.utils.flatten_iter
[docs]merge_dictionaries = ngs.utils.merge_dictionaries
[docs]write_pickle = ngs.utils.write_pickle
[docs]read_pickle = ngs.utils.read_pickle
[docs]class UnsupportedOSException(Exception):
pass
[docs]class suppress_stdout_stderr:
"""A context manager for doing a "deep suppression" of stdout and stderr in
Python, i.e. will suppress all print, even if the print originates in a
compiled C/Fortran sub-function.
This will not suppress raised exceptions, since exceptions are printed
to stderr just before a script exits, and after the context manager has
exited (at least, I think that is why it lets exceptions through).
https://github.com/facebook/prophet/issues/223
"""
def __init__(self):
# Open a pair of null files
self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)]
# Save the actual stdout (1) and stderr (2) file descriptors.
self.save_fds = [os.dup(1), os.dup(2)]
[docs] def __enter__(self):
# Assign the null pointers to stdout and stderr.
os.dup2(self.null_fds[0], 1)
os.dup2(self.null_fds[1], 2)
[docs] def __exit__(self, *_):
# Re-assign the real stdout/stderr back to (1) and (2)
os.dup2(self.save_fds[0], 1)
os.dup2(self.save_fds[1], 2)
# Close the null files
for fd in self.null_fds + self.save_fds:
os.close(fd)
[docs]def get_STAR_binary_path() -> str:
"""Get the path to the platform-dependent STAR binary included with
the installation.
Returns:
Path to the binary
"""
bin_filename = 'STAR.exe' if config.PLATFORM == 'windows' else 'STAR'
path = os.path.join(config.BINS_DIR, config.PLATFORM, 'STAR', bin_filename)
if not os.path.exists(path):
raise UnsupportedOSException(f'This operating system ({config.PLATFORM}) is not supported.')
return path
[docs]def get_STAR_version() -> str:
"""Get the provided STAR version.
Returns:
Version string
"""
p, stdout, stderr = run_executable([get_STAR_binary_path(), '--version'], quiet=True, returncode=1)
version = stdout.strip()
return version
[docs]def combine_arguments(args: Dict[str, Any], additional: Dict[str, Any]) -> Dict[str, Any]:
"""Combine two dictionaries representing command-line arguments.
Any duplicate keys will be merged according to the following procedure:
1. If the value in both dictionaries are lists, the two lists are combined.
2. Otherwise, the value in the first dictionary is OVERWRITTEN.
Args:
args: Original command-line arguments
additional: Additional command-line arguments
Returns:
Combined command-line arguments
"""
new_args = args.copy()
for key, value in additional.items():
if key in new_args:
if isinstance(value, list) and isinstance(new_args[key], list):
new_args[key] += value
else:
new_args[key] = value
else:
new_args[key] = value
return new_args
[docs]def arguments_to_list(args: Dict[str, Any]) -> List[Any]:
"""Convert a dictionary of command-line arguments to a list.
Args:
args: Command-line arguments
Returns:
List of command-line arguments
"""
arguments = []
for key, value in args.items():
arguments.append(key)
if isinstance(value, list):
arguments.extend(value)
else:
arguments.append(value)
return arguments
[docs]def get_file_descriptor_limit() -> int:
"""Get the current value for the maximum number of open file descriptors
in a platform-dependent way.
Returns:
The current value of the maximum number of open file descriptors.
"""
if config.PLATFORM == 'windows':
import win32file
return win32file._getmaxstdio()
else:
import resource
return resource.getrlimit(resource.RLIMIT_NOFILE)[0]
[docs]def get_max_file_descriptor_limit() -> int:
"""Get the maximum allowed value for the maximum number of open file
descriptors.
Note that for Windows, there is not an easy way to get this,
as it requires reading from the registry. So, we just return the maximum for
a vanilla Windows installation, which is 8192.
https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/setmaxstdio?view=vs-2019
Similarly, on MacOS, we return a hardcoded 10240.
Returns:
Maximum allowed value for the maximum number of open file descriptors
"""
if config.PLATFORM == 'windows':
return 8192
elif config.PLATFORM == 'darwin':
return 10240
else:
import resource
return resource.getrlimit(resource.RLIMIT_NOFILE)[1]
@contextmanager
[docs]def increase_file_descriptor_limit(limit: int):
"""Context manager that can be used to temporarily increase the maximum
number of open file descriptors for the current process. The original
value is restored when execution exits this function.
This is required when running STAR with many threads.
Args:
limit: Maximum number of open file descriptors will be increased to
this value for the duration of the context
"""
old = None
if config.PLATFORM == 'windows':
import win32file
try:
old = win32file._getmaxstdio()
win32file._setmaxstdio(limit)
yield
finally:
if old is not None:
win32file._setmaxstdio(old)
else:
import resource
try:
old = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (limit, old[1]))
yield
finally:
if old is not None:
resource.setrlimit(resource.RLIMIT_NOFILE, old)
[docs]def get_available_memory() -> int:
"""Get total amount of available memory (total memory - used memory) in bytes.
Returns:
Available memory in bytes
"""
return psutil.virtual_memory().available
[docs]def make_pool_with_counter(n_threads: int) -> Tuple[multiprocessing.Pool, multiprocessing.Value, multiprocessing.Lock]:
"""Create a new Process pool with a shared progress counter.
Args:
n_threads: Number of processes
Returns:
Tuple of (Process pool, progress counter, lock)
"""
manager = multiprocessing.Manager()
counter = manager.Value('I', 0)
lock = manager.Lock()
pool = multiprocessing.Pool(n_threads)
return pool, counter, lock
[docs]def display_progress_with_counter(
counter: multiprocessing.Value, total: int, *async_results, desc: Optional[str] = None
):
"""Display progress bar for displaying multiprocessing progress.
Args:
counter: Progress counter
total: Maximum number of units of processing
*async_results: Multiprocessing results to monitor. These are used to
determine when all processes are done.
desc: Progress bar description
"""
with ngs.progress.progress(total=total, unit_scale=True, desc=desc) as pbar:
previous_progress = 0
while any(not async_result.ready() for async_result in async_results):
time.sleep(0.01)
progress = counter.value
pbar.update(progress - previous_progress)
pbar.refresh()
previous_progress = progress
[docs]def as_completed_with_progress(futures: Iterable[concurrent.futures.Future]):
"""Wrapper around `concurrent.futures.as_completed` that displays a progress bar.
Args:
Iterator of `concurrent.futures.Future` objects
"""
with ngs.progress.progress(total=len(futures)) as pbar:
for future in as_completed(futures):
yield future
pbar.update(1)
[docs]def split_index(index: List[Tuple[int, int, int]], n: int = 8) -> List[List[Tuple[int, int, int]]]:
"""Split a conversions index, which is a list of tuples (file position,
number of lines, alignment position), one for each read, into `n`
approximately equal parts. This function is used to split the conversions
CSV for multiprocessing.
Args:
index: index
n: Number of splits, defaults to `8`
Returns:
List of parts
"""
n_lines = sum(idx[1] for idx in index)
target = (n_lines // n) + 1 # add one to prevent underflow
# Split the index to "approximately" equal parts
parts = []
current_part = []
current_size = 0
for tup in index:
current_part.append(tup)
current_size += tup[1]
if current_size >= target:
parts.append(current_part)
current_size = 0
current_part = []
if current_part:
parts.append(current_part)
return parts
[docs]def downsample_counts(
df_counts: pd.DataFrame,
proportion: Optional[float] = None,
count: Optional[int] = None,
seed: Optional[int] = None,
group_by: Optional[List[str]] = None
) -> pd.DataFrame:
"""Downsample the given counts dataframe according to the ``proportion`` or
``count`` arguments. One of these two must be provided, but not both. The dataframe
is assumed to be UMI-deduplicated.
Args:
df_counts: Counts dataframe
proportion: Proportion of reads (UMIs) to keep
count: Absolute number of reads (UMIs) to keep
seed: Random seed
group_by: Columns in the counts dataframe to use to group entries.
When this is provided, UMIs are no longer sampled at random, but instead
grouped by this argument, and only groups that have more than ``count`` UMIs
are downsampled.
Returns:
Downsampled counts dataframe
"""
rng = np.random.default_rng(seed)
if not group_by:
if bool(proportion) == bool(count):
raise Exception('Only one of `proportion` or `count` must be provided.')
n_keep = int(df_counts.shape[0] * proportion) if proportion is not None else count
return df_counts.iloc[rng.choice(df_counts.shape[0], n_keep, shuffle=False, replace=False)]
else:
if not count:
raise Exception('`count` must be provided when using `group_by`')
dfs = []
for key, df_group in df_counts.groupby(group_by, sort=False, observed=True):
if df_group.shape[0] > count:
df_group = df_group.iloc[rng.choice(df_group.shape[0], count, shuffle=False, replace=False)]
dfs.append(df_group)
return pd.concat(dfs)
[docs]def counts_to_matrix(
df_counts: pd.DataFrame,
barcodes: List[str],
features: List[str],
barcode_column: str = 'barcode',
feature_column: str = 'GX'
) -> sparse.csr_matrix:
"""Convert a counts dataframe to a sparse counts matrix.
Counts are assumed to be appropriately deduplicated.
Args:
df_counts: Counts dataframe
barcodes: List of barcodes that will map to the rows
features: List of features (i.e. genes) that will map to the columns
barcode_column: Column in counts dataframe to use as barcodes, defaults to `barcode`
feature_column: Column in counts dataframe to use as features, defaults to `GX`
Returns:
Sparse counts matrix
"""
# Transform to index for fast lookup
barcode_indices = {barcode: i for i, barcode in enumerate(barcodes)}
feature_indices = {feature: i for i, feature in enumerate(features)}
matrix = sparse.lil_matrix((len(barcodes), len(features)), dtype=np.float32)
for (barcode, feature), count in df_counts.groupby([barcode_column, feature_column], sort=False,
observed=True).size().items():
matrix[barcode_indices[barcode], feature_indices[feature]] = count
return matrix.tocsr()
[docs]def split_counts(
df_counts: pd.DataFrame,
barcodes: List[str],
features: List[str],
barcode_column: str = 'barcode',
feature_column: str = 'GX',
conversions: FrozenSet[str] = frozenset({'TC'})
) -> Tuple[sparse.csr_matrix, sparse.csr_matrix]:
"""Split counts dataframe into two count matrices by a column.
Args:
df_counts: Counts dataframe
barcodes: List of barcodes that will map to the rows
features: List of features (i.e. genes) that will map to the columns
barcode_column: Column in counts dataframe to use as barcodes
feature_column: Column in counts dataframe to use as features
conversions: Conversion(s) in question
Returns:
count matrix of `conversion==0`, count matrix of `conversion>0`
"""
matrix_unlabeled = counts_to_matrix(
df_counts[(df_counts[list(conversions)] == 0).all(axis=1)],
barcodes,
features,
barcode_column=barcode_column,
feature_column=feature_column
)
matrix_labeled = counts_to_matrix(
df_counts[(df_counts[list(conversions)] > 0).any(axis=1)],
barcodes,
features,
barcode_column=barcode_column,
feature_column=feature_column
)
return matrix_unlabeled, matrix_labeled
[docs]def split_matrix_pi(
matrix: Union[np.ndarray, sparse.spmatrix], pis: Dict[Tuple[str, str], float], barcodes: List[str],
features: List[str]
) -> Tuple[sparse.csr_matrix, sparse.csr_matrix, sparse.csr_matrix]:
"""Split the given matrix based on provided fraction of new RNA.
Args:
matrix: Matrix to split
pis: Dictionary containing pi estimates
barcodes: All barcodes
features: All features (i.e. genes)
Returns:
matrix of pis, matrix of unlabeled RNA, matrix of labeled RNA
"""
unlabeled_matrix = sparse.lil_matrix((len(barcodes), len(features)), dtype=np.float32)
labeled_matrix = sparse.lil_matrix((len(barcodes), len(features)), dtype=np.float32)
pi_matrix = sparse.lil_matrix((len(barcodes), len(features)), dtype=np.float32)
barcode_indices = {barcode: i for i, barcode in enumerate(barcodes)}
feature_indices = {feature: i for i, feature in enumerate(features)}
for (barcode, gx), pi in pis.items():
if barcode not in barcode_indices or gx not in feature_indices:
continue
try:
pi = float(pi)
except ValueError:
continue
row, col = barcode_indices[barcode], feature_indices[gx]
val = matrix[row, col]
unlabeled_matrix[row, col] = val * (1 - pi)
labeled_matrix[row, col] = val * pi
pi_matrix[row, col] = pi
return pi_matrix.tocsr(), unlabeled_matrix.tocsr(), labeled_matrix.tocsr()
[docs]def split_matrix_alpha(
unlabeled_matrix: Union[np.ndarray, sparse.spmatrix], labeled_matrix: Union[np.ndarray, sparse.spmatrix],
alphas: Dict[str, float], barcodes: List[str]
) -> Tuple[sparse.csr_matrix, sparse.csr_matrix, sparse.csr_matrix]:
"""Split the given matrix based on provided fraction of new RNA.
Args:
unlabeled_matrix: unlabeled matrix
labeled_matrix: Labeled matrix
alphas: Dictionary containing alpha estimates
barcodes: All barcodes
features: All features (i.e. genes)
Returns:
matrix of pis, matrix of unlabeled RNA, matrix of labeled RNA
"""
total_matrix = unlabeled_matrix + labeled_matrix
est_unlabeled_matrix = sparse.lil_matrix(unlabeled_matrix.shape, dtype=np.float32)
est_labeled_matrix = sparse.lil_matrix(labeled_matrix.shape, dtype=np.float32)
barcode_indices = {barcode: i for i, barcode in enumerate(barcodes)}
module = np
if sparse.issparse(total_matrix):
module = sparse
for barcode, alpha in alphas.items():
if barcode not in barcode_indices:
continue
try:
alpha = np.clip(float(alpha), 0, 1)
except ValueError:
continue
row = barcode_indices[barcode]
l_alpha = labeled_matrix[row] / alpha
est_labeled_matrix[row] = module.vstack((l_alpha, total_matrix[row])).min(axis=0)
est_unlabeled_matrix[row] = total_matrix[row] - est_labeled_matrix[row]
return est_unlabeled_matrix.tocsr(), est_labeled_matrix.tocsr()
[docs]def results_to_adata(
df_counts: pd.DataFrame,
conversions: FrozenSet[FrozenSet[str]] = frozenset({frozenset({'TC'})}),
gene_infos: Optional[dict] = None,
pis: Optional[Dict[str, Dict[Tuple[str, ...], Dict[Tuple[str, str], float]]]] = None,
alphas: Optional[Dict[str, Dict[Tuple[str, ...], Dict[str, float]]]] = None,
) -> anndata.AnnData:
"""Compile all results to a single anndata.
Args:
df_counts: Counts dataframe, with complemented reverse strand bases
conversions: Conversion(s) in question
gene_infos: Dictionary containing gene information. If this is not provided,
the function assumes gene names are already in the Counts dataframe.
pis: Dictionary of estimated pis
alphas: Dictionary of estimated alphas
Returns:
Anndata containing all results
"""
if pis is not None and alphas is not None:
raise Exception('Only one of `pis` or `alphas` may be provided.')
pis = pis or {}
alphas = alphas or {}
all_conversions = sorted(flatten_iter(conversions))
transcriptome_exists = df_counts['transcriptome'].any()
transcriptome_only = df_counts['transcriptome'].all()
velocities = df_counts['velocity'].unique()
barcodes = sorted(df_counts['barcode'].unique())
features = sorted(df_counts['GX'].unique())
add_names = gene_infos is not None
obs = pd.DataFrame(index=pd.Series(barcodes, name='barcode'))
var = pd.DataFrame(index=pd.Series(features, name='gene_id' if add_names else 'gene_name'))
if add_names:
var['gene_name'] = pd.Categorical([gene_infos.get(feature, {}).get('gene_name') for feature in features])
df_counts_transcriptome = df_counts[df_counts['transcriptome']]
matrix = counts_to_matrix(df_counts_transcriptome, barcodes, features)
layers = {}
# Transcriptome reads
if transcriptome_exists:
for convs in conversions:
convs = sorted(convs)
# Ignore reads that have other conversions
other_convs = list(set(all_conversions) - set(convs))
join = '_'.join(convs)
# Counts for transcriptome reads (i.e. X_unlabeled + X_labeled = X)
layers[f'X_n_{join}'], layers[f'X_l_{join}'] = split_counts(
df_counts_transcriptome[(df_counts_transcriptome[other_convs] == 0).all(axis=1)],
barcodes,
features,
conversions=convs
)
pi = pis.get('transcriptome', {}).get(tuple(convs))
if pi is not None:
(
layers[f'X_{join}_pi_g'],
layers[f'X_n_{join}_est'],
layers[f'X_l_{join}_est'],
) = split_matrix_pi(layers[f'X_n_{join}'] + layers[f'X_l_{join}'], pi, barcodes, features)
alpha = alphas.get('transcriptome', {}).get(tuple(convs))
if alpha is not None:
obs[f'X_{join}_alpha'] = obs.index.map(alpha)
(
layers[f'X_n_{join}_est'],
layers[f'X_l_{join}_est'],
) = split_matrix_alpha(layers[f'X_n_{join}'], layers[f'X_l_{join}'], alpha, barcodes)
else:
logger.warning('No reads were assigned to `transcriptome`')
# Total reads
if not transcriptome_only:
layers['total'] = counts_to_matrix(df_counts, barcodes, features)
for convs in conversions:
convs = sorted(convs)
other_convs = list(set(all_conversions) - set(convs))
join = '_'.join(convs)
layers[f'unlabeled_{join}'], layers[f'labeled_{join}'] = split_counts(
df_counts[(df_counts[other_convs] == 0).all(axis=1)], barcodes, features, conversions=convs
)
pi = pis.get('total', {}).get(tuple(convs))
if pi is not None:
(
layers[f'total_{join}_pi_g'],
layers[f'unlabeled_{join}_est'],
layers[f'labeled_{join}_est'],
) = split_matrix_pi(layers[f'unlabeled_{join}'] + layers[f'labeled_{join}'], pi, barcodes, features)
alpha = alphas.get('total', {}).get(tuple(convs))
if alpha is not None:
obs[f'total_{join}_alpha'] = obs.index.map(alpha)
(
layers[f'unlabeled_{join}_est'],
layers[f'labeled_{join}_est'],
) = split_matrix_alpha(layers[f'unlabeled_{join}'], layers[f'labeled_{join}'], alpha, barcodes)
# Velocity reads
for key in velocities:
if key == 'unassigned':
continue
df_counts_velocity = df_counts[df_counts['velocity'] == key]
layers[key] = counts_to_matrix(df_counts_velocity, barcodes, features)
if key in config.VELOCITY_BLACKLIST:
continue
for convs in conversions:
convs = sorted(convs)
other_convs = list(set(all_conversions) - set(convs))
join = '_'.join(convs)
layers[f'{key[0]}n_{join}'], layers[f'{key[0]}l_{join}'] = split_counts(
df_counts_velocity[(df_counts_velocity[other_convs] == 0).all(axis=1)],
barcodes,
features,
conversions=convs
)
pi = pis.get(key, {}).get(tuple(convs))
if pi is not None:
(
layers[f'{key}_{join}_pi_g'],
layers[f'{key[0]}n_{join}_est'],
layers[f'{key[0]}l_{join}_est'],
) = split_matrix_pi(layers[f'{key[0]}n_{join}'] + layers[f'{key[0]}l_{join}'], pi, barcodes, features)
alpha = alphas.get(key, {}).get(tuple(convs))
if alpha is not None:
obs[f'{key}_{join}_alpha'] = obs.index.map(alpha)
(
layers[f'{key[0]}n_{join}_est'],
layers[f'{key[0]}l_{join}_est'],
) = split_matrix_alpha(layers[f'{key[0]}n_{join}'], layers[f'{key[0]}l_{join}'], alpha, barcodes)
# Construct anndata
return anndata.AnnData(X=matrix, obs=obs, var=var, layers=layers)
[docs]def patch_mp_connection_bpo_17560():
"""Apply PR-10305 / bpo-17560 connection send/receive max size update
See the original issue at https://bugs.python.org/issue17560 and
https://github.com/python/cpython/pull/10305 for the pull request.
This only supports Python versions 3.3 - 3.7, this function
does nothing for Python versions outside of that range.
Taken from https://stackoverflow.com/a/47776649
"""
patchname = "Multiprocessing connection patch for bpo-17560"
if not (3, 3) < sys.version_info < (3, 8):
return
logger.debug(f'Applying {patchname}')
from multiprocessing.connection import Connection
orig_send_bytes = Connection._send_bytes
orig_recv_bytes = Connection._recv_bytes
if (orig_send_bytes.__code__.co_filename == __file__ and orig_recv_bytes.__code__.co_filename == __file__):
logger.info(f'{patchname} already applied, skipping')
return
@functools.wraps(orig_send_bytes)
def send_bytes(self, buf):
n = len(buf)
if n > 0x7fffffff:
pre_header = struct.pack("!i", -1)
header = struct.pack("!Q", n)
self._send(pre_header)
self._send(header)
self._send(buf)
else:
orig_send_bytes(self, buf)
@functools.wraps(orig_recv_bytes)
def recv_bytes(self, maxsize=None):
buf = self._recv(4)
size, = struct.unpack("!i", buf.getvalue())
if size == -1:
buf = self._recv(8)
size, = struct.unpack("!Q", buf.getvalue())
if maxsize is not None and size > maxsize:
return None
return self._recv(size)
Connection._send_bytes = send_bytes
Connection._recv_bytes = recv_bytes
[docs]def dict_to_matrix(d: Dict[Tuple[str, str], float], rows: List[str], columns: List[str]) -> sparse.csr_matrix:
"""Convert a dictionary to a matrix.
Args:
d: Dictionary to convert
rows: Row names
columns: Column names
Returns:
A sparse matrix
"""
# Transform to index for fast lookup
row_indices = {col: i for i, col in enumerate(columns)}
column_indices = {row: i for i, row in enumerate(rows)}
matrix = sparse.lil_matrix((len(rows), len(columns)), dtype=np.float32)
for (row, col), value in d.items():
matrix[row_indices[row], column_indices[col]] = value
return matrix.tocsr()