Source code for dynast.estimate

import datetime as dt
import os
from typing import Dict, List, Optional, Union

import pandas as pd
import pystan
from typing_extensions import Literal

from . import config, constants, estimation, preprocessing, utils
from .logging import logger
from .stats import Stats


@logger.namespaced('estimate')
[docs]def estimate( count_dirs: List[str], out_dir: str, reads: Union[Literal['complete'], List[Literal['total', 'transcriptome', 'spliced', 'unspliced']]] = 'complete', barcodes: Optional[List[List[str]]] = None, groups: Optional[List[Dict[str, str]]] = None, ignore_groups_for_est: bool = True, genes: Optional[List[str]] = None, downsample: Optional[Union[int, float]] = None, downsample_mode: Literal['uniform', 'cell', 'group'] = 'uniform', cell_threshold: int = 1000, cell_gene_threshold: int = 16, control_p_e: Optional[float] = None, control: bool = False, method: Literal["pi_g", "alpha"] = "alpha", n_threads: int = 8, temp_dir: Optional[str] = None, nasc: bool = False, by_name: bool = False, seed: Optional[int] = None, ): """Main interface for the `estimate` command. Args: count_dirs: Paths to directories containing `count` command output out_dir: Output directory reads: What read group(s) to quantify barcodes: Cell barcodes groups: Cell groups ignore_groups_for_est: Ignore cell groups for final estimation genes: Genes to consider downsample: Downsample factor (float) or number (int) donsample_mode: Downsampling mode cell_threshold: Run estimation only for cells with at least this many counts cell_gene_threshold: Run estimation for cell-genes with at least this many counts control_p_e: Old RNA conversion rate (p_e), estimated from control samples control: Whether this is a control sample method: Estimation method to use n_threads: Number of threads temp_dir: Temporary directory nasc: Whether to match NASC-seq pipeline behavior by_name: Whether to group counts by gene name instead of ID seed: Random seed """ stats = Stats() stats.start() stats_path = os.path.join( out_dir, f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json' ) os.makedirs(out_dir, exist_ok=True) # Check that all the conversions are the same if there are multiple count dirs conversions = utils.read_pickle(os.path.join(count_dirs[0], constants.CONVS_FILENAME)) for count_dir in count_dirs[1:]: _conversions = utils.read_pickle(os.path.join(count_dir, constants.CONVS_FILENAME)) if conversions != _conversions: raise Exception( f'Conversions for {count_dir} doesn\'t match conversions for {count_dirs[0]}. ' f'({_conversions} != {conversions}).' ) logger.info(f'Conversions: {" ".join(",".join(convs) for convs in conversions)}') all_conversions = sorted(utils.flatten_iter(conversions)) gene_infos = utils.read_pickle(os.path.join(count_dirs[0], constants.GENES_FILENAME)) # Read each counts dataframe and suffix barcodes if needed dfs = [] for i, count_dir in enumerate(count_dirs): counts_path = os.path.join(count_dir, f'{constants.COUNTS_PREFIX}_{"_".join(all_conversions)}.csv') logger.info( f'Reading {counts_path}' + (f' and suffixing all barcodes with `-{i}`' if len(count_dirs) > 1 else '') ) _df_counts = preprocessing.read_counts(counts_path) # Filter barcodes if barcodes: _df_counts = _df_counts[_df_counts['barcode'].isin(barcodes[i])] if len(count_dirs) > 1: _df_counts['barcode'] = _df_counts['barcode'].astype(str) + f'-{i}' dfs.append(_df_counts) df_counts_uncomplemented = pd.concat(dfs, ignore_index=True) if len(count_dirs) > 1 else dfs[0] df_counts_uncomplemented['barcode'] = df_counts_uncomplemented['barcode'].astype('category') df_counts_uncomplemented.drop(columns=['umi'], inplace=True) # Clean groups by combining multiple into one if isinstance(groups, list): if len(groups) == 1: groups = groups[0] else: groups = {f'{barcode}-{i}': group for i, _groups in enumerate(groups) for barcode, group in _groups.items()} if groups: # Remove any barcodes not present. barcodes = set(df_counts_uncomplemented['barcode']) to_remove = set(groups.keys()) - barcodes if to_remove: logger.warning( f'Removing {len(to_remove)} barcodes from the provided groups that are not present in the counts CSV.' ) for barcode in to_remove: del groups[barcode] # Contains group name to list of cells mapping group_cells = {} for barcode, group in groups.items(): group_cells.setdefault(group, []).append(barcode) # Change barcodes to cell groups. We add groups here instead of after # complementing because the user may want to downsample per group instead of per cell. if groups: logger.warning(f'Barcodes that are not among the {len(groups)} barcodes with assigned groups will be ignored.') df_counts_uncomplemented = df_counts_uncomplemented[df_counts_uncomplemented['barcode'].isin( groups.keys() )].reset_index(drop=True) df_counts_uncomplemented['group'] = df_counts_uncomplemented['barcode'].map(groups).astype('category') # Downsample here if downsample: _proportion = None _count = None _group_by = None if int(downsample) == downsample: _count = int(downsample) else: _proportion = downsample if downsample_mode == 'cell': _group_by = ['barcode'] elif downsample_mode == 'group': _group_by = ['group'] if not _group_by: logger.info( 'Downsampling uniformly at random ' + (f'to {_count} entries' if _count else f'to a factor of {_proportion}') ) else: logger.info(f'Downsampling per {_group_by} to {_count} entries') old_count = df_counts_uncomplemented.shape[0] df_counts_uncomplemented = utils.downsample_counts( df_counts_uncomplemented, proportion=_proportion, count=_count, seed=seed, group_by=_group_by ) logger.debug(f'Downsampled from {old_count} to {df_counts_uncomplemented.shape[0]} entries') # Check that all requested read groups can be corrected. transcriptome_any = df_counts_uncomplemented['transcriptome'].any() transcriptome_all = df_counts_uncomplemented['transcriptome'].all() if reads != 'complete': if 'transcriptome' in reads and not transcriptome_any: raise Exception( 'No reads are assigned to `transcriptome`, so estimation is not supported for this read group.' ) if 'total' in reads and transcriptome_all: raise Exception( 'All reads are assigned to `transcriptome`, so estimation is not supported for `total` read group.' ) for key in set(reads).intersection(('spliced', 'unspliced')): if not (df_counts_uncomplemented['velocity'] == key).any(): raise Exception( f'No reads are assigned to `{key}`, so estimation is not supported for this read group.' ) else: reads = [] if transcriptome_any: reads.append('transcriptome') if not transcriptome_all: reads.append('total') reads += list(set(df_counts_uncomplemented['velocity'].unique()) - set(config.VELOCITY_BLACKLIST)) logger.info(f'Estimation will be done on the following read groups: {reads}') df_counts = preprocessing.complement_counts(df_counts_uncomplemented, gene_infos) logger.info( f'Final counts: {df_counts.shape[0]} reads ' f'across {df_counts["barcode"].nunique()} barcodes' + (f' and {df_counts["group"].nunique()} groups.' if groups else '.') ) # Convert gene IDs to names if by_name: logger.info('`--gene-names` provided. Converting gene IDs to names.') df_counts['GX'] = df_counts['GX'].apply(lambda gx: gene_infos[gx]['gene_name'] or gx) # Subset to provided genes if genes: logger.info(f'`--genes` provided. Ignorning genes not in the {len(genes)} provided.') df_counts = df_counts[df_counts['GX'].isin(genes)] # Estimate p_e p_key = 'group' if groups else 'barcode' p_e_path = os.path.join(out_dir, constants.P_E_FILENAME) if control_p_e: logger.info('`--p-e` provided. No background mutation rate estimation will be done.') df_barcodes = df_counts[[p_key]].drop_duplicates().reset_index(drop=True) df_barcodes['p_e'] = control_p_e df_barcodes.to_csv(p_e_path, header=[p_key, 'p_e'], index=False) else: logger.info(f'Estimating average conversion rate in unlabeled RNA per {p_key} to {p_e_path}') if control: p_e_path = estimation.estimate_p_e_control( df_counts, p_e_path, conversions=conversions, ) elif nasc: rates = preprocessing.read_rates(os.path.join(count_dir, f'{constants.RATES_PREFIX}.csv')) if groups: rates['group'] = rates['barcode'].map(groups) p_e_path = estimation.estimate_p_e_nasc( rates, p_e_path, group_by=[p_key], ) else: p_e_path = estimation.estimate_p_e( df_counts, p_e_path, conversions=conversions, group_by=[p_key], ) if control: logger.info('Downstream processing skipped for controls') logger.info(f'Use `--p-e {p_e_path}` to run test samples') stats.end() stats.save(stats_path) return p_es = estimation.read_p_e(p_e_path, group_by=[p_key]) # Aggregate counts to construct A matrix # NOTE: we don't use groupings here because we may need to use individual # barcodes later. For instance, p_c may be estimated in groups, but pi_g may # be estimated per cell. So that the aggregated A matrix is compatible with both # estimation procedures, we don't care about groupings here. Instead, groupings # should be manually done at each step that requires such groupings. aggregates_paths = {} for key in set(reads).union(['transcriptome'] if transcriptome_all else ['total']): logger.info(f'Aggregating counts for `{key}`') df = preprocessing.subset_counts(df_counts, key) for convs in conversions: convs = sorted(convs) other_convs = list(set(all_conversions) - set(convs)) aggregates_paths.setdefault(key, {})[tuple(convs)] = preprocessing.aggregate_counts( df[(df[other_convs] == 0).all(axis=1)] if other_convs else df, os.path.join(out_dir, f'A_{key}_{"_".join(convs)}.csv'), conversions=convs, ) # Estimate p_c p_c_paths = {} for convs in conversions: convs = sorted(convs) p_c_path = os.path.join(out_dir, f'{constants.P_C_PREFIX}_{"_".join(convs)}.csv') logger.info( f'Estimating {convs} conversion rate in labeled RNA per {p_key} to {p_c_path}. ' 'Consider downsampling with `--downsample` if this step takes too long.' ) df_aggregates = preprocessing.read_aggregates( aggregates_paths['transcriptome' if transcriptome_all else 'total'][tuple(convs)] ) if groups: df_aggregates['group'] = df_aggregates['barcode'].map(groups).astype('category') p_c_paths[tuple(convs)] = estimation.estimate_p_c( df_aggregates, p_es, p_c_path, group_by=[p_key], threshold=cell_threshold, n_threads=n_threads, nasc=nasc ) p_cs = {tuple(convs): estimation.read_p_c(p_c_paths[tuple(convs)], group_by=[p_key]) for convs in conversions} logger.info(f'Compling STAN model from {config.MODEL_PATH}') model = pystan.StanModel(file=config.MODEL_PATH, model_name=config.MODEL_NAME) pis = None alphas = None if method == 'pi_g': pi_key = 'barcode' if ignore_groups_for_est or not groups else 'group' pi_paths = {} pi_as = {} pi_bs = {} pis = {} for key in reads: for convs in conversions: convs = sorted(convs) pi_path = os.path.join(out_dir, f'pi_{key}_{"_".join(convs)}.csv') logger.info( f'Estimating fraction of labeled `{key}` RNA for conversions {convs} per {pi_key}-gene to {pi_path}' ) df_aggregates = preprocessing.read_aggregates(aggregates_paths[key][tuple(convs)]) if groups: df_aggregates['group'] = df_aggregates['barcode'].map(groups).astype('category') pi_paths.setdefault(key, {})[tuple(convs)] = estimation.estimate_pi( df_aggregates, p_es, p_cs[tuple(convs)], pi_path, group_by=[pi_key, 'GX'], p_group_by=[p_key], n_threads=n_threads, threshold=cell_gene_threshold, seed=seed, nasc=nasc, model=model, ) pi_a, pi_b, pi = estimation.read_pi(pi_path, group_by=[pi_key, 'GX']) pi_as.setdefault(key, {})[tuple(convs)] = pi_a pi_bs.setdefault(key, {})[tuple(convs)] = pi_b pis.setdefault(key, {})[tuple(convs)] = pi # Estimated pis need to be per cell because the adata is per cell if groups and not ignore_groups_for_est: pis = { key: { convs: {(barcode, gx): value for (group, gx), value in pis[key][convs].items() for barcode in group_cells[group]} for convs in pis[key] } for key in pis } elif method == 'alpha': pi_c_key = 'group' if groups else 'barcode' alpha_key = 'group' if groups and not ignore_groups_for_est else 'barcode' pi_c_paths = {} pi_cs = {} pi_as = {} pi_bs = {} alpha_paths = {} alphas = {} for key in reads: for convs in conversions: convs = sorted(convs) pi_c_path = os.path.join(out_dir, f'pi_{key}_{"_".join(convs)}.csv') logger.info( f'Estimating fraction of labeled `{key}` RNA for conversions {convs} per ' f'{pi_c_key} to {pi_c_path}. Consider downsampling with `--downsample` if ' 'this step takes too long.' ) df_aggregates = preprocessing.read_aggregates(aggregates_paths[key][tuple(convs)]) if groups: df_aggregates['group'] = df_aggregates['barcode'].map(groups).astype('category') pi_c_path = estimation.estimate_pi( df_aggregates, p_es, p_cs[tuple(convs)], pi_c_path, group_by=[pi_c_key], p_group_by=[p_key], n_threads=n_threads, threshold=cell_gene_threshold, seed=seed, nasc=nasc, model=model, ) pi_c_paths.setdefault(key, {})[tuple(convs)] = pi_c_path pi_as.setdefault(key, {})[tuple(convs)], pi_bs.setdefault(key, {})[tuple(convs)], pi_c = estimation.read_pi( pi_c_path, group_by=[pi_c_key] ) pi_cs.setdefault(key, {})[tuple(convs)] = pi_c alpha_path = os.path.join(out_dir, f'alpha_{key}_{"_".join(convs)}.csv') logger.info( f'Estimating detection rate of `{key}` RNA for conversions {convs} per {alpha_key} to {alpha_path}' ) alpha_path = estimation.estimate_alpha( preprocessing.subset_counts(df_counts, key), pi_c, alpha_path, conversions=convs, group_by=[alpha_key], pi_c_group_by=[pi_c_key], ) alpha_paths.setdefault(key, {})[tuple(convs)] = alpha_path alphas.setdefault(key, {})[tuple(convs)] = estimation.read_alpha(alpha_path, group_by=[alpha_key]) if groups and not ignore_groups_for_est: alphas = { key: { convs: {barcode: value for group, value in alphas[key][convs].items() for barcode in group_cells[group]} for convs in alphas[key] } for key in alphas } else: raise Exception(f'Unrecognized method {method}') adata_path = os.path.join(out_dir, constants.ADATA_FILENAME) logger.info(f'Combining results into Anndata object at {adata_path}') adata = utils.results_to_adata( df_counts, conversions, gene_infos=gene_infos if not by_name else None, pis=pis, alphas=alphas ) # If groups were provided, add the group as a column if groups: adata.obs['group'] = adata.obs.index.map(groups).astype('category') # Add the count dir that was provided as input as another column if multiple # count dirs were provided if len(count_dirs) > 1: adata.obs['count_dir'] = adata.obs.index.str.split('-').str[-1].astype(int).map({ i: count_dir for i, count_dir in enumerate(count_dirs) }).astype('category') adata.obs.reset_index(inplace=True) # Add p_e, p_c estimates adata.obs['p_e'] = adata.obs[p_key].map(p_es).astype(float) for convs in conversions: convs = sorted(convs) convs_key = "_".join(convs) adata.obs[f'p_c_{convs_key}'] = adata.obs[p_key].map(p_cs[tuple(convs)]).astype(float) adata.obs.set_index('barcode', inplace=True) adata.write(adata_path, compression='gzip') stats.end() stats.save(stats_path)