#!/usr/bin/env python3.8
# coding: latin-1
# (c) Massachusetts Institute of Technology 2015-2018
# (c) Brian Teague 2018-2022
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
cytoflow.operations.registration
------------------------------------
The `register` module contains two classes:
`RegistrationOp` -- warps channels to bring areas of high density into registration
`RegistrationDiagnosticView` -- a diagnostic view to make sure
that `RegistrationOp` performed correctly.
"""
from traits.api import (HasStrictTraits, Str, Dict, Int, List,
Float, Constant, provides, Instance, Union,
Callable, Any, Enum)
import numpy as np
import pandas as pd
import scipy.signal
import sklearn
from sklearn.neighbors import KernelDensity
from statsmodels.nonparametric.bandwidths import bw_scott, bw_silverman
from skfda import FDataGrid
from skfda.preprocessing.registration import landmark_elastic_registration_warping, invert_warping
import matplotlib.pyplot as plt
import cytoflow.utility as util
from cytoflow.views import IView
from cytoflow.views.kde_1d import _kde_support
from .i_operation import IOperation
[docs]
@provides(IOperation)
class RegistrationOp(HasStrictTraits):
"""
`RegistrationOp` is used to *register* different data sets with eachother.
It identifies areas of high density that are shared across all most of the
data sets, then applies a warp function to align those areas of high
density. This is commonly used to correct sample-to-sample variation
across large data sets. This is *not* a multidimensional algorithm --
if you apply it to multiple channels, each channel is warped
independently.
Attributes
----------
channels : List(Str)
The channels to register.
scale : Dict(Str : {"linear", "logicle", "log"})
How to scale the channels before registering.
by : List(Str)
Which conditions to use to group samples? These are usually
experimental conditions, not gates!
subset : Str
How to filter the data before estimating the transformation?
kernel : Str (default = ``gaussian``)
The kernel to use for the kernel density estimate. Choices are:
- ``gaussian`` (the default)
- ``tophat``
- ``epanechnikov``
- ``exponential``
- ``linear``
- ``cosine``
bw : Str or Float (deafult = ``scott``)
The bandwidth for the kernel, controls how lumpy or smooth the
kernel estimate is. Choices are:
- ``scott`` (the default) - ``1.059 * A * nobs ** (-1/5.)``, where ``A`` is ``min(std(X),IQR/1.34)``
- ``silverman`` - ``.9 * A * nobs ** (-1/5.)``, where ``A`` is ``min(std(X),IQR/1.34)``
If a float is given, it is the bandwidth. Note, this is in
scaled units, not data units.
gridsize : int (default = 200)
How many locations should we evaluate the kernel?
Notes
-----
The registration algorithm follows the approach from the ``warpSet``
function in the R/Bioconductor ``flowStats`` package. The precise details
differ depending on what is available in the scientific Python ecosystem,
but the overall flow remains the same. For each channel:
- Rescale the data (if requested)
- Smooth the data using a kernel density estimate
- Use a peak-finding algorithm to find landmarks in the distribution
- Use 1-dimensional K-means across groups to group landmarks together
- Determine the (scaled) mean of each group. These are the "destinations" for our
warp functions.
- Using tools from functional data analysis, compute a "warp" function
that can be applied to each group to move the landmarks to the median.
- Apply the warp function to the underlying data, scaling and then
inverting as you do so.
Every step except the last is performed by the `estimate` function. The
diagnostic plot shows the smoothed distribution, the peaks, their
clusters and means, and the warped (smoothed) distribution.
Examples
--------
.. plot::
:context: close-figs
Make a little data set.
>>> import cytoflow as flow
>>> import_op = flow.ImportOp()
>>> import_op.tubes = [flow.Tube(file = "module_examples/itn_02.fcs",
... conditions = {'Sample' : 2}),
... flow.Tube(file = "module_examples/itn_03.fcs",
... conditions = {'Sample' : 3})]
>>> import_op.conditions = {'Sample' : 'category'}
>>> ex = import_op.apply()
Plot the samples "before":
.. plot::
:context: close-figs
>>> flow.Kde1DView(channel = 'CD3',
... huefacet = 'Sample',
... scale = 'log').plot(ex)
Create and parameterize the operation.
.. plot::
:context: close-figs
>>> op = flow.RegistrationOp(channels = ['CD3', 'CD4'],
... scale = {'CD3' : 'log',
... 'CD4' : 'log'},
... by = ['Sample'])
Estimate the clusters
.. plot::
:context: close-figs
>>> op.estimate(ex)
Plot a diagnostic view
.. plot::
:context: close-figs
>>> op.default_view().plot(ex, plot_name = 'CD3')
Apply the warp
.. plot::
:context: close-figs
>>> ex2 = op.apply(ex)
Plot the same KDE after the warp.
.. plot::
:context: close-figs
>>> flow.Kde1DView(channel = 'CD3',
... huefacet = 'Sample',
... scale = 'log').plot(ex2)
"""
# traits
id = Constant('cytoflow.operations.register')
friendly_id = Constant("Density Registration")
name = Constant("Registration")
channels = List(Str)
scale = Dict(Str, util.ScaleEnum)
by = List(Str)
# Smoothing
kernel = Enum('gaussian','tophat','epanechnikov','exponential','linear','cosine')
bw = Union(Enum('scott', 'silverman'), Float)
gridsize = Int(200)
# these are really only saved to support plotting
_scale = Dict(Str, Instance(util.IScale))
_groups = List(Any)
_support = Dict(Str, np.ndarray) # channel --> kde support
_kde = Dict(Str, Dict(Any, np.ndarray)) # channel, group --> kde density
_peaks = Dict(Str, Dict(Any, List(Float))) # channel, group --> peaks
_clusters = Dict(Str, Dict(Any, List(Union(Int, None)))) # channel, group --> cluster assignments
_means = Dict(Str, List(Union(Float, None))) # channel --> cluster medians
_warping = Dict(Str, Callable) # channel --> warping
[docs]
def estimate(self, experiment, subset = None):
"""
Estimate the calibration coefficients from the beads file.
Parameters
----------
experiment : `Experiment`
The experiment used to compute the calibration.
"""
if experiment is None:
raise util.CytoflowOpError('experiment',
"No experiment specified")
if len(self.channels) == 0:
raise util.CytoflowOpError('channels',
"Must set at least one channel")
if len(self.channels) != len(set(self.channels)):
raise util.CytoflowOpError('channels',
"Must not duplicate channels")
for c in self.channels:
if c not in experiment.data:
raise util.CytoflowOpError('channels',
"Channel {0} not found in the experiment"
.format(c))
for c in self.scale:
if c not in self.channels:
raise util.CytoflowOpError('channels',
"Scale set for channel {0}, but it isn't "
"'channels'"
.format(c))
if not self.by:
raise util.CytoflowOpError('by',
"'by' must not be empty!")
for b in self.by:
if b not in experiment.data:
raise util.CytoflowOpError('by',
"Aggregation metadata {} not found, "
"must be one of {}"
.format(b, experiment.conditions))
if subset:
try:
experiment = experiment.query(subset)
except:
raise util.CytoflowOpError('subset',
"Subset string '{0}' isn't valid"
.format(subset))
if len(experiment) == 0:
raise util.CytoflowOpError('subset',
"Subset string '{0}' returned no events"
.format(subset))
groupby = experiment.data.groupby(self.by, observed = True)
if len(groupby.groups) < 2:
raise util.CytoflowOpError('by',
"Must be more than one group after grouping by 'by'")
self._warping.clear()
self._scale.clear()
# get the scale. estimate the scale params for the ENTIRE data set,
# not subsets we get from groupby(). And we need to save it so that
# the data is transformed the same way when we apply()
for c in self.channels:
if c in self.scale:
self._scale[c] = util.scale_factory(self.scale[c], experiment, channel = c)
else:
self._scale[c] = util.scale_factory(util.get_default_scale(), experiment, channel = c)
warpings = {}
for channel in self.channels:
# scikit-fda requires that all the functions (in this case, the
# KDEs) be on the same support.
all_data = pd.concat([group_data[channel] for _, group_data in groupby],
ignore_index = True, sort = True)
all_scaled_data = self._scale[channel](all_data)
if self.bw == 'scott':
bw = bw_scott(all_scaled_data)
elif self.bw == 'silverman':
bw = bw_silverman(all_scaled_data)
else:
bw = self.bw
# support we calculate on is scaled, not in data units.
support = _kde_support(all_scaled_data, bw, self.gridsize, 3.0,
(-np.inf, np.inf))
# but the support we SAVE is in DATA UNITS.
support = self._scale[channel].inverse(support)
# the support must be strictly increasing, so remove duplicates
# (ie, values that were clipped in the inverse)
self._support[channel] = support = np.unique(support)
# re-scale the (inverted) support
support = self._scale[channel](support)
all_peaks = []
self._kde[channel] = {}
self._peaks[channel] = {}
for group, group_data in groupby:
#compute the KDE
scaled_data = self._scale[channel](group_data[channel])
kde = KernelDensity(kernel = self.kernel, bandwidth = bw)
kde.fit(scaled_data.to_numpy()[:, np.newaxis])
density = np.exp(kde.score_samples(support[:, np.newaxis]))
self._kde[channel][group] = density
# find the peaks
density_max = np.max(density)
peaks = scipy.signal.find_peaks(density,
prominence = 0.1 * density_max)[0].tolist()
peaks = [support[p] for p in peaks]
self._peaks[channel][group] = [self._scale[channel].inverse(p) for p in peaks]
if not all_peaks:
all_peaks = peaks
else:
all_peaks.extend(peaks)
# cluster the peaks ACROSS GROUPS. we want the minumum number
# of clusters where no two peaks in the same group are
# assigned to the same cluster.
self._clusters[channel] = {}
for n_clusters in range(1, len(all_peaks)):
km = sklearn.cluster.KMeans(n_clusters = n_clusters,
random_state = 0)
km.fit(np.array(all_peaks).reshape(-1, 1))
for group, _ in groupby:
peaks = self._scale[channel](self._peaks[channel][group])
cluster_assignments = km.predict(np.array(peaks).reshape(-1, 1)).tolist()
# check for duplicates
if len(cluster_assignments) != len(set(cluster_assignments)):
continue
self._clusters[channel][group] = cluster_assignments
if len(self._clusters[channel]) == groupby.ngroups:
break
# get rid of clusters that don't have a peak in each group
for cluster in range(0, n_clusters):
in_group = [cluster in self._clusters[channel][group] for group in self._clusters[channel]]
if not all(in_group):
for group in self._clusters[channel]:
try:
clust_idx = self._clusters[channel][group].index(cluster) # after previous, should only be in here once!
self._clusters[channel][group][clust_idx] = None
except ValueError:
# value wasn't in the list
pass
# now that we have clusters, compute the median of each cluster
self._means[channel] = []
for cluster in range(n_clusters):
clust_peaks = []
for group, _ in groupby:
peaks = self._scale[channel](self._peaks[channel][group])
cluster_assignments = self._clusters[channel][group]
try:
peak_idx = cluster_assignments.index(cluster)
except ValueError:
# this group didn't have a peak assigned to this cluster
continue
clust_peaks.append(peaks[peak_idx])
if clust_peaks:
self._means[channel].append(self._scale[channel].inverse(np.mean(clust_peaks)))
else:
self._means[channel].append(None)
# compute the warping to register the landmarks to the means
# we compute the warping on SCALED data
fd = FDataGrid(data_matrix = [self._scale[channel](self._kde[channel][group]) for group in self._kde[channel]],
grid_points = self._scale[channel](self._support[channel]))
landmarks = [[self._peaks[channel][group][i]
for i in range(len(self._peaks[channel][group]))
if self._clusters[channel][group][i] is not None]
for group in self._peaks[channel]]
landmarks = [self._scale[channel](el) for el in ([sorted(s) for s in landmarks])]
location = [s for s in self._means[channel] if s is not None]
location = self._scale[channel](sorted(location))
warping = landmark_elastic_registration_warping(fd,
landmarks = landmarks,
location = location)
# i don't know why i need to do this :(
# clearly i don't understand FDA
warpings[channel] = invert_warping(warping)
# # set atomically to support the GUI
self._warping = warpings
# self._warp_functions = warp_functions
[docs]
def apply(self, experiment):
"""
Applies the bleedthrough correction to an experiment.
Parameters
----------
experiment : `Experiment`
the experiment to which this operation is applied
Returns
-------
Experiment
A new experiment with the specified channels warped to bring their
density maxima into registration.
"""
if experiment is None:
raise util.CytoflowOpError('experiment', "No experiment specified")
if not self._warping:
raise util.CytoflowOpError(None,
"Registration warp not found. "
"Did you forget to call estimate()?")
if not set(self.channels) <= set(experiment.channels):
raise util.CytoflowOpError('units',
"Warp channels don't match experiment channels")
if set(self.channels) != set(self._warping):
raise util.CytoflowOpError('units',
"Registration warp doesn't match channels. "
"Did you forget to call estimate()?")
new_experiment = experiment.clone(deep = True)
if self.by:
groupby = experiment.data.groupby(self.by, observed = False)
else:
# use a lambda expression to return a group that contains
# all the events
groupby = experiment.data.groupby(lambda _: True, observed = False)
for channel in self.channels:
scale = self._scale[channel]
warping = self._warping[channel]
for group_idx, (_, group_data) in enumerate(groupby):
new_experiment.data.loc[group_data.index, channel] = \
scale.inverse(warping(scale(group_data[channel])))[group_idx]
if 'range' in new_experiment.metadata[channel]:
new_experiment.metadata[channel]['range'] = max(scale.inverse(warping(scale(experiment.metadata[channel]['range']))))[0][0]
if 'voltage' in new_experiment.metadata[channel]:
del new_experiment.metadata[channel]['voltage']
new_experiment.history.append(self.clone_traits(transient = lambda t: True))
return new_experiment
[docs]
def default_view(self, **kwargs):
"""
Returns a diagnostic plot to see if the peak finding is working.
Returns
-------
`IView`
An diagnostic view, call `BeadCalibrationDiagnostic.plot` to
see the diagnostic plots
"""
v = RegistrationDiagnosticView(op = self)
v.trait_set(**kwargs)
return v
[docs]
@provides(IView)
class RegistrationDiagnosticView(HasStrictTraits):
"""
A diagnostic view for `RegistrationOp`.
Plots the smoothed histogram of the bead data; the peak locations;
a scatter plot of the raw bead fluorescence values vs the calibrated unit
values; and a line plot of the model that was computed. Make sure that the
relationship is linear; if it's not, it likely isn't a good calibration!
Attributes
----------
op : Instance(`BeadCalibrationOp`)
The operation instance whose diagnostic we're plotting. Set
automatically if you created the instance using
`BeadCalibrationOp.default_view`.
"""
# traits
id = Constant("cytoflow.views.registrationdiagnosticview")
friendly_id = Constant("Registration Diagnostic")
op = Instance(RegistrationOp)
[docs]
def enum_plots(self, experiment):
"""
Enumerate the named plots we can make from this set of statistics.
Returns
-------
iterator
An iterator across the possible plot names.
"""
if experiment is None:
raise util.CytoflowViewError('experiment',
"No experiment specified")
if self.op._support and self.op.by:
return util.IterByWrapper(iter(self.op._support), ["Channel"])
else:
return util.IterByWrapper(iter([]), [])
[docs]
def plot(self, experiment, plot_name = None, **kwargs):
"""
Plots the diagnostic view.
Parameters
----------
experiment : `Experiment`
The experiment used to create the diagnostic plot.
plot_name : Str
The channel name to plot.
"""
if experiment is None:
raise util.CytoflowViewError('experiment', "No experiment specified")
if not self.op._support:
raise util.CytoflowViewError(None, "Must estimate the operations parameters first!")
if plot_name is None and len(self.op._support) == 1:
plot_name = list(self.op._support)[0]
if not plot_name:
raise util.CytoflowViewError('plot_name', "Must set 'plot_name' to one of the channels that was estimated!")
if not plot_name in self.op._support:
raise util.CytoflowViewError('plot_name',
"Channel {} was not estimated!"
.format(self.channel))
channel = plot_name
scale = self.op._scale[channel]
groups = self.op._kde[channel].keys()
kde_support = self.op._support[channel]
# let's not use the FacetGrid stuff here, eh?
fig, axes = plt.subplots(len(groups), 1, sharex = True)
fig.set_constrained_layout_pads(hspace = 0.0, h_pad = 0.0)
for i, group in enumerate(groups):
ax = axes[i]
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylabel(', '.join([str(g) for g in group]))
ax.set_xscale(scale.name, **scale.get_mpl_params(plt.gca().get_xaxis()))
ax.grid(False)
ax.tick_params(axis = 'y', which = "both", left = False, labelleft = False)
if i < len(self.op._groups) - 1:
ax.tick_params(axis = 'x', which = "both", bottom = False, labelbottom = False)
# plot the density
kde_density = self.op._kde[channel][group]
x = kde_support
before_artist = axes[i].plot(x, kde_density)
scaled_support = scale(kde_support)
warped_x = self.op._warping[channel](scaled_support)[i]
warped_x = scale.inverse(warped_x)
after_artist = axes[i].plot(warped_x, kde_density, color = 'r')
if i == 0:
before_artist[0].set_label("Before registration")
after_artist[0].set_label("After registration")
# plot the peaks
peaks = self.op._peaks[channel][group]
for peak in peaks:
ax.axvline(peak, color = 'b', linestyle = '--')
# plot cluster peaks, means
means = self.op._means[channel]
cluster_assignments = self.op._clusters[channel][group]
for cluster_idx, mean in enumerate(means):
if mean is None:
continue
ax.axvline(mean, color = 'grey', linestyle = '-')
try:
peak_idx = cluster_assignments.index(cluster_idx)
y = 0.1
if abs(scale(peaks[peak_idx]) - scale(mean)) > 0.01:
ax.annotate("", xytext = (peaks[peak_idx], y), xy = (mean, y),
arrowprops=dict(width = 1, headwidth = 5, headlength = 3, color = 'k'))
except ValueError:
# this group didn't have a peak assigned to this cluster
continue
# axes[i].set_title("{} = {}".format(', '.join(self.op.by),
# ', '.join([str(g) for g in group])))
fig.draw_without_rendering()
plt.xlabel(channel)
fig.supylabel(', '.join(self.op.by))
# plot a figure legend
fig.legend(loc = 'outside right upper')
title = kwargs.pop('title', None)
if title:
plt.suptitle(title)