Source code for flyeye.data.cells

from copy import deepcopy
from functools import reduce
from operator import add
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re

from ..utilities.string_handling import format_channel, standardize_channels
from ..dynamics.visualization import TimeseriesPlot, IntervalPlot
from ..dynamics.resampling import DiscResampler


[docs]class CellProperties: """ Properties of Cells object. """ @property def channels(self): """ List of unique fluorescence channels. """ return [s for s in self.data.columns if len(re.findall('ch[0-9]+$', s))] @property def normalized_channels(self): """ List of normalized channel names. """ return [ch+'_normalized' for ch in self.channels] @property def num_channels(self): """ Number of unique fluorescence channels. """ return len(self.channels) @property def is_sorted(self): return (self.data.centroid_x.diff() < 0).sum() == 0 @property def cell_type_counts(self): return self.data.groupby('label')['label'].count() @property def cell_types(self): """ Unique cell types. """ return self.cell_type_counts.index.tolist()
[docs]class Cells(CellProperties): """ Object represents a population of cells. Each cell is completely described by a single record in a DataFrame of cell measurements. These measurements include cell positions, expression levels, and cell type annotations. Object may contain cells of one or more cell types. Attributes: data (pd.DataFrame) - cell measurement data normalization (str or int) - channel used to normalize intensities """ def __init__(self, data=None, normalization=None): """ Instantiate population of cells. Args: data (pd.DataFrame) - cell measurement data normalization (str or int) - channel used to normalize intensities """ # store measurements if data is None: data = pd.DataFrame() self.data = standardize_channels(data) # store normalization self.normalization = format_channel(normalization) # standardize levels if len(self.data) > 0: self.sort() def __add__(self, cells): """ Concatenate second Cell instance. """ cells = Cells(pd.concat((self.data, cells.data), sort=True), self.normalization) cells.sort(by='t') return cells
[docs] def sort(self, by='centroid_x'): """ Sort cell measurements in place. Args: by (str) - key on which measurements are sorted """ self.data = self.data.sort_values(by=by, ascending=True)
[docs] def apply_lag(self, lag): """ Shift cells in time. Args: lag (float) - shift (NOTE: x-positions are unaffected) """ self.data['t'] += lag
[docs] def select_cell_type(self, cell_types): """ Select subset of cells corresponding to a specified label. Args: cell_types (str or list) - type of cells to be selected (e.g. pre, r8) Returns: cells (data.cells.Cells) """ # convert string to list if type(cell_types) == str: cell_types = [cell_types] # add both precursor labels if needed if 'pre' in cell_types: cell_types.append('p') elif 'p' in cell_types: cell_types.append('pre') # select cells data = self.data[self.data.label.apply(lambda x: x in cell_types)] # instantiate cells object cells = Cells(data, self.normalization) return cells
[docs] def select_by_position(self, xmin=-np.inf, xmax=np.inf, ymin=-np.inf, ymax=np.inf, zmin=-np.inf, zmax=np.inf, tmin=-np.inf, tmax=np.inf): """ Select subset of cells within specified spatial bounds. Args: xmin, xmax (float) - x-coordinate bounds ymin, ymax (float) - y-coordinate bounds zmin, zmax (float) - z-coordinate (layer number) bounds tmin, tmax (float) - time interval bounds Returns: cells (data.cells.Cells) - copied subset of cells """ # initialize filter to include all cells data = deepcopy(self.data) # apply sequential filters data = data[data['centroid_x'].between(xmin, xmax)] data = data[data['centroid_y'].between(ymin, ymax)] data = data[data['layer'].between(zmin, zmax)] data = data[data['t'].between(tmin, tmax)] # instantiate subpopulation cells = Cells(data, self.normalization) return cells
[docs] def get_nuclear_diameter(self): """ Returns median nuclear diameter. Diameters are approximated as that of a circle with equivalent area to each nuclear contour. Returns: nuclear_diameter (float) - median diameter """ return (2*np.sqrt(self.data.pixel_count/np.pi)).median()
[docs] @staticmethod def get_binned_mean(x, values, bins=None, bin_width=1): """ Bin cells and compute mean for each bin. Args: x (pd.Series) - coordinate on which to bin values values (pd.Series) - values to be aggregated bins (np array) - edges for specified bins bin_width (float) - width of bins used if no bins specified Returns: bin_centers (np.ndarray) - bin centers means (np array) - mean value within each bin """ if bins is None: bins = np.arange(x.min(), x.max(), bin_width) bin_centers = [bins[i] + (bins[i+1] - bins[i])/2 for i in range(0, len(bins)-1)] means, _, _ = st.binned_statistic(x, values, statistic='mean', bins=bins) return bin_centers, means
[docs] def plot_dynamics(self, channel, ax=None, scatter=False, average=True, interval=False, marker_kw={}, line_kw={}, interval_kw={}, ma_kw={}): """ Plot expression dynamics for specified channel. Args: channel (str) - expression channel ax (mpl.axes.AxesSubplot) - if None, create axes scatter (bool) - if True, add markers for each measurement average (bool) - if True, add moving average interval - if True, add confidence interval for moving average marker_kw (dict) - keyword arguments for marker formatting line_kw (dict) - keyword arguments for line formatting interval_kw (dict) - keyword arguments for interval formatting ma_kw (dict) - keyword arguments for interval construction Returns: ax (mpl.axes.AxesSubplot) """ # sort values inplace self.sort('t') # instantiate TimeseriesPlot x, y = self.data.t.values, self.data[channel].values tsplot = TimeseriesPlot(x, y, ax=ax) # plot dynamics tsplot.plot(scatter=scatter, average=average, interval=interval, marker_kw=marker_kw, line_kw=line_kw, interval_kw=interval_kw, ma_kw=ma_kw) return tsplot.ax
[docs] def plot_resampled_dynamics(self, channel, ax=None, average=True, interval=False, marker_kw={}, line_kw={}, interval_kw={}, resampling_kw={}): """ Plot expression dynamics for specified channel, resampling from discrete subpopulations of cells. Args: channel (str) - expression channel ax (mpl.axes.AxesSubplot) - if None, create axes average (bool) - if True, add moving average interval - if True, add confidence interval for moving average line_kw (dict) - keyword arguments for line formatting interval_kw (dict) - keyword arguments for interval formatting resampling_kw (dict) - keyword arguments for disc resampler Returns: ax (mpl.axes.AxesSubplot) """ # sort values inplace self.sort('t') # resample discs and cells within them time = DiscResampler(self, 't', **resampling_kw).mean resampler = DiscResampler(self, channel, **resampling_kw) mean = resampler.mean lower, upper = resampler.confidence_interval # construct interval plot interval_plot = IntervalPlot(time, lower, upper, mean, ax=ax) # plot dynamics interval_plot.plot(average=average, interval=interval, line_kw=line_kw, interval_kw=interval_kw) return interval_plot.ax
[docs] def scatterplot(self, x, y, color='grey', s=5, alpha=0.5, fraction=False, ax=None): """ Create XY scatterplot of two fluorescence channels. Args: x, y (str or int) - channels used for x and y axes color (str) - marker color s (float) - marker size alpha (float) - transparency of markers fraction (bool) - if True, annotate fraction above midline ax (mpl.axes.AxesSubplot) - if None, create figure Returns: ax (mpl.axes.AxesSubplot) """ # get string representation of channel names x, y = format_channel(x), format_channel(y) # create figure if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) # format axes ax.set_xlim(0, 2), ax.set_ylim(0, 2) ax.set_xlabel(x) ax.set_ylabel(y) ax.tick_params(labelsize=10) ax.grid(True) # scatter data ax.scatter(self.data[x], self.data[y], c=color, s=s, alpha=alpha, lw=0) # add fraction above midline if fraction: ratio = self.data[y]/self.data[x] self.annotate_fraction(ax, ratio, p=2.5) return ax
[docs] @staticmethod def annotate_fraction(ax, ratio, p=2.5): """ Add fraction of cells above midline. Args: ax (mpl.axes.AxesSubplot) ratio (array like) - vector of ratios p (float) - text position relative to center line """ fraction = sum(ratio >= 1) / len(ratio) ax.text(p, p+0.5, '{:2.1%}'.format(fraction), ha='right', fontsize=8) ax.text(p, p-0.5, '{:2.1%}'.format(1-fraction), ha='left', fontsize=8)
[docs] def plot_spectrogram(self, channel, periods=None, ymax=None, ax=None, **kwargs): """ Plot Lomb Scargle periodogram. Args: channel (str or int) - expression channel periods (array like) - spectral frequencies to be tested ymax (float) - max spectral power ax (mpl.axes.AxesSubplot) kwargs: spectrogram visualization keywords Returns: ax (mpl.axes.AxesSubplot) """ # get string representation of channel name channel = format_channel(channel) # compile spectrogram precursors = self.select_cell_type('pre') spectrogram = Spectrogram(precursors.data.centroid_y.values, precursors.data[channel].values, periods=periods) # plot power spectrum ax = spectrogram.simple_visualization(ax=ax, ymax=ymax, **kwargs) return ax