import multiprocessing
import re
import shutil
import tempfile
from functools import partial
from typing import Dict, FrozenSet, List, Optional, Set, Tuple
import ngs_tools as ngs
import numpy as np
import pandas as pd
from typing_extensions import Literal
from .. import config, utils
from ..logging import logger
[docs]CONVERSIONS_PARSER = re.compile(
r'''^
(?P<read_id>[^,]*),
(?P<index>[^,]*),
(?P<contig>[^,]*),
(?P<genome_i>[^,]*),
(?P<conversion>[^,]*),
(?P<quality>[^,]*)\n
$''', re.VERBOSE
)
[docs]ALIGNMENTS_PARSER = re.compile(
r'''^
(?P<read_id>[^,]*),
(?P<index>[^,]*),
(?P<barcode>[^,]*),
(?P<umi>[^,]*),
(?P<GX>[^,]*),
(?P<A>[^,]*),
(?P<C>[^,]*),
(?P<G>[^,]*),
(?P<T>[^,]*),
(?P<velocity>[^,]*),
(?P<transcriptome>[^,]*),
(?P<score>[^,]*)\n
$''', re.VERBOSE
)
[docs]CONVERSION_IDX = {
'AC': 0,
'AG': 1,
'AT': 2,
'CA': 3,
'CG': 4,
'CT': 5,
'GA': 6,
'GC': 7,
'GT': 8,
'TA': 9,
'TC': 10,
'TG': 11,
}
[docs]BASE_IDX = {
'A': 12,
'C': 13,
'G': 14,
'T': 15,
}
[docs]CONVERSION_COMPLEMENT = {
conversion: ngs.sequence.complement_sequence(conversion, reverse=False)
for conversion in CONVERSION_IDX.keys()
}
[docs]CONVERSION_COLUMNS = sorted(CONVERSION_IDX.keys())
[docs]BASE_COLUMNS = sorted(BASE_IDX.keys())
[docs]COLUMNS = CONVERSION_COLUMNS + BASE_COLUMNS
[docs]CSV_COLUMNS = ['read_id', 'barcode', 'umi', 'GX'] + COLUMNS + ['velocity', 'transcriptome', 'score']
[docs]def read_counts(counts_path: str, *args, **kwargs) -> pd.DataFrame:
"""Read counts CSV as a pandas dataframe.
Any additional arguments and keyword arguments are passed to `pandas.read_csv`.
Args:
counts_path: Path to CSV
Returns:
Counts dataframe
"""
dtypes = {
'read_id': 'string',
'barcode': 'category',
'umi': 'category',
'GX': 'category',
'velocity': 'category',
'transcriptome': bool,
'score': np.uint16,
**{column: np.uint8
for column in COLUMNS}
}
return pd.read_csv(counts_path, dtype=dtypes, na_filter=False, *args, **kwargs)
[docs]def complement_counts(df_counts: pd.DataFrame, gene_infos: dict) -> pd.DataFrame:
"""Complement the counts in the counts dataframe according to gene strand.
Args:
df_counts: Counts dataframe
gene_infos: Dictionary containing gene information, as returned by
`preprocessing.gtf.parse_gtf`
Returns:
counts dataframe with counts complemented for reads mapping to genes on the reverse strand
"""
# Extract columns that do not need to be complemented
other_columns = []
for col in df_counts.columns:
if col in COLUMNS:
continue
other_columns.append(col)
forward_strand = df_counts['GX'].map(lambda gx: gene_infos[gx]['strand']) == '+'
columns = other_columns + COLUMNS
df_forward = df_counts[forward_strand][columns]
df_reverse = df_counts[~forward_strand][columns]
df_reverse.columns = other_columns + CONVERSION_COLUMNS[::-1] + BASE_COLUMNS[::-1]
df_reverse = df_reverse[columns]
return pd.concat((df_forward, df_reverse), verify_integrity=True)
[docs]def subset_counts(
df_counts: pd.DataFrame,
key: Literal['total', 'transcriptome', 'spliced', 'unspliced'],
) -> pd.DataFrame:
"""Subset the given counts DataFrame to only contain reads of the desired key.
Args:
df_count: Counts dataframe
key: Read types to subset
Returns:s
Subset dataframe
"""
if key == 'transcriptome':
df_counts = df_counts[df_counts['transcriptome']]
if key in ('spliced', 'unspliced'):
df_counts = df_counts[df_counts['velocity'] == key]
return df_counts
[docs]def drop_multimappers(df_counts: pd.DataFrame, conversions: Optional[FrozenSet[str]] = None) -> pd.DataFrame:
"""Drop multimappings that have the same read ID where
* some map to the transcriptome while some do not -- drop non-transcriptome alignments
* none map to the transcriptome AND aligned to multiple genes -- drop all
* none map to the transcriptome AND assigned multiple velocity types -- set to ambiguous
TODO: This function can probably be removed because BAM parsing only considers
primary alignments now.
Args:
df_counts: Counts dataframe
conversions: Conversions to prioritize
Returns:
Counts dataframe with multimappers appropriately filtered
"""
columns = list(df_counts.columns)
convs = list(conversions) if conversions is not None else CONVERSION_COLUMNS
df_counts['conversion_sum'] = df_counts[convs].sum(axis=1)
df_sorted = df_counts.sort_values(['transcriptome', 'score', 'conversion_sum']).drop(columns='conversion_sum')
df_counts.drop(columns='conversion_sum', inplace=True)
df_sorted['not_transcriptome'] = ~df_sorted['transcriptome']
read_id_grouped = df_sorted.groupby('read_id', sort=False, observed=True)
# None map to the transcriptome
not_transcriptome = read_id_grouped['not_transcriptome'].all()
not_transcriptome_read_ids = not_transcriptome.index[not_transcriptome]
# Assigned to multiple genes
multigene = read_id_grouped['GX'].nunique() > 1
multigene_read_ids = multigene.index[multigene]
# Assigned to multiple velocity types
multivelocity = read_id_grouped['velocity'].nunique() > 1
multivelocity_read_ids = multivelocity.index[multivelocity]
# Rule 3. Note that we need to add ambiguous category if it doesn't exist.
not_transcriptome_multivelocity_read_ids = not_transcriptome_read_ids.intersection(multivelocity_read_ids)
if list(not_transcriptome_multivelocity_read_ids):
if 'ambiguous' not in df_sorted['velocity'].cat.categories:
df_sorted['velocity'].cat.add_categories('ambiguous', inplace=True)
df_sorted.loc[df_sorted['read_id'].isin(not_transcriptome_read_ids.intersection(multivelocity_read_ids)),
'velocity'] = 'ambiguous'
# Rule 2
df_sorted = df_sorted[~df_sorted['read_id'].isin(not_transcriptome_read_ids.intersection(multigene_read_ids))]
# Rule 1
df_deduplicated = df_sorted.drop_duplicates(
'read_id', keep='last'
).sort_values(['barcode', 'GX']).reset_index(drop=True)
return df_deduplicated[columns]
[docs]def deduplicate_counts(
df_counts: pd.DataFrame,
conversions: Optional[FrozenSet[str]] = None,
use_conversions: bool = True
) -> pd.DataFrame:
"""Deduplicate counts based on barcode, UMI, and gene.
The order of priority is the following.
1. If `use_conversions=True`, reads that have at least one such conversion
2. Reads that align to the transcriptome (exon only)
3. Reads that have highest alignment score
4. If `conversions` is provided, reads that have a larger sum of such conversions
If `conversions` is not provided, reads that have larger sum of all conversions
Args:
df_counts: Counts dataframe
conversions: Conversions to prioritize, defaults to `None`
use_conversions: Prioritize reads that have conversions first
Returns:
Deduplicated counts dataframe
"""
convs = list(conversions) if conversions is not None else CONVERSION_COLUMNS
df_counts['conversion_sum'] = df_counts[convs].sum(axis=1)
# Deduplication priority.
sort_order = ['transcriptome', 'score', 'conversion_sum']
to_remove = ['conversion_sum']
if use_conversions and conversions is not None:
df_counts['has_conversions'] = df_counts[list(conversions)].sum(axis=1) > 0
sort_order.insert(0, 'has_conversions')
to_remove.append('has_conversions')
# Sort by has desired conversion(s) last, transcriptome last,
# best alignment last, most conversions last
df_sorted = df_counts.sort_values(sort_order).drop(columns=to_remove)
# Restore input dataframe.
df_counts.drop(columns=to_remove, inplace=True)
return df_sorted.drop_duplicates(subset=['barcode', 'umi', 'GX'], keep='last').reset_index(drop=True)
[docs]def drop_multimappers_part(
counter: multiprocessing.Value,
lock: multiprocessing.Lock,
split_path: str,
out_path: str,
conversions: Optional[FrozenSet[str]] = None
) -> str:
"""Helper function to parallelize :func:`drop_multimappers`.
"""
drop_multimappers(
read_counts(split_path), conversions=conversions
)[CSV_COLUMNS[1:]].to_csv(
out_path, header=False, index=False
)
lock.acquire()
counter.value += 1
lock.release()
return out_path
[docs]def deduplicate_counts_part(
counter: multiprocessing.Value,
lock: multiprocessing.Lock,
split_path: str,
out_path: str,
conversions: Optional[FrozenSet[str]],
use_conversions: bool = True
):
"""Helper function to parallelize :func:`deduplicate_multimappers`.
"""
deduplicate_counts(
read_counts(split_path), conversions=conversions, use_conversions=use_conversions
)[CSV_COLUMNS[1:]].to_csv(
out_path, header=False, index=False
)
lock.acquire()
counter.value += 1
lock.release()
return out_path
[docs]def split_counts_by_velocity(df_counts: pd.DataFrame) -> Dict[str, pd.DataFrame]:
"""Split the given counts dataframe by the `velocity` column.
Args:
df_counts: Counts dataframe
Returns:
Dictionary containing `velocity` column values as keys and the subset dataframe as values
"""
dfs = {}
for velocity, df_part in df_counts.groupby('velocity', sort=False, observed=True):
dfs[velocity] = df_part.reset_index(drop=True)
logger.debug(f'Found the following velocity assignments: {", ".join(dfs.keys())}')
return dfs
[docs]def count_no_conversions(
alignments_path: str,
counter: multiprocessing.Value,
lock: multiprocessing.Lock,
index: List[Tuple[int, int, int]],
barcodes: Optional[List[str]] = None,
temp_dir: Optional[str] = None,
update_every: int = 10000,
) -> str:
"""Count reads that have no conversion.
Args:
alignments_path: Alignments CSV path
counter: Counter that keeps track of how many reads have been processed
lock: Semaphore for the `counter` so that multiple processes do not
modify it at the same time
index: Index for conversions CSV
barcodes: List of barcodes to be considered. All barcodes are considered if not provided
temp_dir: Path to temporary directory
update_every: Update the counter every this many reads
Returns:
Path to temporary counts CSV
"""
count_path = utils.mkstemp(dir=temp_dir)
positions = set(tup[2] for tup in index)
n = 0
with open(alignments_path, 'r') as f, open(count_path, 'w') as out:
f.readline() # header
while True:
pos = f.tell()
line = f.readline()
if not line:
break
n += 1
if n == update_every:
lock.acquire()
counter.value += update_every
lock.release()
n = 0
if pos in positions:
continue
groups = ALIGNMENTS_PARSER.match(line).groupdict()
if barcodes and groups['barcode'] not in barcodes:
continue
out.write(
f'{groups["read_id"]},{groups["barcode"]},{groups["umi"]},{groups["GX"]},'
f'{",".join(groups.get(key, "0") for key in COLUMNS)},'
f'{groups["velocity"]},{groups["transcriptome"]},{groups["score"]}\n'
)
lock.acquire()
counter.value += n
lock.release()
return count_path
[docs]def count_conversions_part(
conversions_path: str,
alignments_path: str,
counter: multiprocessing.Value,
lock: multiprocessing.Lock,
index: List[Tuple[int, int, int]],
barcodes: Optional[List[str]] = None,
snps: Optional[Dict[str, Dict[str, Set[int]]]] = None,
quality: int = 27,
temp_dir: Optional[str] = None,
update_every: int = 10000,
) -> str:
"""Count the number of conversions of each read per barcode and gene, along with
the total nucleotide content of the region each read mapped to, also per barcode
and gene. This function is used exclusively for multiprocessing.
Args:
conversions_path: Path to conversions CSV
alignments_path: Path to alignments information about reads
counter: Counter that keeps track of how many reads have been processed
lock: Semaphore for the `counter` so that multiple processes do not
modify it at the same time
index: Index for conversions CSV
barcodes: List of barcodes to be considered. All barcodes are considered if not provided
snps: Dictionary of contig as keys and list of genomic positions as
values that indicate SNP locations
quality: Only count conversions with PHRED quality greater than this value
temp_dir: Path to temporary directory, defaults to `None`
update_every: Update the counter every this many reads
Returns:
Path to temporary counts CSV
"""
def is_snp(gx, conversion, contig, genome_i):
if not snps:
return False
return genome_i in snps.get(conversion, {}).get(contig, set())
count_path = utils.mkstemp(dir=temp_dir)
n = 0
with open(conversions_path, 'r') as f, open(alignments_path, 'r') as f_alignments, open(count_path, 'w') as out:
for pos, n_lines, pos2 in index:
f.seek(pos)
f_alignments.seek(pos2)
n += 1
if n == update_every:
lock.acquire()
counter.value += update_every
lock.release()
n = 0
alignment = ALIGNMENTS_PARSER.match(f_alignments.readline()).groupdict()
if barcodes and alignment['barcode'] not in barcodes:
continue
counts = [0] * (len(CONVERSION_IDX) + len(BASE_IDX))
for base, i in BASE_IDX.items():
counts[i] = alignment[base]
gx = alignment['GX']
for _ in range(n_lines):
groups = CONVERSIONS_PARSER.match(f.readline()).groupdict()
conversion = groups["conversion"]
if int(groups['quality']) > quality and not is_snp(gx, conversion, groups['contig'], int(
groups['genome_i'])):
counts[CONVERSION_IDX[conversion]] += 1
out.write(
f'{groups["read_id"]},{alignment["barcode"]},{alignment["umi"]},'
f'{alignment["GX"]},{",".join(str(c) for c in counts)},'
f'{alignment["velocity"]},{alignment["transcriptome"]},{alignment["score"]}\n'
)
lock.acquire()
counter.value += n
lock.release()
return count_path
[docs]def count_conversions(
conversions_path: str,
alignments_path: str,
index_path: str,
counts_path: str,
gene_infos: dict,
barcodes: Optional[List[str]] = None,
snps: Optional[Dict[str, Dict[str, Set[int]]]] = None,
quality: int = 27,
conversions: Optional[FrozenSet[str]] = None,
dedup_use_conversions: bool = True,
n_threads: int = 8,
temp_dir: Optional[str] = None
) -> str:
"""Count the number of conversions of each read per barcode and gene, along with
the total nucleotide content of the region each read mapped to, also per barcode.
When a duplicate UMI for a barcode is observed, the read with the greatest
number of conversions is selected.
Args:
conversions_path: Path to conversions CSV
alignments_path: Path to alignments information about reads
index_path: Path to conversions index
counts_path: Path to write counts CSV
gene_infos: Dictionary containing gene information, as returned by
`ngs.gtf.genes_and_transcripts_from_gtf`
barcodes: List of barcodes to be considered. All barcodes are considered if not provided
snps: Dictionary of contig as keys and list of genomic positions as
values that indicate SNP locations
conversions: Conversions to prioritize when deduplicating only applicable
for UMI technologies
dedup_use_conversions: Prioritize reads that have at least one conversion
when deduplicating
quality: Only count conversions with PHRED quality greater than this value
n_threads: Number of threads
temp_dir: Path to temporary directory
Returns:
Path to counts CSV
"""
# Load index
logger.debug(f'Loading index {index_path}')
index = utils.read_pickle(index_path)
# Split index into n contiguous pieces
logger.debug(f'Splitting indices into {n_threads} parts')
parts = utils.split_index(index, n=n_threads)
# Parse each part in a different process
logger.debug(f'Spawning {n_threads} processes')
total = len(index)
with open(alignments_path, 'r') as f:
for line in f:
total += 1
total -= 1
pool, counter, lock = utils.make_pool_with_counter(n_threads)
no_async_result = pool.apply_async(
partial(
count_no_conversions,
alignments_path,
counter,
lock,
index,
barcodes=barcodes,
temp_dir=tempfile.mkdtemp(dir=temp_dir)
)
)
async_result = pool.starmap_async(
partial(
count_conversions_part,
conversions_path,
alignments_path,
counter,
lock,
barcodes=barcodes,
snps=snps,
quality=quality,
temp_dir=tempfile.mkdtemp(dir=temp_dir)
), [(part,) for part in parts]
)
pool.close()
# Display progres bar
utils.display_progress_with_counter(counter, total, async_result, no_async_result, desc='counting')
pool.join()
# Combine csvs
combined_path = utils.mkstemp(dir=temp_dir)
logger.debug(f'Combining intermediate parts to {combined_path}')
with open(combined_path, 'wb') as out:
out.write(f'{",".join(CSV_COLUMNS)}\n'.encode())
for counts_part_path in async_result.get():
with open(counts_part_path, 'rb') as f:
shutil.copyfileobj(f, out)
with open(no_async_result.get(), 'rb') as f:
shutil.copyfileobj(f, out)
# Filter counts dataframe
logger.debug(f'Loading combined counts from {combined_path}')
df_counts = complement_counts(read_counts(combined_path), gene_infos)
umi = all(df_counts['umi'] != 'NA')
barcode_groupby = df_counts.groupby('barcode', sort=False, observed=True)
barcode_counts = dict(barcode_groupby.size())
split_paths = []
current_split_path = None
current_split_f = None
current_split_size = 0
# Split barcodes into approximately `config.COUNTS_SPLIT_THRESHOLD` bins.
# Note that a single barcode may have more than this many reads.
try:
for barcode, df_counts_barcode in barcode_groupby:
# Make its own split
if barcode_counts[barcode] > config.COUNTS_SPLIT_THRESHOLD:
split_path = utils.mkstemp(dir=temp_dir)
logger.debug(f'Splitting counts for barcode {barcode} to {split_path}')
df_counts_barcode.to_csv(split_path, index=False)
split_paths.append(split_path)
elif current_split_path is None:
current_split_path = utils.mkstemp(dir=temp_dir)
logger.debug(f'Splitting counts for residual barcodes to {current_split_path}')
current_split_f = open(current_split_path, 'w')
# Write header
df_counts_barcode.to_csv(current_split_f, index=False)
split_paths.append(current_split_path)
else:
# Don't write header
df_counts_barcode.to_csv(current_split_f, index=False, header=False)
# If we exceeded read threshold, close file & reset.
current_split_size += df_counts_barcode.shape[0]
if current_split_size > config.COUNTS_SPLIT_THRESHOLD:
current_split_f.close()
current_split_path = None
current_split_size = 0
finally:
if current_split_f is not None:
current_split_f.close()
del df_counts
logger.debug(f'Spawning {n_threads} processes')
pool, counter, lock = utils.make_pool_with_counter(n_threads)
paths = [(split_path, utils.mkstemp(dir=temp_dir)) for split_path in split_paths]
async_result = pool.starmap_async(
partial(
deduplicate_counts_part,
counter,
lock,
conversions=conversions,
use_conversions=dedup_use_conversions,
) if umi else partial(drop_multimappers_part, counter, lock, conversions=conversions), paths
)
pool.close()
# Display progres bar
utils.display_progress_with_counter(counter, len(split_paths), async_result, desc='filtering')
pool.join()
# Need to complement counts again (to revert to original)
with open(counts_path, 'w') as out:
out.write(f'{",".join(CSV_COLUMNS[1:])}\n')
for counts_part_path in async_result.get():
df_part = complement_counts(pd.read_csv(counts_part_path, names=CSV_COLUMNS[1:]), gene_infos)
df_part[CSV_COLUMNS[1:]].to_csv(out, header=False, index=False)
return counts_path