Source code for cows.catalogue

import numpy as np

from .filament import label_skeleton, find_filaments


[docs]def gen_catalogue(data, periodic=False, sort=True): ''' Generate a catalogue of filaments Generates a catalogue of filaments given an array containing a cleaned and separated skeleton. Parameters ---------- data : ndarray, 3D An array containing the separated skeleton. Zeros represent background, ones are endpoints and, twos are regular cells. periodic: bool If True, the skeletonization uses periodic boundary conditions for the input array. Input array must be 3D. sort : boolean If sort=True, the filaments are sorted by filament length in descending order and reassigned IDs such that the longest filament has an ID of zero. Returns ------- result : ndarray, 3D An array with data.shape containing the sets of connected cells with their respective ID. catalogue : ndarray, 2D A catalogue containing, for each cell, a row of filament ID, filament length, X-, Y-, Z-position, X-, Y-, Z-direction. Notes ----- The function assumes that values larger than zero are part of the skeleton. Filament length is defined here as the number of member cells. ''' ncells = data.shape[0] # Classify the skeleton to identify endpoints and regular cells data = label_skeleton(data, periodic=periodic) # Connect cells within a 3x3x3 neighbourhood from endpoint to endpoint # and store in the first column _, cat = find_filaments(data, periodic=periodic) # Crop the catalogue if not completely full if np.any(cat[:, 0] == 0): cat = cat[:np.argmin(cat[:, 0])] # Copy into new array catalogue = np.zeros([cat.shape[0], 8], order='c') catalogue[:,0] = cat[:,0] # Return the empty catalogue if no filaments were found if catalogue.shape[0] == 0: return catalogue # Store the filament cell positions catalogue[:,2:5] = cat[:,1:] # Calculate the filament lengths and store in the second column group_lengths = np.diff(np.hstack([0,np.where(np.diff(cat[:,0])!=0)[0]+1, len(cat[:,0])])) catalogue[:,1] = np.repeat(group_lengths, group_lengths) # Calculate and store the filament cell directions catalogue[:,5:8] = _get_direction(cat[:,0], cat[:,1:], ncells) # Sort the catalogue by filament length in descending order if sort: sort_idx = np.lexsort([catalogue[:,0],catalogue[:,1]])[::-1] catalogue = catalogue[sort_idx] group_lengths = np.sort(group_lengths)[::-1] catalogue[:,0] = np.repeat(np.arange(np.max(catalogue[:,0]))+1, group_lengths) return catalogue
def _get_direction(index, pos, box_size): ''' Calculates the direction of a filament cell based on the location of that cells' neighbours. The direction vector is normalised. ''' assert pos.ndim == 2 assert pos.shape[1] == 3 dxyz = np.zeros(pos.shape) # Find the beginning and end of filaments in the catalogue idx_diff = 1 - np.diff(index) # Calculate position vector between 2 neighbours and account for periodic dxyz_tmp = np.mod((pos[:-1]-pos[1:])+1, box_size) - 1 # Add direction to appropriate indices dxyz_tmp = dxyz_tmp * idx_diff[:,None] # set to 0 between filaments dxyz[:-1] += dxyz_tmp dxyz[1:] += dxyz_tmp # Noramlise direction vector r = np.sqrt(np.sum(dxyz**2, axis=1)) return dxyz/r[:,None]