import datetime as dt
import os
from typing import FrozenSet, List, Optional
import ngs_tools as ngs
import pandas as pd
from typing_extensions import Literal
from . import config, constants, preprocessing, utils
from .logging import logger
from .stats import Stats
@logger.namespaced('count')
[docs]def count(
bam_path: str,
gtf_path: str,
out_dir: str,
strand: Literal['forward', 'reverse', 'unstranded'] = 'forward',
umi_tag: Optional[str] = None,
barcode_tag: Optional[str] = None,
gene_tag: str = 'GX',
barcodes: Optional[List[str]] = None,
control: bool = False,
quality: int = 27,
conversions: FrozenSet[FrozenSet[str]] = frozenset({frozenset({'TC'})}),
snp_threshold: Optional[float] = None,
snp_min_coverage: int = 1,
snp_csv: Optional[str] = None,
n_threads: int = 8,
temp_dir: Optional[str] = None,
velocity: bool = True,
strict_exon_overlap: bool = False,
dedup_mode: Literal['auto', 'exon', 'conversion'] = 'auto',
by_name: bool = False,
nasc: bool = False,
overwrite: bool = False,
):
"""Main interface for the `count` command.
Args:
bam_path: Path to BAM
gtf_path: Path to GTF
out_dir: Path to output directory
strand: Strandedness of technology
umi_tag: BAM tag to use as UMIs
barcode_tag: BAM tag to use as barcodes
gene_tag: BAM tag to use as genes
barcodes: List of barcodes to consider
control: Whether this is a control sample
quality: Quality threshold in detecting conversions
conversions: Set of conversions to quantify
snp_threshold: Call genomic locations that have greater than this proportion of
specific conversions as a SNP
snp_min_coverage: Only consider genomic locations with at least this many mapping
reads for SNP calling
snp_csv: CSV containing SNPs
n_threads: Number of threads to use
temp_dir: Temporary directory
velocity: Whether to quantify spliced/unspliced RNA
strict_exon_overlap: Whether spliced/unspliced RNA quantification is strict
dedup_mode: UMI deduplication mode
by_name: Whether to group counts by gene name instead of ID
nasc: Whether to match NASC-seq pipeline behavior
overwrite: Overwrite existing files
"""
stats = Stats()
stats.start()
stats_path = os.path.join(
out_dir, f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json'
)
os.makedirs(out_dir, exist_ok=True)
all_conversions = sorted(utils.flatten_iter(conversions))
# Check memory.
available_memory = utils.get_available_memory()
if available_memory < config.RECOMMENDED_MEMORY:
logger.warning(
f'There is only {available_memory / (1024 ** 3):.2f} GB of free memory on the machine. '
f'It is highly recommended to have at least {config.RECOMMENDED_MEMORY // (1024 ** 3)} GB '
'free when running dynast. Continuing may cause dynast to crash with an out-of-memory error.'
)
# Sort and index bam.
bam_path = preprocessing.sort_and_index_bam(
bam_path, '{}.sortedByCoord{}'.format(*os.path.splitext(bam_path)), n_threads=n_threads
)
# Check BAM tags
tags = preprocessing.get_tags_from_bam(bam_path, config.BAM_PEEK_READS, n_threads=n_threads)
required_tags = config.BAM_REQUIRED_TAGS.copy()
if barcode_tag:
required_tags.append(barcode_tag)
elif config.BAM_BARCODE_TAG in tags:
logger.warning(
f'BAM contains reads with {config.BAM_BARCODE_TAG} tag. Are you sure '
f'you didn\'t mean to provide `--barcode-tag {config.BAM_BARCODE_TAG}`?'
)
elif config.BAM_READGROUP_TAG in tags:
logger.warning(
f'BAM contains reads with {config.BAM_READGROUP_TAG} tag. Are you sure '
f'you didn\'t mean to provide `--barcode-tag {config.BAM_READGROUP_TAG}`?'
)
if umi_tag:
required_tags.append(umi_tag)
elif config.BAM_UMI_TAG in tags:
logger.warning(
f'BAM contains reads with {config.BAM_UMI_TAG} tag. Are you sure '
f'you didn\'t mean to provide `--umi-tag {config.BAM_UMI_TAG}`?'
)
if gene_tag:
required_tags.append(gene_tag)
elif config.BAM_GENE_TAG in tags:
logger.warning(
f'BAM contains reads with {config.BAM_GENE_TAG} tag. Are you sure '
f'you didn\'t mean to provide `--gene-tag {config.BAM_GENE_TAG}`?'
)
missing_tags = set(required_tags) - tags
if missing_tags:
raise Exception(
f'First {config.BAM_PEEK_READS} reads in the BAM do not contain the following required tags: '
f'{", ".join(missing_tags)}. '
)
# Check BAM alignments
if preprocessing.check_bam_contains_secondary(bam_path, config.BAM_PEEK_READS, n_threads=n_threads):
logger.warning(
'BAM contains secondary alignments, which will be ignored. Only primary '
'alignments are considered.'
)
if preprocessing.check_bam_contains_unmapped(bam_path):
logger.warning('BAM contains unmapped reads, which will be ignored.')
if preprocessing.check_bam_contains_duplicate(bam_path, config.BAM_PEEK_READS, n_threads=n_threads):
logger.warning('BAM contains duplicate reads, which will be ignored.')
# Parse BAM and save results
conversions_path = os.path.join(out_dir, constants.CONVERSIONS_FILENAME)
index_path = os.path.join(out_dir, constants.CONVERSIONS_INDEX_FILENAME)
alignments_path = os.path.join(out_dir, constants.ALIGNMENTS_FILENAME)
genes_path = os.path.join(out_dir, constants.GENES_FILENAME)
conversions_required = [conversions_path, index_path, alignments_path, genes_path]
bam_parsed = False
if not utils.all_exists(*conversions_required) or overwrite:
logger.info('Parsing gene and transcript information from GTF')
gene_infos, transcript_infos = ngs.gtf.genes_and_transcripts_from_gtf(gtf_path, use_version=False)
utils.write_pickle(gene_infos, genes_path)
logger.info(f'Parsing read conversion information from BAM to {conversions_path}')
conversions_path, alignments_path, index_path = preprocessing.parse_all_reads(
bam_path,
conversions_path,
alignments_path,
index_path,
gene_infos,
transcript_infos,
strand=strand,
umi_tag=umi_tag,
barcode_tag=barcode_tag,
gene_tag=gene_tag,
barcodes=barcodes,
n_threads=n_threads,
temp_dir=temp_dir,
nasc=nasc,
velocity=velocity,
strict_exon_overlap=strict_exon_overlap,
)
bam_parsed = True
else:
logger.warning('Skipped BAM parsing because files already exist. Use `--overwrite` to re-parse the BAM.')
gene_infos = utils.read_pickle(genes_path)
# Check consistency of alignments
small_alignments = preprocessing.read_alignments(alignments_path, nrows=config.BAM_PEEK_READS)
barcode_is_na = small_alignments['barcode'] == 'NA'
umi_is_na = small_alignments['umi'] == 'NA'
if barcode_tag and barcode_is_na.any():
raise Exception(
"`--barcode-tag` was provided but existing files contain NA barcodes. "
'Re-run `dynast count` with `--overwrite` to fix this inconsistency.'
)
elif not barcode_tag and (~barcode_is_na).any():
raise Exception(
"`--barcode-tag` was not provided but existing files contain barcodes. "
'Re-run `dynast count` with `--overwrite` to fix this inconsistency.'
)
if umi_tag and umi_is_na.any():
raise Exception(
"`--umi-tag` was provided but existing files contain NA UMIs. "
'Re-run `dynast count` with `--overwrite` to fix this inconsistency.'
)
elif not umi_tag and (~umi_is_na).any():
raise Exception(
"`--umi-tag` was not provided but existing files contain UMIs. "
'Re-run `dynast count` with `--overwrite` to fix this inconsistency.'
)
# Save conversions
redo_snp = False
convs_path = os.path.join(out_dir, constants.CONVS_FILENAME)
if utils.all_exists(convs_path):
prev_conversions = utils.read_pickle(convs_path)
if conversions != prev_conversions:
logger.warning(f'Conversions changed from {prev_conversions} in previous run to {conversions}.')
redo_snp = True
else:
redo_snp = True
# Detect SNPs
coverage_path = os.path.join(out_dir, constants.COVERAGE_FILENAME)
snps_path = os.path.join(out_dir, constants.SNPS_FILENAME)
snp_required = [convs_path, coverage_path, snps_path]
if snp_threshold is not None:
if not control:
# If SNP filtering is used with a non-control sample, there are some
# inconsistencies in what particular reads (among duplicated ones) are used for
# coverage/SNP-calling and conversion-calling.
logger.warning(
"Reads used for coverage calculation and SNP-calling may differ from those "
"that will be used for conversion-calling. This is due to using the "
"`--snp-threshold` option on a non `--control` sample."
)
if not utils.all_exists(*snp_required) or redo_snp or bam_parsed:
logger.info('Selecting alignments to use for SNP detection')
alignments = preprocessing.select_alignments(preprocessing.read_alignments(alignments_path))
snp_conversions = set(
all_conversions + [preprocessing.CONVERSION_COMPLEMENT[conv] for conv in all_conversions]
)
logger.info(f'Selecting genomic locations with {snp_conversions} conversions in forward strand.')
df_conversions = preprocessing.read_conversions(conversions_path)
# Subset to selected alignments.
df_conversions = df_conversions[[
key in alignments for key in df_conversions[['read_id', 'index']].itertuples(index=False, name=None)
]]
# Subset to conversions of interest
df_conversions = df_conversions.loc[df_conversions['conversion'].isin(snp_conversions),
['contig', 'genome_i']]
logger.info(f'Calculating coverage and outputting to {coverage_path}')
coverage_path = preprocessing.calculate_coverage(
bam_path,
{
contig: set(df_part['genome_i'])
for contig, df_part in df_conversions.groupby('contig', sort=False, observed=True)
},
coverage_path,
alignments=alignments,
umi_tag=umi_tag,
barcode_tag=barcode_tag,
gene_tag=gene_tag,
barcodes=barcodes,
temp_dir=temp_dir,
velocity=velocity,
)
coverage = preprocessing.read_coverage(coverage_path)
logger.info(f'Detecting SNPs with threshold {snp_threshold} to {snps_path}')
snps_path = preprocessing.detect_snps(
conversions_path,
index_path,
coverage,
snps_path,
alignments=alignments,
conversions=snp_conversions,
quality=quality,
threshold=snp_threshold,
min_coverage=snp_min_coverage,
n_threads=n_threads,
)
utils.write_pickle(conversions, convs_path)
else:
logger.warning(
'Skipped SNP detection because files already exist. '
f'Remove {convs_path} to run SNP detection again.'
)
else:
utils.write_pickle(conversions, convs_path)
# Count conversions and calculate mutation rates
counts_path = os.path.join(out_dir, f'{constants.COUNTS_PREFIX}_{"_".join(all_conversions)}.csv')
logger.info(f'Counting conversions to {counts_path}')
snps = utils.merge_dictionaries(
preprocessing.read_snps(snps_path) if snp_threshold is not None else {},
preprocessing.read_snp_csv(snp_csv) if snp_csv else {},
f=set.union,
default=set,
)
# Figure out deduplication priority if set to auto
if umi_tag and dedup_mode == 'auto':
if config.BAM_CONSENSUS_READ_COUNT_TAG in tags:
dedup_mode = 'exon'
logger.info(f'Auto-detected deduplication mode: `{dedup_mode}`. Exonic reads will be prioritized. ')
else:
dedup_mode = 'conversion'
logger.info(
f'Auto-detected deduplication mode: `{dedup_mode}`. '
f'Reads with at least one of {all_conversions} conversions will be prioritized.'
)
counts_path = preprocessing.count_conversions(
conversions_path,
alignments_path,
index_path,
counts_path,
gene_infos,
barcodes=barcodes,
snps=snps,
quality=quality,
conversions=all_conversions,
dedup_use_conversions=dedup_mode == 'conversion',
n_threads=n_threads,
temp_dir=temp_dir
)
df_counts_uncomplemented = preprocessing.read_counts(counts_path)
df_counts_complemented = preprocessing.complement_counts(df_counts_uncomplemented, gene_infos)
if barcodes:
count_barcodes = set(df_counts_uncomplemented['barcode'])
missing_barcodes = barcodes - count_barcodes
if missing_barcodes:
logger.warning(
f'{len(missing_barcodes)} barcodes are missing from {counts_path}. '
'Re-run `dynast count` with `--overwrite` to fix this inconsistency. '
'Otherwise, all missing barcodes will be ignored. '
)
# Calculate mutation rates for each group
transcriptome_exists = df_counts_complemented['transcriptome'].any()
velocities = df_counts_complemented['velocity'].unique()
df_counts = df_counts_uncomplemented if nasc else df_counts_complemented
rates_paths = {}
rates_all_path = os.path.join(out_dir, f'{constants.RATES_PREFIX}.csv')
logger.info(f'Calculating mutation rates for all reads to {rates_all_path}.')
rates_all_path = preprocessing.calculate_mutation_rates(df_counts, rates_all_path, group_by=['barcode'])
rates_paths['all'] = rates_all_path
if transcriptome_exists:
rates_X_path = os.path.join(out_dir, f'{constants.RATES_PREFIX}_X.csv')
logger.info(f'Calculating mutation rates for X reads to {rates_X_path}.')
rates_X_path = preprocessing.calculate_mutation_rates(
df_counts[df_counts['transcriptome']], rates_X_path, group_by=['barcode']
)
rates_paths['X'] = rates_X_path
for key in velocities:
if key in config.VELOCITY_BLACKLIST:
continue
rates_velocity_path = os.path.join(out_dir, f'{constants.RATES_PREFIX}_{key}.csv')
logger.info(f'Calculating mutation rates for {key} reads to {rates_velocity_path}.')
rates_velocity_path = preprocessing.calculate_mutation_rates(
df_counts[df_counts['velocity'] == key], rates_velocity_path, group_by=['barcode']
)
rates_paths[key] = rates_velocity_path
if control:
logger.info('Downstream processing skipped for controls')
if snp_threshold is not None:
logger.info(f'Use `--snp-csv {snps_path}` to run test samples')
else:
if by_name:
logger.info('Collapsing counts by gene name.')
df_counts_complemented['GX'] = df_counts_complemented['GX'].apply(
lambda gx: gene_infos[gx]['gene_name'] or gx
)
adata_path = os.path.join(out_dir, constants.ADATA_FILENAME)
logger.info(f'Combining results into Anndata object at {adata_path}')
adata = utils.results_to_adata(
df_counts_complemented, conversions, gene_infos=gene_infos if not by_name else None
)
# Add rates to obsm
for key, rates_path in rates_paths.items():
obsm = 'rates'
if key != 'all':
obsm += f'_{key}'
logger.debug(f'Adding {obsm} to obsm.')
rates = preprocessing.read_rates(rates_path)
if isinstance(rates, pd.Series):
expanded = pd.DataFrame(columns=rates.index)
for obs in adata.obs_names:
expanded.loc[obs] = rates
adata.obsm[obsm] = expanded
else:
adata.obsm[obsm] = rates.set_index('barcode').reindex(adata.obs_names, fill_value=0.0)
adata.write(adata_path, compression='gzip')
stats.end()
stats.save(stats_path)