import random
import tempfile
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystan
from tqdm import tqdm
from .. import config, estimation
from ..preprocessing import aggregation, conversion
[docs]def generate_sequence(k, seed=None):
random.seed(seed)
return ''.join(random.choices(conversion.BASE_COLUMNS, k=k))
[docs]def simulate_reads(sequence, p_e, p_c, pi, l=100, n=100, seed=None): # noqa
generator = np.random.RandomState(seed)
n_new = int(n * pi)
n_old = n - n_new
contents = []
convs = []
# Generate new sequences
for _ in range(n_new):
i = generator.randint(0, len(sequence) - l) # noqa
subsequence = sequence[i:i + l] # noqa
# Nucleotide content
content = Counter(subsequence)
# Create read with mutations
conv = {conversion: 0 for conversion in conversion.CONVERSION_COLUMNS}
for base in subsequence:
if base == 'T' and generator.random() < p_c:
conv['TC'] += 1
else:
other_bases = [b for b in conversion.BASE_COLUMNS if b != base]
for other_base in random.sample(other_bases, k=len(other_bases)):
bases = f'{base}{other_base}'
if generator.random() < p_e and bases != 'TC':
conv[bases] += 1
break
contents.append(dict(content))
convs.append(conv)
# Generate old sequences
for _ in range(n_old):
i = generator.randint(0, len(sequence) - l) # noqa
subsequence = sequence[i:i + l] # noqa
# Nucleotide content
content = Counter(subsequence)
# Create read with mutations
conv = {conversion: 0 for conversion in conversion.CONVERSION_COLUMNS}
for base in subsequence:
other_bases = [b for b in conversion.BASE_COLUMNS if b != base]
for other_base in random.sample(other_bases, k=len(other_bases)):
if generator.random() < p_e:
conv[f'{base}{other_base}'] += 1
break
contents.append(dict(content))
convs.append(conv)
df_contents = pd.DataFrame(contents)
df_conversions = pd.DataFrame(convs)
df_counts = pd.concat((df_contents, df_conversions),
axis=1)[conversion.COLUMNS].iloc[generator.choice(np.arange(n), size=n,
replace=False)].reset_index(drop=True)
return df_counts
[docs]def initializer(model):
global _model
_model = model
[docs]def estimate(
df_counts,
p_e,
p_c,
pi,
estimate_p_e=False,
estimate_p_c=False,
estimate_pi=True,
model=None,
nasc=False,
):
# p_e
if estimate_p_e:
if nasc:
with tempfile.NamedTemporaryFile() as tf:
rates_path = aggregation.calculate_mutation_rates(df_counts, tf.name)
df_rates = aggregation.read_rates(rates_path)
with tempfile.NamedTemporaryFile() as tf:
p_e_path = estimation.estimate_p_e_nasc(df_rates, tf.name)
p_e_estimate = estimation.read_p_e(p_e_path)
else:
with tempfile.NamedTemporaryFile() as tf:
p_e_path = estimation.estimate_p_e(df_counts, tf.name)
p_e_estimate = estimation.read_p_e(p_e_path)
else:
p_e_estimate = p_e
# p_c
if estimate_p_c:
df_aggregates = pd.DataFrame(df_counts.groupby(['TC', 'T'], sort=False, observed=True).size())
df_aggregates.columns = ['count']
df_aggregates.reset_index(inplace=True)
df_aggregates.rename(columns={'TC': 'conversion', 'T': 'base'}, inplace=True)
with tempfile.NamedTemporaryFile() as tf:
p_c_path = estimation.estimate_p_c(df_aggregates, p_e_estimate, tf.name, nasc=nasc)
p_c_estimate = estimation.read_p_c(p_c_path)
else:
p_c_estimate = p_c
# pi
if estimate_pi:
df_aggregates = pd.DataFrame(
df_counts[df_counts['GX'] == 'gene_0'].groupby(['TC', 'T'], sort=False, observed=True).size()
)
df_aggregates.columns = ['count']
df_aggregates.reset_index(inplace=True)
df_aggregates = df_aggregates[(df_aggregates[['T', 'count']] > 0).all(axis=1)]
vals = df_aggregates.values
guess = min(max((sum(vals[vals[:, 0] > 0][:, 2]) / sum(vals[:, 2])), 0.01), 0.99)
guess, alpha, beta, pi_estimate = estimation.pi.fit_stan_mcmc(
vals,
p_e_estimate,
p_c_estimate,
guess=guess,
model=model,
)
if nasc:
pi_estimate = estimation.pi.beta_mode(alpha, beta)
else:
guess, alpha, beta, pi_estimate = pi, None, None, pi
return p_e_estimate, p_c_estimate, guess, alpha, beta, pi_estimate
[docs]def _simulate(
p_e,
p_c,
pi,
sequence=None,
k=10000,
l=100, # noqa
n=100,
estimate_p_e=False,
estimate_p_c=False,
estimate_pi=True,
seed=None,
model=None,
nasc=False,
):
model = model or _model
pis = pi
ns = n
if isinstance(pi, list) and not isinstance(n, list):
pis = pi
ns = [n] * len(pis)
elif not isinstance(pi, list) and isinstance(n, list):
ns = n
pis = [pi] * len(ns)
elif not isinstance(pi, list) and not isinstance(n, list):
ns = [n]
pis = [pi]
assert len(pis) == len(ns)
dfs = []
for i, (pi, n) in enumerate(zip(pis, ns)):
sequence = sequence or generate_sequence(k, seed=seed)
df_counts = simulate_reads(sequence, p_e, p_c, pi, l=l, n=n, seed=seed)
df_counts['GX'] = f'gene_{i}'
dfs.append(df_counts)
df_counts = pd.concat(dfs, ignore_index=True)
return estimate(
df_counts,
p_e,
p_c,
pis[0],
estimate_p_e=estimate_p_e,
estimate_p_c=estimate_p_c,
estimate_pi=estimate_pi,
model=model,
nasc=nasc,
)
[docs]def simulate(
p_e,
p_c,
pi,
sequence=None,
k=10000,
l=100, # noqa
n=100,
n_runs=16,
n_threads=8,
estimate_p_e=False,
estimate_p_c=False,
estimate_pi=True,
model=None,
nasc=False,
):
model = model or pystan.StanModel(file=config.MODEL_PATH, model_name=config.MODEL_NAME)
p_es = []
p_cs = []
guesses = []
alphas = []
betas = []
pis = []
with ProcessPoolExecutor(max_workers=n_threads, initializer=initializer, initargs=(model,)) as executor:
futures = [
executor.submit(
_simulate,
p_e,
p_c,
pi,
sequence=sequence,
k=k,
l=l,
n=n,
estimate_p_e=estimate_p_e,
estimate_p_c=estimate_p_c,
estimate_pi=estimate_pi,
nasc=nasc,
) for _ in range(n_runs)
]
for future in as_completed(futures):
p_e_estimate, p_c_estimate, guess, alpha_estimate, beta_estimate, pi_estimate = future.result()
p_es.append(p_e_estimate)
p_cs.append(p_c_estimate)
guesses.append(guess)
alphas.append(alpha_estimate)
betas.append(beta_estimate)
pis.append(pi_estimate)
return p_es, p_cs, guesses, alphas, betas, pis
[docs]def simulate_batch(
p_e,
p_c,
pi,
l, # noqa
n,
estimate_p_e,
estimate_p_c,
estimate_pi,
n_runs,
n_threads,
model,
nasc=False
):
"""Helper function to run simulations in batches.
"""
p_es = p_e
p_cs = p_c
pis = pi
if not isinstance(p_e, list):
p_es = [p_e]
if not isinstance(p_c, list):
p_cs = [p_c]
if not isinstance(pi, list):
pis = [pi]
dfs = []
for p_e, p_c, pi in tqdm(list(product(p_es, p_cs, pis))):
p_e_estimates, p_c_estimates, guesses, alphas, betas, pi_estimates = simulate(
p_e,
p_c,
pi,
l=l,
n=n,
estimate_p_e=estimate_p_e,
estimate_p_c=estimate_p_c,
estimate_pi=estimate_pi,
n_runs=n_runs,
n_threads=n_threads,
model=model,
nasc=nasc,
)
dfs.append(
pd.DataFrame({
'p_e': p_e,
'p_c': p_c,
'pi': pi[0] if isinstance(pi, list) else pi,
'p_e_estimate': p_e_estimates,
'p_c_estimate': p_c_estimates,
'guess': guesses,
'alpha_estimate': alphas,
'beta_estimate': betas,
'pi_estimate': pi_estimates
})
)
return pd.concat(dfs, ignore_index=True)
[docs]def plot_estimations(
X, Y, n_runs, means, truth, ax=None, box=True, tick_decimals=1, title=None, xlabel=None, ylabel=None
):
if ax is not None:
_ax = ax
else:
fig, _ax = plt.subplots(figsize=(5, 5), tight_layout=True)
if box:
X_range = max(X) - min(X)
_ax.boxplot(
list(np.array(Y).reshape(-1, n_runs)),
positions=np.sort(np.unique(X)),
zorder=-1,
widths=X_range * 0.05,
medianprops=dict(c='gray', linewidth=1.5),
boxprops=dict(facecolor='lightgray', color='gray', linewidth=1.5),
whiskerprops=dict(c='gray', linewidth=1.5),
capprops=dict(c='gray', linewidth=1.5),
patch_artist=True,
showfliers=False,
)
_ax.scatter(X, Y, s=3, label=f'n={n_runs}')
_ax.scatter(means.index, means.values, s=15, label='mean')
try:
iter(truth)
_ax.plot(truth, truth, c='red', linewidth=1, label='truth')
except: # noqa
_ax.plot([min(X), max(X)], [truth, truth], c='red', linewidth=1, label='truth')
if box:
_ax.set_xlim(left=min(X) - X_range * 0.1, right=max(X) + X_range * 0.1)
xticks = np.sort(np.unique(X))
_ax.set_xticks(xticks)
_ax.set_xticklabels([f'{round(x, tick_decimals)}' for x in xticks])
_ax.legend()
if title:
_ax.set_title(title)
if xlabel:
_ax.set_xlabel(xlabel)
if ylabel:
_ax.set_ylabel(ylabel)
if ax is None:
fig.show()
return _ax