Source code for flyqma.selection.interface

from os.path import join
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.gridspec import GridSpec

from ..utilities import IO

[docs]class LayerVisualization: """ Object for visualizing a single layer. Attributes: path (str) - layer path layer (flyqma.Layer) - layer instance axes (array like) - axes for blue/red/green color channels """ def __init__(self, layer, axes): """ Instantiate layer visualization. Args: layer (Layer) - image layer axes (array like) - axes for blue/red/green color channels """ self.layer = layer # set selection path layer.make_subdir('selection') self.path = layer.subdirs['selection'] # set axes self.axes = axes # render images self.render_images(layer)
[docs] def render_images(self, layer, cmap=None): """ Add blue, green, and red channels of layer to axes. Args: layer (Layer) - RGB image layer cmap (matplotlib.colors.ColorMap) """ # visualize layers for ch, ax in enumerate(self.axes): _ = layer.get_channel(ch).show(segments=False, ax=ax, cmap=cmap) ax.set_aspect(1) # add layer number text = ' {:d}'.format(layer._id) self.axes[0].text(0, 0, text, fontsize=14, color='y', va='top')
[docs] def add_marker(self, x, y, color='k', markersize=10): """ Add marker to layer images. """ for ax in self.axes: ax.plot(x, y, '.', color=color, markersize=markersize, zorder=2)
[docs] def remove_marker(self): """ Remove marker from layer images. """ for ax in self.axes: ax.lines[-1].remove()
[docs] def update_marker(self, color, markersize, ind=-1): """ Update size and color of last added marker. """ for ax in self.axes: if len(ax.lines) > 0: ax.lines[ind].set_color(color) ax.lines[ind].set_markersize(markersize)
[docs] def clear_markers(self): """" Remove all markers. """ while len(self.axes[0].lines) > 0: self.remove_marker()
[docs] def add_polygon(self): """ Add polygon to each image. """ for ax in self.axes: poly = Polygon(self.pts, ec=(1,1,0,1), lw=1, fc=(1,1,1,0.2), fill=False, zorder=1, closed=True) ax.add_patch(poly)
[docs] def remove_polygon(self): """ Remove polygon from each image. """ for ax in self.axes: ax.patches[0].remove()
[docs] def overlay(self, msg, s=18): """ Overlay string centered on image. """ for ax in self.axes: ax.images[0].set_alpha(0.5) x, y = np.mean(ax.get_xlim()), np.mean(ax.get_ylim()) ax.text(x, y, msg, color='k', fontsize=s, ha='center', va='center')
[docs]class LayerInterface(LayerVisualization): """ Event handler for an individual layer. Attributes: include (bool) - flag for layer inclusion active_polyhon (bool) - if True, polygon is currently active pts (list) - selection boundary points traceback (list) - exception traceback Inherited attributes: path (str) - layer path layer (flyqma.Layer) - layer instance axes (array like) - axes for blue/red/green color channels """ def __init__(self, layer, axes): """ Instantiate layer interface. Args: layer (Layer) - image layer axes (array like) - axes for blue/red/green color channels """ # call visualization instantiation method super().__init__(layer, axes) # set layer attributes self.include = True # no initial polygon self.active_polygon = False # initialize points list self.pts = [] self.traceback = []
[docs] def load(self): """ Load layer selection. """ io = IO() # load selected points pts = io.read_npy(join(self.path, 'selection.npy')) self.pts = pts.tolist() # load selection metadata md = io.read_json(join(self.path, 'md.json')) self.include = md['include'] # add markers for pt in self.pts: self.add_marker(*pt, color='y', markersize=5) self.update_marker('r', markersize=10) # add polygon if len(self.pts) >= 3: self.add_polygon() self.active_polygon = True # mark excluded layers if self.include==False: self.overlay('EXCLUDED')
[docs] def save(self): """ Save selected points and selection metadata to file. """ # if no region was specified, use corners (e.g. include everything) if len(self.pts) <= 2: w,h = self.layer.shape self.pts = [ [0,0], [w, 0], [w, h], [0, h] ] io = IO() pts = np.array(self.pts) io.write_npy(join(self.path, 'selection.npy'), pts) md = dict(include=self.include) io.write_json(join(self.path, 'md.json'), md) # update measurements self.layer.load_inclusion() self.layer.define_roi( self.layer.save_processed_data()
[docs] def clear(self): """ Clear all points from layer selection bounds. """ self.pts = [] self.clear_markers() if len(self.axes[0].patches) > 0: self.remove_polygon() self.active_polygon = False
[docs] def add_point(self, pt): """ Add point to layer selection bounds. """ # store point self.pts.append(pt) # update previous marker and add new marker self.update_marker(color='y', markersize=5) self.add_marker(*pt, color='r', markersize=10) # update polygon if self.active_polygon: self.update_polygon() elif len(self.pts) == 3: self.add_polygon() self.active_polygon = True
[docs] def remove_point(self): """ Remove last point added to layer selection bounds. """ _ = self.pts.pop() self.remove_marker() self.update_marker(color='r', markersize=10) if len(self.pts) < 3: self.active_polygon = False
[docs] def update_polygon(self): """ Update polygon for each image. """ self.remove_polygon() if self.active_polygon: self.add_polygon()
[docs] def undo(self): """ Remove last point added and update polygon. """ self.remove_point() self.update_polygon()
[docs]class StackInterface: """ Object for visualizing multiple layers in an image stack. Attributes: path (str) - layer path axes (array like) - axes for blue/red/green color channels """ def __init__(self, stack): self.path = stack.path self.build_interface(stack)
[docs] def build_interface(self, stack): """ Build interface by adding interface for each layer. Args: stack (Stack) - image stack """ # create figure nrows, ncols = stack.stack_depth, stack.color_depth figsize = (2.25 * ncols, 2.25 * nrows) self.fig = plt.figure(figsize=figsize) gs = GridSpec(nrows=nrows, ncols=ncols, wspace=.01, hspace=.01) # instantiate maps self.layer_to_interface = {} self.ax_to_layer = {} # build interface for each layer for i in range(stack.stack_depth): # iterator excludes discs marked for excludion layer = stack[i] # create all axes for current layer axes = [self.fig.add_subplot(gs[i*ncols+j]) for j in range(ncols)] # add layer gui to layer --> interface map self.layer_to_interface[i] = LayerInterface(layer, axes) # update axis --> layer map for ax in axes: self.ax_to_layer[ax] = i # label top row if i == 0: for j, ax in enumerate(axes): ch_label = 'Channel {:d}'.format(j) ax.set_title(ch_label, fontsize=14)