from scipy.spatial import Delaunay
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cmaps
from matplotlib import colors
import matplotlib.gridspec as gs
from ..dynamics.averages import get_rolling_mean
from ..dynamics.visualization import plot_mean
[docs]class Triangulation:
"""
Object for estimating the median distance between adjacent columns of R8 cells within an individual eye disc. Distance estimate is obtained by constructing a Delaunay graph connecting all annotated R8 neurons, filtering the edges by length and angle relative to the horizontal axis, then evaluating the median x-component of remaining edges.
The median inter-column distance is multiplied by the estimated MF velocity (0.5 columns/hr) to generate a distance-to-time scaling factor.
Attributes:
params (dict) - triangulation parameters, {name: value}
xycoords (np.ndarray) - R8 cell positions
delaunay (scipy.spatial.tri) - Delaunay triangulation
distances (np.ndarray) - distances between adjacent R8 cells
edges (np.ndarray) - edge vertices
hours_per_pixel (float) - distance to time scaling factor
disc (data.discs.Disc)
"""
def __init__(self, disc,
furrow_velocity=2,
threshold=None,
min_angle=30,
max_angle=60,
include_x=True,
include_y=False):
"""
Instantiate object for estimating the median distance between adjacent columns of R8 cells.
Args:
disc (data.discs.Disc)
furrow_velocity (float) - furrow inverse-velocity (hours per column)
threshold (float) - max. quantile of included distances, 0 to 100
min_angle, max_angle (float) - min/max angle of included edges
include_x (bool) - if True, include x-distance
include_y (bool) - if True, include y-distance
"""
self.triangulate(disc, furrow_velocity, threshold, min_angle=min_angle, max_angle=max_angle, include_x=include_x, include_y=include_y)
self.params = {'furrow_velocity': furrow_velocity,
'threshold': threshold,
'min_angle': min_angle,
'max_angle': max_angle,
'include_x': include_x,
'include_y': include_y}
def __call__(self, disc):
"""
Apply distance to time scaling to a disc.
Args:
disc (data.discs.Disc)
Returns:
disc (data.discs.Disc) - disc with estimated developmental times
"""
disc = self._apply_time_scaling(disc, self.hours_per_pixel)
return disc
[docs] def get_disc(self):
""" Return disc. """
return self.disc
@staticmethod
def _apply_time_scaling(disc, hours_per_pixel):
"""
Update developmental times.
Args:
disc (data.discs.Disc)
hours_per_pixel (float) - distance to time scaling factor
Returns:
disc (data.discs.Disc) - disc with estimated developmental times
"""
disc.data['t'] = disc.data.centroid_x * hours_per_pixel
return disc
@staticmethod
def _get_delaunay(xycoords):
""" Return Delaunay triangulation for xy points. """
return Delaunay(xycoords)
@staticmethod
def _get_edges(delaunay,
min_angle=30,
max_angle=60,
include_x=True,
include_y=False):
""" Get distances between adjacent R8 neurons. """
# get indices of vertex neighbors
indices, indptr = delaunay.vertex_neighbor_vertices
# iterate through all R8 cells
distances, edges = [], []
for k, (x1, y1) in enumerate(delaunay.points):
# get all neighbors of current R8
neighbors = delaunay.points[indptr[indices[k]:indices[k+1]]]
# if neighbor is within a 30-60 degree angle from horizontal, include x-coordinate of its distance
for x2, y2 in neighbors:
theta = np.arctan(abs((y2-y1)/(x2-x1)))
# if neighbor is within the specified angle from horizontal, include distance
if theta >= min_angle*np.pi/180 and theta <= max_angle*np.pi/180:
distances.append(np.sqrt(((x2-x1)**2)*include_x + ((y2-y1)**2)*include_y))
edges.append([[x1, x2], [y1, y2]])
return np.array(distances), np.array(edges)
@staticmethod
def _filter_edges_by_length(distances, edges, threshold=1.75):
"""
Filter edges by length. Length threshold is computed as a multiple of the median edge length.
Args:
distances (np.ndarray) - edge lengths
edges (np.ndarray) - edges
threshold (float) - maximum multiple of median length
Returns:
distances (np.ndarray) - filtered edge lengths
edges (np.ndarray) - filtered edges
"""
indices = np.where(distances < threshold*np.median(distances))[0]
return distances[indices], edges[indices]
[docs] def triangulate(self, disc,
furrow_velocity=2,
threshold=None,
min_angle=30,
max_angle=60,
include_x=True,
include_y=False):
"""
Run triangulation.
Args:
disc (data.discs.Disc)
furrow_velocity (float) - furrow inverse-velocity (hr/column)
threshold (float) - max quantile of included distances, 0 to 100
min_angle, max_angle (float) - min/max angle of included edges
include_x (bool) - if True, include x-distance
include_y (bool) - if True, include y-distance
"""
# get coordinates
xycoords = disc.data[disc.data.label=='r8'][['centroid_x', 'centroid_y']].values
self.xycoords = xycoords
# get triangulation
self.delaunay = self._get_delaunay(xycoords)
# get edges
self.distances, self.edges = self._get_edges(self.delaunay, min_angle=min_angle, max_angle=max_angle, include_x=include_x, include_y=include_y)
# filter edges
if threshold is not None:
self.distances, self.edges = self._filter_edges_by_length(distances=self.distances, edges=self.edges, threshold=threshold)
# compute mean distance to time scaling
self.hours_per_pixel = furrow_velocity/np.mean(self.distances)
# update time vector
self.disc = self._apply_time_scaling(disc, self.hours_per_pixel)
@staticmethod
def _get_log2_fold_change(values):
return np.log2(values/values.mean())
@classmethod
def _add_edges_to_plot(cls,
distances,
edges,
ax,
hours_per_pixel,
cmap=cmaps.coolwarm):
""" Add delaunay edges to existing axes. """
# get scores for colormap
scores = cls._get_log2_fold_change(distances)
# plot lines
for edge, score in zip(edges, scores):
times = [x * hours_per_pixel for x in edge[0]]
ax.plot(times, edge[1], '-', linewidth=2, alpha=1, color=cmap.to_rgba(score), zorder=1)
ax.set_yticks([])
ax.set_xlabel('time (hr)')
[docs] def add_edges_to_plot(self, ax, cmap=cmaps.coolwarm):
""" Add delaunay edges to existing axes. """
cmap = self.get_colormap(cmap)
self._add_edges_to_plot(distances=self.distances, edges=self.edges, ax=ax,
hours_per_pixel=self.hours_per_pixel, cmap=cmap)
ax.set_ylim(self.xycoords[:, 1].min(), self.xycoords[:, 1].max())
@staticmethod
def get_colormap(cmap=cmaps.coolwarm):
norm = colors.Normalize(vmin=-1, vmax=1)
colormap = cmaps.ScalarMappable(norm=norm, cmap=cmap)
return colormap
@classmethod
def _plot_histogram(cls, values, ax=None, dist_type='x'):
""" Plot colored histogram for a set of values. """
# get colormap
cmap = cls.get_colormap(cmap=cmaps.coolwarm)
# histogram values
counts, bin_edges = np.histogram(values)
bin_centers = [(edge+bin_edges[i+1])/2 for i, edge in enumerate(bin_edges[:-1])]
scores = np.log2(bin_centers / np.mean(values))
patches = ax.bar(bin_edges[:-1], counts, width=(bin_edges[1]-bin_edges[0]), color=[cmap.to_rgba(score) for score in scores])
# format plot
ax.set_yticks([])
ax.set_xlim(0, 100)
ax.text(95, .9*max(counts), s=' Mean: {:0.1f} px'.format(np.mean(values)), ha='right', va='top')
ax.text(95, .9*max(counts), s='\n Med.: {:0.1f} px'.format(np.median(values)), ha='right', va='top')
ax.text(95, .9*max(counts), s='\n\n N = {:d}'.format(int(len(values)/2)), ha='right', va='top')
_ = ax.set_xlabel(dist_type+'-distance')
_ = ax.set_ylabel('edges')
return counts, bin_edges, patches
[docs] def plot_histogram(self, ax):
""" Histogram inter-R8 distances. """
dist_type = self.params['include_x']*'x'+self.params['include_y']*'y'
self._plot_histogram(self.distances, ax=ax, dist_type=dist_type)
@staticmethod
def _plot_expression(ax, disc, channel, hours_per_pixel,
window_size=100,
color='black',
alpha=1):
""" Plot expression trajectories. """
cells = disc.select_cell_type('pre')
cells.plot_dynamics(channel, ax=ax,
line_kw={'color': color, 'alpha': alpha, 'lw': 2})
# format x axis
ax.set_ylim(0, 2)
ax.set_yticks([])
ax.set_xlim(-15, 55)
ax.set_xticks(np.arange(-10, 60, step=10))
ax.set_xticklabels([str(int(round(label, 0))) for label in ax.get_xticks()])
[docs] def plot_expression(self, ax, channel, **kwargs):
""" Plot expression trajectory. """
self._plot_expression(ax, self.disc, channel, self.hours_per_pixel, **kwargs)
[docs] def overlay_epression(self, ax, channel, **kwargs):
""" Plot expression trajectory on twin y-axis. """
ax_alt = ax.twinx()
self.plot_expression(ax_alt, channel, **kwargs)
[docs] def show(self,
gs_parent=None,
include_expression=True,
channel=None,
is_subplot=False, **kwargs):
"""
Plot inter-R8 distance distribution, Delaunay triangulation, and expression.
"""
# retriangulate
self.triangulate(self.disc, **self.params)
# create axes
if gs_parent is None:
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(6, 2))
else:
gs_child = gs.GridSpecFromSubplotSpec(1, 2, width_ratios=[1, 1.5], subplot_spec=gs_parent, hspace=0)
ax0 = plt.subplot(gs_child[0])
ax1 = plt.subplot(gs_child[1])
is_subplot = True
# add colorbar
self.add_colorbar(ax1, is_subplot=is_subplot)
# plot edges, expression, and histogram
self.add_edges_to_plot(ax1, cmap=cmaps.coolwarm)
if include_expression and channel is not None:
self.overlay_epression(ax1, channel, **kwargs)
self.plot_histogram(ax=ax0)
plt.tight_layout()
return ax0, ax1
@staticmethod
def add_colorbar(ax, is_subplot=False, fraction=0.2):
mappable = ax.scatter([1e6, 1e6], [0, 0], alpha=1, c=[-1, 1], cmap=cmaps.coolwarm)
cbar = plt.colorbar(mappable=mappable, ax=ax, fraction=fraction)
# simplify colorbar for subplots
if is_subplot is False:
cbar.set_ticks([-1, 0.35])
cbar.ax.tick_params(length=0)
cbar.ax.set_yticklabels(['compressed', 'stretched'], rotation='vertical', fontsize=6, ha='left', va='bottom')
cbar.set_label('Log2(F.C. wrt mean)', fontsize=8)
else:
cbar.set_ticks([-1, -0.5, 0, 0.5, 1])
cbar.ax.tick_params(length=0, labelsize=8)
cbar.set_label('Log2(F.C. wrt mean)', fontsize=8)
[docs]class ExperimentTriangulation:
"""
Object for estimating the median distance between adjacent columns of R8 cells for each disc within an experiment. Distance estimate is obtained by constructing a Delaunay graph connecting all annotated R8 neurons, filtering the edges by length and angle relative to the horizontal axis, then evaluating the median x-component of remaining edges.
The median inter-column distance is multiplied by the estimated MF velocity (0.5 columns/hr) to generate a distance-to-time scaling factor.
Attributes:
experiment (data.experiments.Experiment)
tri (dict) - {disc ID: Triangulation} pairs
"""
def __init__(self, experiment, **kwargs):
"""
Instantiate triangulation objects for all discs in an experiment.discs
Args:
experiment (data.experiments.Experiment)
kwargs: triangulation keyword arguments
"""
discs = experiment.discs
self.tri = self._get_triangulations(discs, **kwargs)
self.experiment = self._apply_triangulations(experiment, self.tri)
def __call__(self):
""" Return experiment with updated developmental times. """
return self.experiment
@staticmethod
def _apply_triangulations(experiment, tri):
"""
Return experiment with updated developmental times.
Args:
experiment (data.experiments.Experiment) - experiment to be updated
tri (dict) - {disc ID: Triangulation} pairs
Returns:
experiment (data.experiments.Experiment) - updated experiment
"""
experiment.discs = {i: t.get_disc() for i, t in tri.items()}
return experiment
@staticmethod
def _get_triangulations(discs, **kwargs):
"""
Return dictionary of Triangulation objects.
Args:
discs (dict) - {disc ID: Disc} pairs
kwargs: keyword arguments for Triangulation
Returns:
tri (dict) - {disc ID: Triangulation} pairs
"""
return {i: Triangulation(disc, **kwargs) for i, disc in discs.items()}
[docs] def plot_expression(self, ax, channel,
color='black',
**kwargs):
""" Plot expression for all triangulations. """
for triangulation in self.triangulations.values():
triangulation.plot_expression(ax, channel=channel, color=color, **kwargs)
[docs] def show_triangulations(self):
""" Visualize all triangulations. """
gs_parent = gs.GridSpec(nrows=len(self.triangulations), ncols=1)
fig = plt.figure(figsize=(6, 1.5*len(self.triangulations)))
for i, (triangulation, gs0) in enumerate(zip(self.triangulations.values(), gs_parent)):
ax0, ax1 = triangulation.show(gs_parent=gs0)
if i != len(self.triangulations)-1:
ax0.set_xticks([]), ax0.set_xlabel('')
ax1.set_xticks([]), ax1.set_xlabel('')
plt.tight_layout()
return fig
[docs] def show_alignment(self, channel,
xoffsets=None,
ax=None,
scatter=False,
legend=True,
window_size=100,
ma_type='sliding',
color_wheel='cmyk',
figsize=(4, 3)):
""" Plot alignment of all discs. """
# assume zero offset for each channel
if xoffsets is None:
xoffsets = np.zeros(len(self.triangulations))
# create axes if none provided
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
# add reference lines
ax.plot([0, 0], [0, 3], '--k')
ax.plot([-15, 55], [0.25, 0.25], '--k')
handles, labels = [], []
for i, triangulation in self.triangulations.items():
# get line color for disc
color = color_wheel[i % len(color_wheel)]
# get cells
disc_cells_data = triangulation.data[np.logical_or(triangulation.data.label=='pre', triangulation.data.label=='p')]
if scatter is True:
ax.plot(disc_cells_data.t + xoffsets[i], disc_cells_data[channel], '.', alpha=0.1, color=color)
# add line average
line = plot_mean(disc_cells_data.t + xoffsets[i], (disc_cells_data[channel]),
ax=ax, label='Disc {:d}'.format(i), ma_type=ma_type, window_size=window_size,
line_color=color, line_width=3, line_alpha=0.5)
handles.append(line[0]), labels.append('Disc {:d}'.format(i))
if legend is True:
ax.legend(handles=handles, labels=labels, loc=0, frameon=False)
ax.tick_params(labelsize=16)
ax.set_ylim(0, 2)
ax.set_ylabel('Fluorescence (a.u.)', fontsize=16)
ax.set_xlim(-5, 55)
ax.set_xlabel('Time (hr)', fontsize=16)