import array
import multiprocessing
import queue
from hashlib import sha256
from typing import Any, Dict, List, Optional
import ngs_tools as ngs
import numpy as np
import pysam
from tqdm import tqdm
from typing_extensions import Literal
from .. import config, utils
from ..logging import logger
from . import bam
[docs]BASES = ('A', 'C', 'G', 'T')
[docs]BASE_IDX = {base: i for i, base in enumerate(BASES)}
[docs]def call_consensus_from_reads(
reads: List[pysam.AlignedSegment],
header: pysam.AlignmentHeader,
quality: int = 27,
tags: Optional[Dict[str, Any]] = None,
) -> pysam.AlignedSegment:
"""Call a single consensus alignment given a list of aligned reads.
Reads must map to the same contig. Results are undefined otherwise.
Additionally, consensus bases are called only for positions that match
to the reference (i.e. no insertions allowed).
This function only sets the minimal amount of attributes such that the
alignment is valid. These include:
* read name -- SHA256 hash of the provided read names
* read sequence and qualities
* reference name and ID
* reference start
* mapping quality (MAPQ)
* cigarstring
* MD tag
* NM tag
* Not unmapped, paired, duplicate, qc fail, secondary, nor supplementary
The caller is expected to further populate the alignment
with additional tags, flags, and name.
Args:
reads: List of reads to call a consensus sequence from
header: header to use when creating the new pysam alignment
quality: quality threshold
tags: additional tags to set
Returns:
New pysam alignment of the consensus sequence
"""
if len(set(read.reference_name for read in reads)) > 1:
raise Exception("Can not call consensus from reads mapping to multiple contigs.")
# Pysam coordinates are [start, end)
left_pos = min(read.reference_start for read in reads)
right_pos = max(read.reference_end for read in reads)
length = right_pos - left_pos
# A consensus sequence is internally represented as a L x 4 matrix,
# where L is the length of the sequence and the columns correspond to
# each of the four bases. The values indicate the support of each base.
# It's possible to switch these to sparse matrices if memory becomes an issue.
sequence = np.zeros((length, len(BASES)), dtype=np.uint32)
reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved
deletions = 0
for read in reads:
read_sequence = read.query_sequence.upper()
read_qualities = read.query_qualities
for read_i, genome_i, _genome_base in read.get_aligned_pairs(matches_only=False, with_seq=True):
# Insertion
if genome_i is None or _genome_base is None:
continue
i = genome_i - left_pos
genome_base = _genome_base.upper()
if genome_base == 'N':
continue
# Deletion
if read_i is None:
if reference[i] < 0:
reference[i] = BASE_IDX[genome_base]
deletions += 1
continue
read_base = read_sequence[read_i]
if read_base == 'N':
continue
if reference[i] < 0:
reference[i] = BASE_IDX[genome_base]
sequence[i, BASE_IDX[read_base]] += read_qualities[read_i]
# Determine consensus
# Note that we ignore any insertions
consensus_length = (sequence > 0).any(axis=1).sum()
consensus = np.zeros(consensus_length, dtype=np.uint8)
qualities = np.zeros(consensus_length, dtype=np.uint8)
cigar = []
last_cigar_op = None
cigar_n = 0
md = []
md_n = 0
md_zero = True
md_del = False
nm = 0
consensus_i = 0
for i in range(length):
ref = reference[i]
# Region not present in read. MD tag only deals with aligned
# regions, so nothing else needs to be done.
cigar_op = 'N'
if ref >= 0:
seq = sequence[i]
# Deletion
if (seq == 0).all():
cigar_op = 'D'
if md_n > 0 or md_zero:
md.append(str(md_n))
md_n = 0
if not md_del:
md.append('^')
md.append(BASES[ref])
md_del = True
# Match
else:
md_del = False
# On ties, select reference if present. Otherwise, choose lexicographically.
base_q = seq.max()
if base_q < quality:
base = ref
else:
bases = (seq == base_q).nonzero()[0]
if len(bases) > 0 and ref in bases:
base = ref
else:
base = bases[0]
# We use the STAR convention of using M cigar operation to mean
# both matches AND mismatches, ignoring the X cigar operation exists.
cigar_op = 'M'
if ref == base:
md_n += 1
md_zero = False
else:
if md_n > 0 or md_zero:
md.append(str(md_n))
md_n = 0
md.append(BASES[ref])
md_zero = True
nm += 1
consensus[consensus_i] = base
qualities[consensus_i] = min(base_q, 42) # Clip to maximum PHRED score
consensus_i += 1
if cigar_op == last_cigar_op:
cigar_n += 1
else:
if last_cigar_op:
cigar.append(f'{cigar_n}{last_cigar_op}')
last_cigar_op = cigar_op
cigar_n = 1
md.append(str(md_n)) # MD tag always ends with a number
cigar.append(f'{cigar_n}{last_cigar_op}')
al = pysam.AlignedSegment(header)
al.query_name = sha256(''.join(read.query_name for read in reads).encode('utf-8')).hexdigest()
al.query_sequence = ''.join(BASES[i] for i in consensus)
al.query_qualities = array.array('B', qualities)
al.reference_name = reads[0].reference_name
al.reference_id = reads[0].reference_id
al.reference_start = left_pos
al.mapping_quality = 255
al.cigarstring = ''.join(cigar)
# Set tags
tags = tags or {}
tags.update({'MD': ''.join(md), 'NM': nm})
al.set_tags(list(tags.items()))
# Make sure these are False
al.is_unmapped = False
al.is_paired = False
al.is_duplicate = False
al.is_qcfail = False
al.is_secondary = False
al.is_supplementary = False
return al
[docs]def call_consensus_from_reads_process(reads, header, tags, strand=None, quality=27):
"""Helper function to call :func:`call_consensus_from_reads` from a subprocess."""
header = pysam.AlignmentHeader.from_dict(header)
reads = [pysam.AlignedSegment.fromstring(read, header) for read in reads]
consensus = call_consensus_from_reads(reads, header, quality=quality, tags=tags)
consensus.is_paired = False
if strand == '-':
consensus.is_reverse = True
return consensus.to_string()
[docs]def consensus_worker(args_q, results_q, *args, **kwargs):
"""Multiprocessing worker."""
while True:
try:
_args = args_q.get(timeout=1) # None means we are done.
except queue.Empty:
continue
if _args is None:
return
results_q.put(call_consensus_from_reads_process(*_args, *args, **kwargs))
[docs]def call_consensus(
bam_path: str,
out_path: str,
gene_infos: dict,
strand: Literal['forward', 'reverse', 'unstranded'] = 'forward',
umi_tag: Optional[str] = None,
barcode_tag: Optional[str] = None,
gene_tag: str = 'GX',
barcodes: Optional[List[str]] = None,
quality: int = 27,
add_RS_RI: bool = False,
temp_dir: Optional[str] = None,
n_threads: int = 8
) -> str:
"""Call consensus sequences from BAM.
Args:
bam_path: Path to BAM
out_path: Output BAM path
gene_infos: Gene information, as parsed from the GTF
strand: Protocol strandedness
umi_tag: BAM tag containing the UMI
barcode_tag: BAM tag containing the barcode
gene_tag: BAM tag containing the assigned gene
barcodes: List of barcodes to consider
quality: Quality threshold
add_RS_RI: Add RS and RI BAM tags for debugging
temp_dir: Temporary directory
n_threads: Number of threads
Returns:
Path to sorted and indexed consensus BAM
"""
def skip_alignment(read, tags):
return read.is_secondary or read.is_unmapped or any(not read.has_tag(tag) for tag in tags)
def find_genes(contig, start, end, read_strand=None):
genes = []
# NOTE: contigs may not have any genes
for gene in contig_gene_order.get(contig, []):
if read_strand and read_strand != gene_infos[gene]['strand']:
continue
gene_segment = gene_infos[gene]['segment']
if end <= gene_segment.start:
break
if start >= gene_segment.start and end <= gene_segment.end:
genes.append(gene)
return genes
def swap_gene_tags(read, gene):
tags = dict(read.get_tags())
if gene_tag and read.has_tag(gene_tag):
del tags[gene_tag]
tags['GX'] = gene
gn = gene_infos.get(gene, {}).get('gene_name')
if gn:
tags['GN'] = gn
read.set_tags(list(tags.items()))
return read
def create_tags_and_strand(barcode, umi, reads, gene_info):
tags = {
'AS': sum(read.get_tag('AS') for read in reads),
'NH': 1,
'HI': 1,
config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads),
}
if barcode_tag:
tags[barcode_tag] = barcode
if umi_tag:
tags[umi_tag] = umi
if gene_tag:
gene_id = None
for read in reads:
if read.has_tag(gene_tag):
gene_id = read.get_tag(gene_tag)
break
if gene_id:
tags['GX'] = gene_id
gn = gene_info.get('gene_name')
if gn:
tags['GN'] = gn
if add_RS_RI:
tags.update({
'RS': ';'.join(read.query_name for read in reads),
'RI': ';'.join(str(read.get_tag('HI')) for read in reads),
})
# Figure out what strand the consensus should map to
gene_strand = gene_info['strand']
consensus_strand = gene_strand
if strand == 'forward':
consensus_strand = gene_strand
elif strand == 'reverse':
consensus_strand = '-' if gene_strand == '+' else '+'
return tags, consensus_strand
if add_RS_RI:
logger.warning('RS and RI tags will be added to the BAM. This may dramatically increase the BAM size.')
contig_gene_order = {}
for gene_id, gene_info in gene_infos.items():
contig_gene_order.setdefault(gene_info['chromosome'], []).append(gene_id)
for contig in list(contig_gene_order.keys()):
contig_gene_order[contig] = sorted(
contig_gene_order[contig], key=lambda gene: tuple(gene_infos[gene]['segment'])
)
gx_barcode_umi_groups = {}
paired = {}
required_tags = []
if umi_tag:
required_tags.append(umi_tag)
if barcode_tag:
required_tags.append(barcode_tag)
# Start processes for consensus calling
logger.debug(f'Spawning {n_threads} processes')
manager = multiprocessing.Manager()
args_q = manager.Queue(1000 * n_threads)
results_q = manager.Queue()
workers = [
multiprocessing.Process(
target=consensus_worker, args=(args_q, results_q), kwargs=dict(quality=quality), daemon=True
) for _ in range(n_threads)
]
for worker in workers:
worker.start()
temp_out_path = utils.mkstemp(dir=temp_dir)
with pysam.AlignmentFile(bam_path, 'rb') as f:
# Get header dict and update sort order to unsorted.
header_dict = f.header.to_dict()
hd = header_dict.setdefault('HD', {'VN': '1.4', 'SO': 'unsorted'})
hd['SO'] = 'unsorted'
header = pysam.AlignmentHeader.from_dict(header_dict)
with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out:
for i, read in tqdm(enumerate(f.fetch()), total=ngs.bam.count_bam(bam_path), ascii=True, smoothing=0.01,
desc='Calling consensus'):
if skip_alignment(read, required_tags):
continue
barcode = read.get_tag(barcode_tag) if barcode_tag else None
if barcode == '-' or (barcodes and barcode not in barcodes):
continue
contig = read.reference_name
umi = read.get_tag(umi_tag) if umi_tag else None
read_id = read.query_name
alignment_index = read.get_tag('HI')
start = read.reference_start
end = read.reference_end
key = (read_id, alignment_index)
mate = None
if read.is_paired:
if key not in paired:
paired[key] = read
continue
mate = paired.pop(key)
# Use alignment start and end as UMI for paired reads without UMI
if not umi:
start = mate.reference_start
umi = (start, end)
# Determine read strand
read_strand = None
if read.is_paired:
if read.is_read1: # R1 is mapped after R2
if strand == 'forward':
read_strand = '+' if read.is_reverse else '-'
elif strand == 'reverse':
read_strand = '-' if read.is_reverse else '+'
else: # R1 is mapped before R2
if strand == 'forward':
read_strand = '-' if read.is_reverse else '+'
elif strand == 'reverse':
read_strand = '+' if read.is_reverse else '-'
elif strand == 'forward':
read_strand = '-' if read.is_reverse else '+'
elif strand == 'reverse':
read_strand = '+' if read.is_reverse else '-'
# Find compatible genes
gx_assigned = read.has_tag(gene_tag) if gene_tag else False
genes = [read.get_tag(gene_tag)] if gx_assigned else find_genes(contig, start, end, read_strand)
# If there isn't exactly one compatible gene, do nothing and
# write to BAM.
if len(genes) != 1:
out.write(read)
if read.is_paired:
out.write(mate)
continue
# Add read to group
gx_barcode_umi_groups.setdefault(genes[0], {}).setdefault(barcode, {}).setdefault(umi, []).append(read)
if read.is_paired:
gx_barcode_umi_groups[genes[0]][barcode][umi].append(mate)
if i % 10000 == 0:
# Call consensus for gene's whose bodies we've fully passed.
leftmost_start = start if not paired else next(iter(paired.values())).reference_start
for gene in list(gx_barcode_umi_groups.keys()):
gene_info = gene_infos[gene]
gene_contig = gene_info['chromosome']
gene_segment = gene_info['segment']
if (gene_contig < contig) or (gene_contig == contig and gene_segment.end <= leftmost_start):
barcode_umi_groups = gx_barcode_umi_groups.pop(gene)
for barcode, umi_groups in barcode_umi_groups.items():
for umi, reads in umi_groups.items():
if len(reads) == 1:
out.write(swap_gene_tags(reads[0], gene))
tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_info)
# Save for multiprocessing later.
args_q.put(([read.to_string()
for read in reads], header_dict, tags, consensus_strand))
else:
break
to_remove = 0
for gene in contig_gene_order[contig]:
if gene_infos[gene]['segment'].end <= leftmost_start:
to_remove += 1
else:
break
if to_remove > 0:
contig_gene_order[contig] = contig_gene_order[contig][to_remove:]
while True:
try:
result = results_q.get_nowait()
if result:
consensus = pysam.AlignedSegment.fromstring(result, header)
out.write(consensus)
except queue.Empty:
break
# Put remaining
for gene, barcode_umi_groups in gx_barcode_umi_groups.items():
for barcode, umi_groups in barcode_umi_groups.items():
for umi, reads in umi_groups.items():
if len(reads) == 1:
out.write(swap_gene_tags(reads[0], gene))
continue
tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_infos[gene])
args_q.put(([read.to_string() for read in reads], header_dict, tags, consensus_strand))
# Signal to workers to terminate once queue is depleted.
for _ in range(len(workers)):
args_q.put(None)
for worker in workers:
worker.join()
while not results_q.empty():
result = results_q.get()
consensus = pysam.AlignedSegment.fromstring(result, header)
out.write(consensus)
# Sort and index
return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir)