from typing import Dict, FrozenSet, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
[docs]def read_alpha(alpha_path: str,
group_by: Optional[List[str]] = None) -> Union[float, Dict[str, float], Dict[Tuple[str, ...], float]]:
"""Read alpha CSV as a dictionary, with `group_by` columns as keys.
Args:
alpha_path: Path to CSV containing alpha values
group_by: Columns to group by, defaults to `None`
Returns:
Dictionary with `group_by` columns as keys (tuple if multiple)
"""
if group_by is None:
with open(alpha_path, 'r') as f:
return float(f.read())
df = pd.read_csv(alpha_path, dtype={key: 'category' for key in group_by})
return dict(df.set_index(group_by)['alpha'])
[docs]def estimate_alpha(
df_counts: pd.DataFrame,
pi_c: Union[float, Dict[str, float], Dict[Tuple[str, ...], float]],
alpha_path: str,
conversions: FrozenSet[str] = frozenset({'TC'}),
group_by: Optional[List[str]] = None,
pi_c_group_by: Optional[List[str]] = None,
) -> str:
"""Estimate the detection rate alpha.
Args:
df_counts: Pandas dataframe containing conversion counts
pi_c: Labeled mutation rate
alpha_path: Path to output CSV containing alpha estimates
conversions: Conversions to consider
group_by: Columns to group by
pi_c_group_by: Columns that were used to group when calculating pi_c
Returns:
Path to output CSV containing alpha estimates
"""
columns = ['barcode'] + list(conversions)
if group_by is not None:
columns += group_by
if pi_c_group_by is not None:
columns += pi_c_group_by
df_full = df_counts[list(set(columns))].copy()
if pi_c_group_by is not None:
df_full.set_index(pi_c_group_by, inplace=True)
df_full['pi_c'] = df_full.index.map(pi_c)
df_full.reset_index(inplace=True)
else:
df_full['pi_c'] = pi_c
df_full.dropna(subset=['pi_c'], inplace=True) # Drop NA values due to pi_c
pi_cs = df_full['pi_c'].values
if group_by is None:
if not isinstance(pi_c, float):
raise Exception('`pi_c` and `p_c` must be a float when `group_by` is not provided')
if pi_c <= 0:
raise Exception(f'Estimated `pi_c` must be positive, but got {pi_c}')
total = df_full.shape[0]
new = (df_full[list(conversions)] > 0).any(axis=1).sum()
ntr = new / total
alpha = ntr / pi_c
else:
groupby = df_full.groupby(group_by, sort=False, observed=True)
groups = groupby.indices
total = groupby.size()
new = df_full[(df_full[list(conversions)] > 0).any(axis=1)].groupby(group_by, observed=True, sort=False).size()
ntr = new.reindex(total.index, fill_value=0.) / total
alphas = {}
for key, idx in groups.items():
pi_c_unique = np.unique(pi_cs[idx])
if len(pi_c_unique) > 1:
raise Exception(f'`pi_c` for each aggregate group must be a constant, but instead got {pi_c_unique}.')
if pi_c_unique[0] > 0:
alphas[key] = ntr[key] / pi_c_unique[0]
with open(alpha_path, 'w') as f:
if group_by is None:
f.write(str(alpha))
else:
f.write(f'{",".join(group_by)},alpha\n')
for key in sorted(alphas.keys()):
alpha = alphas[key]
formatted_key = key if isinstance(key, str) else ",".join(key)
f.write(f'{formatted_key},{alpha}\n')
return alpha_path