Source code for dynast.estimation.p_c

from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from numba import njit
from scipy import sparse

from .. import utils
from ..logging import logger


[docs]def read_p_c(p_c_path: str, group_by: Optional[List[str]] = None) -> Union[float, Dict[str, float], Dict[Tuple[str, ...], float]]: """Read p_c CSV as a dictionary, with `group_by` columns as keys. Args: p_c_path: Path to CSV containing p_c values group_by: Columns to group by, defaults to `None` Returns: Dictionary with `group_by` columns as keys (tuple if multiple) """ if group_by is None: with open(p_c_path, 'r') as f: return float(f.read()) df = pd.read_csv(p_c_path, dtype={key: 'category' for key in group_by}) return dict(df.set_index(group_by)['p_c'])
@njit(cache=True)
[docs]def binomial_pmf(k: int, n: int, p: int) -> float: """Numbaized binomial PMF function for faster calculation. Args: k: Number of successes n: Number of trials p: Probability of success Returns: Probability of observing `k` successes in `n` trials with probability of success `p` """ coef = np.prod(np.arange(k)[::-1] + (n - k + 1)) / (np.prod(np.arange(k) + 1)) return coef * (p**k) * ((1 - p)**(n - k))
@njit(cache=True)
[docs]def expectation_maximization_nasc(values: np.ndarray, p_e: float, threshold: float = 0.01) -> float: """NASC-seq pipeline variant of the EM algorithm to estimate average conversion rate in labeled RNA. Args: values: N x C Numpy array where N is the number of conversions, C is the nucleotide content, and the value at this position is the number of reads observed p_e: Background mutation rate of unlabeled RNA threshold: Filter threshold Returns: Estimated conversion rate """ mask = [] for k in range(values.shape[0]): for n in range(values.shape[1]): expected = (values[k + 1:, n]).sum() * binomial_pmf(k, n, p_e) if expected > threshold * values[k, n]: mask.append((k, n)) r = 1 l = 0 # noqa p_c = (r + l) / 2 prev_p_c = p_c keys = sorted(mask) keys_set = set(mask) new_values = values.astype(np.float32) # This makes a copy while r - l >= 1e-8: # noqa # E step for k, n in keys: numerator = 0 denominator = 0 for kp in range(new_values.shape[0]): if (kp, n) in keys_set: numerator += binomial_pmf(k, n, p_c) * new_values[kp, n] denominator += binomial_pmf(kp, n, p_c) new_values[k, n] = numerator / denominator # M step numerator = 0 denominator = 0 for k in range(new_values.shape[0]): for n in range(new_values.shape[1]): numerator += k * new_values[k, n] denominator += n * new_values[k, n] prev_p_c = p_c p_c = numerator / denominator if denominator > 0 else 0 if prev_p_c == p_c: break if p_c < prev_p_c: r = prev_p_c else: l = prev_p_c # noqa return p_c
@njit(cache=True)
[docs]def expectation_maximization( values: np.ndarray, p_e: float, p_c: float = 0.1, threshold: float = 0.01, max_iters: int = 300 ) -> float: """Run EM algorithm to estimate average conversion rate in labeled RNA. This function runs the following two steps. 1) Constructs a sparse matrix representation of `values` and filters out certain indices that are expected to contain more than `threshold` proportion of unlabeled reads. 2) Runs an EM algorithm that iteratively updates the filtered out data and stimation. See https://doi.org/10.1093/bioinformatics/bty256. Args: values: array of three columns encoding a sparse array in (row, column, value) format, zero-indexed, where row: number of conversions column: nucleotide content value: number of reads p_e: Background mutation rate of unlabeled RNA p_c: Initial p_c value threshold: Filter threshold max_iters: Maximum number of EM iterations Returns: Estimated conversion rate """ mask = [] for k in range(values.shape[0]): for n in range(values.shape[1]): expected = (values[k:, n]).sum() * binomial_pmf(k, n, p_e) if expected > threshold * values[k, n]: mask.append((k, n)) prev_p_c = p_c keys = sorted(mask) keys_set = set(mask) new_values = values.astype(np.float32) # This makes a copy for _ in range(max_iters): # noqa # E step for k, n in keys: numerator = 0 denominator = 0 for kp in range(new_values.shape[0]): if (kp, n) not in keys_set: numerator += binomial_pmf(k, n, p_c) * new_values[kp, n] denominator += binomial_pmf(kp, n, p_c) if denominator > 0: new_values[k, n] = numerator / denominator # M step numerator = (new_values * np.arange(new_values.shape[0]).reshape(-1, 1)).sum() denominator = (new_values * np.arange(new_values.shape[1])).sum() prev_p_c = p_c p_c = numerator / denominator if denominator > 0 else p_c if prev_p_c == p_c: break if p_c <= p_e: raise ValueError('p_c <= p_e') return p_c
[docs]def estimate_p_c( df_aggregates: pd.DataFrame, p_e: Union[float, Dict[str, float], Dict[Tuple[str, ...], float]], p_c_path: str, group_by: Optional[List[str]] = None, threshold: int = 1000, n_threads: int = 8, nasc: bool = False ) -> str: """Estimate the average conversion rate in labeled RNA. Args: df_aggregates: Pandas dataframe containing aggregate values p_e: Background mutation rate of unlabeled RNA p_c_path: Path to output CSV containing p_c estimates group_by: Columns to group by threshold: Read count threshold n_threads: Number of threads nasc: Flag to indicate whether to use NASC-seq pipeline variant of the EM algorithm Returns: Path to output CSV containing p_c estimates """ em_func = expectation_maximization_nasc if nasc else expectation_maximization values = df_aggregates[['conversion', 'base', 'count']].values logger.debug('Running EM algorithm') if group_by is None: # Initialize with sparse matrix because duplicates need to be added. values_mtx = sparse.csr_matrix((values[:, 2], (values[:, 0], values[:, 1]))).A p_c = em_func(values_mtx, p_e) else: groups = df_aggregates.groupby(group_by, sort=False, observed=True).indices p_cs = {} skipped = 0 failed = 0 logger.debug(f'Spawning {n_threads} processes') with ProcessPoolExecutor(max_workers=n_threads) as executor: futures = {} for key, idx in groups.items(): vals = values[idx] if sum(vals[:, 2]) < threshold: skipped += 1 continue # Initialize with sparse matrix because duplicates need to be added. vals_mtx = sparse.csr_matrix((vals[:, 2], (vals[:, 0], vals[:, 1]))).A futures[executor.submit(em_func, vals_mtx, p_e[key])] = key for future in utils.as_completed_with_progress(futures): key = futures[future] try: p_cs[key] = future.result() except Exception: failed += 1 if skipped > 0: logger.warning( f'Estimation skipped for {skipped} barcodes because they contain less than ' f'{threshold} reads. Use `--cell-threshold` to change.' ) if failed > 0: logger.warning(f'Estimation failed {failed} times.') logger.debug(f'Writing p_c estimates to {p_c_path}') with open(p_c_path, 'w') as f: if group_by is None: f.write(str(p_c)) else: f.write(f'{",".join(group_by)},p_c\n') for key in sorted(p_cs.keys()): p_c = p_cs[key] formatted_key = key if isinstance(key, str) else ",".join(key) f.write(f'{formatted_key},{p_c}\n') return p_c_path