diff --git a/assemble_traps_image.py b/assemble_traps_image.py new file mode 100644 index 0000000..61fccc7 --- /dev/null +++ b/assemble_traps_image.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +from __future__ import print_function, division +import six + +# import modules +import sys +import os +# import time +import inspect +import argparse +import numpy as np +import yaml +from pprint import pprint # for human readable file output +try: + import cPickle as pickle +except: + import pickle +import numpy as np +from scipy.io import savemat +from skimage import io +import skimage +import matplotlib.pyplot as plt + +# user modules +# realpath() will make your script run, even if you symlink it +cmd_folder = os.path.realpath(os.path.abspath( + os.path.split(inspect.getfile(inspect.currentframe()))[0])) +if cmd_folder not in sys.path: + sys.path.insert(0, cmd_folder) + +# This makes python look for modules in ./external_lib +cmd_subfolder = os.path.realpath(os.path.abspath( + os.path.join(os.path.split(inspect.getfile( + inspect.currentframe()))[0], "external_lib"))) +if cmd_subfolder not in sys.path: + sys.path.insert(0, cmd_subfolder) + +import mm3_helpers as mm3 + +if __name__ == "__main__": + + # set switches and parameters + parser = argparse.ArgumentParser(prog='python assemble_traps_image.py', + description='Assemble and save images of traps side-by-side.') + parser.add_argument('-f', '--paramfile', type=str, + required=True, help='Yaml file containing parameters.') + parser.add_argument('--frame', required=True, type=int, help="Defines the frame (1-indexed) to be sliced from each stack.") + parser.add_argument('-o', '--fov', type=str, + required=False, help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.') + parser.add_argument('--cell_segs', action='store_true', + required=False, help='Apply this argument if you would like cell segmentation results in addition to phase images.') + parser.add_argument('--focus_segs', action='store_true', + required=False, help='Apply this argument if you would like focus segmentation results in addition to phase images.') + parser.add_argument('--fluorescent_channels', type=str, + required=False, help='List of channels (as integers) to include in addition to the phase channel.') + + namespace = parser.parse_args() + + # Load the project parameters file + mm3.information('Loading experiment parameters.') + if namespace.paramfile: + param_file_path = namespace.paramfile + else: + mm3.warning('No param file specified. Using 100X template.') + param_file_path = 'yaml_templates/params_SJ110_100X.yaml' + p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + + if namespace.fov: + if '-' in namespace.fov: + user_spec_fovs = range(int(namespace.fov.split("-")[0]), + int(namespace.fov.split("-")[1])+1) + else: + user_spec_fovs = [int(val) for val in namespace.fov.split(",")] + else: + user_spec_fovs = [] + + # load specs file + specs = mm3.load_specs() + # print(specs) # for debugging + + # make list of FOVs to process (keys of channel_mask file) + fov_id_list = sorted([fov_id for fov_id in specs.keys()]) + + # remove fovs if the user specified so + if user_spec_fovs: + fov_id_list[:] = [fov for fov in fov_id_list if fov in user_spec_fovs] + + peaks_list = [peak_id for peak_id,val in specs[fov_id_list[0]].items() if val == 1] + img_height, img_width = mm3.load_stack(fov_id_list[0], peaks_list[0], color=p['phase_plane'])[0,:,:].shape + + # how many images total will we concatenate horizontally? + img_count = 0 + for fov_id in fov_id_list: + fov_peak_count = len([peak_id for peak_id,val in specs[fov_id].items() if val == 1]) + img_count += fov_peak_count + + # placeholder array of img_height, and proper width to hold all pixels from this fov + phase_arr = np.zeros((230,img_width*img_count), 'uint16') + + if namespace.cell_segs: + # p['seg_dir'] = 'segmented' + seg_arr = np.zeros((img_230height,img_width*img_count), 'uint16') + + if namespace.focus_segs: + # p['foci_seg_dir'] = 'segmented_foci' + focus_arr = np.zeros((230,img_width*img_count), 'uint16') + + if namespace.fluorescent_channels: + fluor_arrays = {} + for fluorescent_channel in namespace.fluorescent_channels: + fluor_arrays[fluorescent_channel] = np.zeros((230,img_width*img_count), 'uint16') + + frame = namespace.frame + frame_idx = frame - 1 + + img_counter = 0 + for fov_id in fov_id_list: + + print("concatenating images from fov_id {}.".format(fov_id)) + peaks_list = [peak_id for peak_id,val in specs[fov_id].items() if val == 1] + + for i,peak_id in enumerate(peaks_list): + + start_x = img_counter * img_width + end_x = start_x + img_width + + if namespace.cell_segs: + # set segmentation image name for saving and loading segmented images + img = mm3.load_stack(fov_id, peak_id, color='seg_unet')[frame_idx,:230,:] + seg_arr[:,start_x:end_x] = img + + if namespace.focus_segs: + # set segmentation image name for saving and loading segmented images + img = mm3.load_stack(fov_id, peak_id, color='foci_seg_unet')[frame_idx,:230,:] + focus_arr[:,start_x:end_x] = img + + if namespace.fluorescent_channels: + for fluorescent_channel in namespace.fluorescent_channels: + img = mm3.load_stack(fov_id, peak_id, color='c{}'.format(fluorescent_channel))[frame_idx,:230,:] + fluor_arrays[fluorescent_channel][:,start_x:end_x] = img + + # now for the phase images + img = mm3.load_stack(fov_id, peak_id, color=p['phase_plane'])[frame_idx,:230,:] + phase_arr[:,start_x:end_x] = img + + img_counter += 1 + + fname = os.path.join(p['experiment_directory'], '{}_t{:0=4}_combined_phase_peaks.png'.format(p['experiment_name'], frame)) + io.imsave(fname, skimage.img_as_ubyte(phase_arr)) + + if namespace.cell_segs: + fname = os.path.join(p['experiment_directory'], '{}_t{:0=4}_combined_cell_seg_peaks.png'.format(p['experiment_name'], frame)) + io.imsave(fname, skimage.img_as_ubyte(seg_arr)) + + if namespace.focus_segs: + fname = os.path.join(p['experiment_directory'], '{}_t{:0=4}_combined_focus_seg_peaks.png'.format(p['experiment_name'], frame)) + io.imsave(fname, skimage.img_as_ubyte(focus_arr)) + + if namespace.fluorescent_channels: + for fluorescent_channel in namespace.fluorescent_channels: + fname = os.path.join(p['experiment_directory'], '{}_t{:0=4}_combined_c{}_peaks.png'.format(p['experiment_name'], frame, fluorescent_channel)) + io.imsave(fname, fluor_arrays[fluorescent_channel]) + diff --git a/combine_tracks_from_chtc.py b/combine_tracks_from_chtc.py new file mode 100755 index 0000000..789b779 --- /dev/null +++ b/combine_tracks_from_chtc.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# import modules +import sys +import os +import inspect +import glob +import argparse +import skimage +from skimage import measure, io +from pprint import pprint # for human readable file output +try: + import cPickle as pickle +except: + import pickle + +# user modules +# realpath() will make your script run, even if you symlink it +cmd_folder = os.path.realpath(os.path.abspath( + os.path.split(inspect.getfile(inspect.currentframe()))[0])) +if cmd_folder not in sys.path: + sys.path.insert(0, cmd_folder) + +# This makes python look for modules in ./external_lib +cmd_subfolder = os.path.realpath(os.path.abspath( + os.path.join(os.path.split(inspect.getfile( + inspect.currentframe()))[0], "external_lib"))) +if cmd_subfolder not in sys.path: + sys.path.insert(0, cmd_subfolder) + +import mm3_helpers as mm3 + +#%% +# when using this script as a function and not as a library the following will execute +if __name__ == "__main__": + + # set switches and parameters + parser = argparse.ArgumentParser( + prog='python combine_tracks_from_chtc.py', + description='CHTC saves a separate track file for each fov/peak. Here we combine them.' + ) + parser.add_argument( + '-f', + '--paramfile', + type=str, + required=True, + help='Yaml file containing parameters.' + ) + + namespace = parser.parse_args() + + # Load the project parameters file + mm3.information('Loading experiment parameters.') + if namespace.paramfile: + param_file_path = namespace.paramfile + else: + mm3.warning('No param file specified. Using 100X template.') + param_file_path = 'yaml_templates/params_SJ110_100X.yaml' + p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + + # Get file names + fnames = glob.glob(os.path.join(p['cell_dir'], "{}*_tracks.pkl".format(p['experiment_name']))) + + ### Now prune and save the data. + mm3.information("Reading cell data from each file and combining into one.") + + tracks = {} + + for fname in fnames: + with open(fname, 'rb') as cell_file: + cell_data = pickle.load(cell_file) + os.remove(fname) + tracks.update(cell_data) + + with open(p['cell_dir'] + '/all_cells.pkl', 'wb') as cell_file: + pickle.dump(tracks, cell_file, protocol=pickle.HIGHEST_PROTOCOL) + + if os.path.isfile(os.path.join(p['cell_dir'], 'complete_cells.pkl')): + os.remove(os.path.join(p['cell_dir'], 'complete_cells.pkl')) + + os.symlink( + os.path.join(p['cell_dir'], 'all_cells.pkl'), + os.path.join(p['cell_dir'], 'complete_cells.pkl') + ) + + mm3.information("Finished curating and saving cell data.") diff --git a/docker/mm3-py3-tf2/Dockerfile b/docker/mm3-py3-tf2/Dockerfile new file mode 100644 index 0000000..95d941c --- /dev/null +++ b/docker/mm3-py3-tf2/Dockerfile @@ -0,0 +1,25 @@ +# Tensorflow image based on Ubuntu 16.04 with Python 3 and Jupyter +FROM tensorflow/tensorflow:2.1.0-gpu-py3 + +# update repositories +RUN apt-get update && \ + apt -y dist-upgrade + +# Choose locales. Set to US. +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y locales +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata +RUN locale-gen en_US.UTF-8 # en_GB.UTF-8 eu_FR.UTF-8 + +# set timezone. Set to PST. +RUN rm -f /etc/localtime /etc/timezone && echo "tzdata tzdata/Areas select America" > myfile && echo "tzdata tzdata/Zones/America select Los_Angeles" >> myfile && debconf-set-selections myfile && dpkg-reconfigure -f noninteractive tzdata && rm myfile + +# Install minimal packages +RUN apt-get install -y sudo vim less + +# install package for mac OSX socket sharing through x11 +RUN apt-get install -y libglu1-mesa + +# Copy installation script with Python packages into container and execute +COPY install_mm3_dependencies.sh /root/ +RUN /bin/bash /root/install_mm3_dependencies.sh +RUN rm -f /root/install_mm3_dependencies.sh diff --git a/docker/mm3-py3-tf2/install_mm3_dependencies.sh b/docker/mm3-py3-tf2/install_mm3_dependencies.sh new file mode 100644 index 0000000..5530983 --- /dev/null +++ b/docker/mm3-py3-tf2/install_mm3_dependencies.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# -y means yes +apt-get install -y python3-tk + +echo "Install other packages for imaging and plotting." +apt-get install -y libreadline-dev libsqlite3-dev libbz2-dev libssl-dev +apt-get install -y libblas-dev liblapack-dev libatlas-dev +apt-get install -y libpng-dev libfreetype6-dev tk-dev pkg-config +apt-get install -y ffmpeg + +echo "Install additional python packages with pip." +# Note: pip softlinks to pip3 +python -m pip install scipy +python -m pip install matplotlib +python -m pip install seaborn +python -m pip install Pillow +python -m pip install scikit-image +python -m pip install pyYAML +python -m pip install pandas +python -m pip install pims_nd2 +python -m pip install sklearn +python -m pip install tensorflow-gpu==2.0 + +# python binding +python -m pip install freetype-py + +# PyQt +python -m pip install PyQt5 +sudo apt-get -y install python3-pyqt5 +sudo apt-get -y install pyqt5-dev-tools +sudo apt-get -y install qttools5-dev-tools diff --git a/focus_track_training_file_paths.csv b/focus_track_training_file_paths.csv index 78ed17f..6df23c2 100644 --- a/focus_track_training_file_paths.csv +++ b/focus_track_training_file_paths.csv @@ -1,12 +1,74 @@ file_path,include -/home/wanglab/Users_local/Jeremy/Imaging/20190729/analysis/tracking_foci/20190729_JDW3705-inv_xy001_p0073_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190729/analysis/tracking_foci/20190729_JDW3705-inv_xy001_p0074_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy001_p0012_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy001_p0013_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy001_p0014_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy002_p0006_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy002_p0007_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy002_p0009_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy002_p0010_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy003_p0008_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20190726/analysis/tracking_foci/20190726_JDW3705-inv_xy003_p0012_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0001_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0007_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0008_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0009_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0011_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy003_p0001_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy005_p0007_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy005_p0009_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy005_p0010_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy005_p0015_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy005_p0016_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0020_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0021_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0022_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0023_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0024_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0025_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0036_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200714/analysis/tracking_foci/20200714_JDW3907_xy002_p0037_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy001_p0042_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy001_p0122_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy001_p0124_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy002_p0033_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy002_p0034_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy002_p0035_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy003_p0035_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy003_p0037_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0033_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0035_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0038_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0039_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0072_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0123_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0124_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy004_p0125_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0077_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0084_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0099_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0100_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0105_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0108_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0111_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200818/analysis/tracking_foci/20200818_JDW3930_xy001_p0120_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0097_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0104_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0108_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0097_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0102_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0023_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0024_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0025_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0026_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0029_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0030_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy002_p0066_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy002_p0072_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy002_p0087_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy001_p0040_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy001_p0041_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy002_p0036_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200819/analysis/tracking_foci/20200819_JDW3930_xy003_p0041_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200817/analysis/tracking_foci/20200817_JDW3930_xy001_p0016_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200817/analysis/tracking_foci/20200817_JDW3930_xy001_p0018_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0105_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy001_p0123_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0119_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0115_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy002_p0121_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0080_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0036_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0091_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200716/analysis/tracking_foci/20200716_JDW3907_xy001_p0065_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20200713/analysis/tracking_foci/20200713_JDW3907_xy003_p0032_updated_tracks.pkl,1 diff --git a/mm3_CellTrackGUI.py b/mm3_CellTrackGUI.py index 0d7c4b9..f4f1619 100755 --- a/mm3_CellTrackGUI.py +++ b/mm3_CellTrackGUI.py @@ -26,7 +26,7 @@ import pandas as pd sys.path.insert(0, '/home/wanglab/src/mm3/') # Jeremy's path to mm3 folder -sys.path.insert(0, '/home/wanglab/src/mm3/aux/') +sys.path.insert(0, '/home/wanglab/src/mm3/sup/') import mm3_helpers as mm3 import mm3_plots @@ -487,7 +487,7 @@ def save_updates(self): def no_track_pickle_lookup(self): - self.track_info = self.create_tracking_information(self.fov_id, self.peak_id, self.labelStack) + self.track_info = self.create_tracking_information(self.fov_id, self.peak_id, self.labelStack, self.phaseStack) def get_track_pickle(self): @@ -495,7 +495,7 @@ def get_track_pickle(self): # look for previously updated tracking information and load that if it is found. if not os.path.isfile(self.pickle_file_name): # get tracking information in a format usable and updatable by qgraphicsscene - self.track_info = self.create_tracking_information(self.fov_id, self.peak_id, self.labelStack) + self.track_info = self.create_tracking_information(self.fov_id, self.peak_id, self.labelStack, self.phaseStack) else: with open(self.pickle_file_name, 'rb') as pickle_file: try: @@ -701,13 +701,22 @@ def phase_img_and_regions(self, frame_index, watershed=False): return(phaseQpixmap, time_regions_and_events) - def create_tracking_information(self, fov_id, peak_id, label_stack): + def create_tracking_information(self, fov_id, peak_id, label_stack, phase_stack): - Complete_Lineages = mm3_plots.organize_cells_by_channel(self.Cells, self.specs) - All_Lineages = mm3_plots.organize_cells_by_channel(self.All_Cells, self.specs) + # Complete_Lineages = mm3_plots.organize_cells_by_channel(self.Cells, self.specs) + # All_Lineages = mm3_plots.organize_cells_by_channel(self.All_Cells, self.specs) t_adj = 1 + # correct for occasional missing of a frame by our microscope + dt = phase_stack.dtype + for k,img in enumerate(phase_stack): + # if the mean phase image signal is less than 200, add its index to list + if ((dt == 'uint16') and (np.mean(img) < 200)): + label_stack[k,...] = label_stack[k-1,...] + elif ((dt == 'uint8') and (np.mean(img) < 200/(2**16-1)*(2**8-1))): + label_stack[k,...] = label_stack[k-1,...] + regions_by_time = {frame+t_adj: measure.regionprops(label_stack[frame,:,:]) for frame in range(label_stack.shape[0])} regions_and_events_by_time = {frame+t_adj : {'regions' : {}, 'matrix' : None} for frame in range(label_stack.shape[0])} diff --git a/mm3_ChannelPicker.py b/mm3_ChannelPicker.py index a9beee1..d572860 100755 --- a/mm3_ChannelPicker.py +++ b/mm3_ChannelPicker.py @@ -25,15 +25,16 @@ plt.rcParams['axes.linewidth']=0.5 from skimage.exposure import rescale_intensity # for displaying in GUI -from skimage import io, morphology, segmentation -from scipy.misc import imresize +from skimage import io, morphology, segmentation, transform +# from scipy.misc import imresize # deprecated from skimage.external import tifffile as tiff import multiprocessing from multiprocessing import Pool import warnings import h5py -from tensorflow.python.keras import models +from tensorflow.keras import models +from tensorflow.keras.preprocessing.image import ImageDataGenerator # user modules # realpath() will make your script run, even if you symlink it @@ -90,6 +91,9 @@ def fov_plot_channels(fov_id, crosscorrs, specs, outputdir='.', phase_plane='c1' # load data for figure image_data = mm3.load_stack(fov_id, peak_id, color=phase_plane) + io.imshow(image_data[0,...]) + plt.show(); + first_img = rescale_intensity(image_data[0,:,:]) # phase image at t=0 last_img = rescale_intensity(image_data[-1,:,:]) # phase image at end @@ -100,7 +104,7 @@ def fov_plot_channels(fov_id, crosscorrs, specs, outputdir='.', phase_plane='c1' # plot the first image in each channel in top row ax=axhi - ax.imshow(first_img,cmap=plt.cm.gray, interpolation='nearest') + ax.imshow(first_img, cmap=plt.cm.gray, interpolation='nearest') ax.axis('off') ax.set_title(str(peak_id), fontsize = 12) if n == 0: @@ -289,7 +293,15 @@ def fov_cell_segger_plot_channels(fov_id, predictionDict, specs, outputdir='.', # load data for figure image_data = mm3.load_stack(fov_id, peak_id, color=phase_plane) - first_img = rescale_intensity(image_data[0,:,:]) # phase image at t=0 + img_idx = 0 + first_img = image_data[img_idx,:,:] # phase image at t=0 + print(np.mean(first_img)) + while np.mean(first_img) < 200: + img_idx += 1 + first_img = image_data[img_idx,:,:] + print(np.mean(first_img)) + + first_img = rescale_intensity(first_img) # phase image at t=0 last_img = rescale_intensity(image_data[-1,:,:]) # phase image at end # append an axis handle to ax list while adding a subplot to the figure which has a @@ -329,7 +341,7 @@ def fov_cell_segger_plot_channels(fov_id, predictionDict, specs, outputdir='.', # finally plot the prediction values as horizontal bar chart ax=axlo if predictionDict: - ax.barh(range(len(predictions)), predictions) + ax.barh([0], [predictions]) #ax.vlines(x=p['channel_picker']['channel_picking_threshold'], ymin=-1, ymax=5, linestyles='dashed',colors='red') ax.set_title('cell count', fontsize = 8) else: @@ -340,7 +352,7 @@ def fov_cell_segger_plot_channels(fov_id, predictionDict, specs, outputdir='.', if not n == 0: ax.get_yaxis().set_ticks([]) else: - ax.set_yticklabels(labels=["","1","2","3","4","5"]) + ax.set_yticklabels(labels=[""]) ax.set_ylabel("") fig.suptitle("FOV {:d}".format(fov_id),fontsize=14) @@ -745,7 +757,8 @@ def onclick_cells(event): # finally plot the prediction values as horizontal bar chart ax.append(fig.add_subplot(3, npeaks, n + 2*npeaks)) if predictionDict: - ax[-1].barh(range(len(predictions)), predictions) + # ax[-1].barh(range(len(predictions)), predictions) + ax[-1].barh([0], [predictions]) #ax[-1].vlines(x=p['channel_picker']['channel_picking_threshold'], ymin=-1, ymax=5, linestyles='dashed',colors='red') ax[-1].set_title('cell count', fontsize = 8) else: @@ -756,7 +769,8 @@ def onclick_cells(event): if not n == 1: ax[-1].get_yaxis().set_ticks([]) else: - ax[-1].set_yticklabels(labels=["",'1','2','3','4','5']) + ax[-1].set_yticklabels(labels=["1"]) + # ax[-1].set_yticklabels(labels=["",'1','2','3','4','5']) ax[-1].set_ylabel("") # show the plot finally @@ -804,10 +818,10 @@ def preload_images(specs, fov_id_list): UI_images[fov_id][peak_id] = {'first' : None, 'last' : None} # init dictionary # phase image at t=0. Rescale intenstiy and also cut the size in half first_image = p['channel_picker']['first_image'] - UI_images[fov_id][peak_id]['first'] = imresize(image_data[first_image,:,:], 0.5) + UI_images[fov_id][peak_id]['first'] = transform.resize(image_data[first_image,:,:], (int(np.floor(image_data.shape[1]*0.5)),int(np.floor(image_data.shape[2]*0.5)))) last_image = p['channel_picker']['last_image'] # phase image at end - UI_images[fov_id][peak_id]['last'] = imresize(image_data[last_image,:,:], 0.5) + UI_images[fov_id][peak_id]['last'] = transform.resize(image_data[last_image,:,:], (int(np.floor(image_data.shape[1]*0.5)),int(np.floor(image_data.shape[2]*0.5)))) return UI_images @@ -972,6 +986,7 @@ def preload_images(specs, fov_id_list): unet_shape = (p['segment']['trained_model_image_height'], p['segment']['trained_model_image_width']) + batch_size = p['segment']['batch_size'] cellClassThreshold = p['segment']['cell_class_threshold'] if cellClassThreshold == 'None': # yaml imports None as a string cellClassThreshold = False @@ -986,37 +1001,40 @@ def preload_images(specs, fov_id_list): 'shuffle':False} # arguments to predict_generator predict_args = dict(use_multiprocessing=True, - workers=p['num_analyzers'], - verbose=1) + workers=p['num_analyzers'], + verbose=1) for fov_id in fov_id_list: predictionDict[fov_id] = {} - mm3.information('Inferring number of cells in five evenly spaced frames for each trap in fov {}.'.format(fov_id)) + mm3.information('Inferring number of cells in first frame for each trap in fov {}.'.format(fov_id)) # assign each prediction to the proper fov_id, peak_id in predictions dict - counter = 0 + # counter = 0 peak_number = len(channel_masks[fov_id]) for i,peak_id in enumerate(sorted(channel_masks[fov_id].keys())): # get list of tiff file names tiff_file_name = glob.glob(os.path.join(chnl_dir, "*xy{:0=3}_p{:0=4}_c1.tif".format(fov_id, peak_id)))[0] img_array = io.imread(tiff_file_name) - img_height = img_array.shape[1] - img_width = img_array.shape[2] - slice_increment = int(img_array.shape[0]/5) + # slice_increment = int(img_array.shape[0]/5) # set up stack for images from all peaks # this is a bit more complicated than just doing 5 images at a time, but it is much faster # because you don't have nearly as many data transfer steps if i == 0: - img_stack = np.zeros((5*peak_number,img_height,img_width),dtype='uint16') + # img_stack = np.zeros((5*peak_number,img_height,img_width),dtype='uint16') + img_height = img_array.shape[1] + img_width = img_array.shape[2] + img_stack = np.zeros((peak_number,img_height,img_width),dtype='uint16') # switched to just looking at first timepoint # grab 5 images to load and run cell segmentation - for j in range(5): - img_stack[counter,...] = img_array[slice_increment*j,...] - counter += 1 + # for j in range(5): + # img_stack[counter,...] = img_array[slice_increment*j,...] + # counter += 1 + img_stack[i,...] = img_array[0,...] + # counter += 1 pad_dict = mm3.get_pad_distances(unet_shape, img_height, img_width) @@ -1031,15 +1049,18 @@ def preload_images(specs, fov_id_list): mode='constant') img_stack = np.expand_dims(img_stack, -1) - # set up image generator - image_generator = mm3.CellSegmentationDataGenerator(img_stack, **data_gen_args) - # run predictions - predictions = model.predict_generator(image_generator, **predict_args)[:,:,:,0] - if p['debug']: - fig,ax = plt.subplots(ncols=5); - for i in range(5): - ax[i].imshow(predictions[i,:,:]); - plt.show(); + image_datagen = ImageDataGenerator() + image_generator = image_datagen.flow(x=img_stack, + batch_size=batch_size, + shuffle=False) # keep same order + + # predict cell locations. This has multiprocessing built in but I need to mess with the parameters to see how to best utilize it. *** + predictions = model.predict_generator(image_generator, **predict_args) + # if p['debug']: + # fig,ax = plt.subplots(ncols=5); + # for i in range(5): + # ax[i].imshow(predictions[i,:,:]); + # plt.show(); # binarized and label (if there is a threshold value, otherwise, save a grayscale for debug) if cellClassThreshold: @@ -1060,17 +1081,18 @@ def preload_images(specs, fov_id_list): segmented_imgs[frame,:,:] = morphology.label(predictions[frame,:,:], connectivity=1) else: # in this case you just want to scale the 0 to 1 float image to 0 to 255 - information('Converting predictions to grayscale.') + mm3.information('Converting predictions to grayscale.') segmented_imgs = np.around(predictions * 100) # put number of cells detected into array for predictionDict - counter = 0 + # counter = 0 for i,peak_id in enumerate(sorted(channel_masks[fov_id].keys())): - cell_count_array = np.zeros(5, dtype='uint8') - for j in range(5): - cell_count_array[j] = int(np.max(segmented_imgs[counter,:,:])) - counter += 1 + cell_count_array = int(np.max(segmented_imgs[i,:,:])) + # cell_count_array = np.zeros(5, dtype='uint8') + # for j in range(5): + # cell_count_array[j] = int(np.max(segmented_imgs[counter,:,:])) + # counter += 1 predictionDict[fov_id][peak_id] = cell_count_array diff --git a/mm3_Compile.py b/mm3_Compile.py index f6439d2..dfedc99 100755 --- a/mm3_Compile.py +++ b/mm3_Compile.py @@ -13,6 +13,7 @@ import re from skimage import io, measure, morphology from skimage.external import tifffile as tiff +from sklearn import cluster from scipy import stats from pprint import pprint # for human readable file output try: @@ -56,6 +57,8 @@ description='Identifies and slices out channels into individual TIFF stacks through time.') parser.add_argument('-f', '--paramfile', type=str, required=True, help='Yaml file containing parameters.') + parser.add_argument('-p', '--path', type=str, + required=False, help='Path to data directory. Overrides what is in param file') parser.add_argument('-o', '--fov', type=str, required=False, help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.') parser.add_argument('-j', '--nproc', type=int, @@ -73,10 +76,15 @@ param_file_path = 'yaml_templates/params_SJ110_100X.yaml' p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + if namespace.path: + p = mm3.init_mm3_helpers(param_file_path, datapath=namespace.path) # initialized the helper library + else: + p = mm3.init_mm3_helpers(param_file_path, datapath=None) + if namespace.fov: if '-' in namespace.fov: - user_spec_fovs = range(int(namespace.fov.split("-")[0]), - int(namespace.fov.split("-")[1])+1) + user_spec_fovs = [i for i in range(int(namespace.fov.split("-")[0]), + int(namespace.fov.split("-")[1])+1)] else: user_spec_fovs = [int(val) for val in namespace.fov.split(",")] else: @@ -147,7 +155,7 @@ mm3.information('Removing images after time {}'.format(t_end)) # go through list and find first place where timepoint is equivalent to t_end for n, ifile in enumerate(found_files): - string = re.compile('t%03dxy|t%04dxy' % (t_end, t_end)) # account for 3 and 4 digit + string = re.compile('t{:0=3}xy|t{:0=4}xy'.format(t_end, t_end)) # account for 3 and 4 digit if re.search(string, ifile): found_files = found_files[:n] break @@ -158,7 +166,7 @@ mm3.information('Filtering TIFFs by FOV.') fitered_files = [] for fov_id in user_spec_fovs: - fov_string = 'xy%02d' % fov_id # xy01 + fov_string = 'xy{:0=2}'.format(fov_id) # xy01 fitered_files += [ifile for ifile in found_files if fov_string in ifile] found_files = fitered_files[:] @@ -212,13 +220,17 @@ else: model_file_path = p['compile']['model_file_traps'] # *** Need parameter for weights - model = models.load_model(model_file_path, - custom_objects={'tversky_loss': mm3.tversky_loss, - 'cce_tversky_loss': mm3.cce_tversky_loss}) + model = models.load_model( + model_file_path, + custom_objects={ + 'tversky_loss': mm3.tversky_loss, + 'cce_tversky_loss': mm3.cce_tversky_loss + } + ) mm3.information("Model loaded.") # initialize pool for getting image metadata - pool = Pool(p['num_analyzers']) + # pool = Pool(p['num_analyzers']) # loop over images and get information for fn in found_files: @@ -226,24 +238,24 @@ # for each file name. Won't look for channels, just gets the metadata for later use by Unet # This is the non-parallelized version (useful for debug) - # analyzed_imgs[fn] = mm3.get_initial_tif_params(fn) + analyzed_imgs[fn] = mm3.get_initial_tif_params(fn) # Parallelized - analyzed_imgs[fn] = pool.apply_async(mm3.get_initial_tif_params, args=(fn,)) + # analyzed_imgs[fn] = pool.apply_async(mm3.get_initial_tif_params, args=(fn,)) - mm3.information('Waiting for image metadata pool to be finished.') - pool.close() # tells the process nothing more will be added. - pool.join() # blocks script until everything has been processed and workers exit + # mm3.information('Waiting for image metadata pool to be finished.') + # pool.close() # tells the process nothing more will be added. + # pool.join() # blocks script until everything has been processed and workers exit mm3.information('Image metadata pool finished, getting results.') # get results from the pool and put them in a dictionary - for fn in analyzed_imgs.keys(): - result = analyzed_imgs[fn] - if result.successful(): - analyzed_imgs[fn] = result.get() # put the metadata in the dict if it's good - else: - analyzed_imgs[fn] = False # put a false there if it's bad + # for fn in analyzed_imgs.keys(): + # result = analyzed_imgs[fn] + # if result.successful(): + # analyzed_imgs[fn] = result.get() # put the metadata in the dict if it's good + # else: + # analyzed_imgs[fn] = False # put a false there if it's bad # print(analyzed_imgs) @@ -253,6 +265,10 @@ file_names = np.asarray(file_names) fov_ids = [analyzed_imgs[key]['fov'] for key in analyzed_imgs.keys()] + if p['debug']: + print(file_names) + print(fov_ids) + unique_fov_ids = np.unique(fov_ids) if p['compile']['do_channel_masks']: @@ -295,13 +311,15 @@ # produces predition stack with 3 "pages", index 0 is for traps, index 1 is for central tough, index 2 is for background mm3.information("Predicting trap locations for first frame.") - first_frame_trap_prediction = mm3.get_frame_predictions(img, - model, - stack_weights, - trap_align_metadata['shift_distance'], - subImageNumber=16, - padSubImageNumber=25, - debug=p['debug']) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + first_frame_trap_prediction = mm3.get_frame_predictions(img, + model, + stack_weights, + trap_align_metadata['shift_distance'], + subImageNumber=16, + padSubImageNumber=25, + debug=p['debug']) if p['debug']: fig,ax = plt.subplots(nrows=1, ncols=4, figsize=(12,12)) @@ -324,11 +342,13 @@ trap_props = measure.regionprops(trap_labels) trap_area_threshold = p['compile']['trap_area_threshold'] - trap_bboxes = mm3.get_frame_trap_bounding_boxes(trap_labels, - trap_props, - trapAreaThreshold=trap_area_threshold, - trapWidth=trap_align_metadata['trap_width'], - trapHeight=trap_align_metadata['trap_height']) + trap_bboxes, trap_rotations = mm3.get_frame_trap_bounding_boxes( + trap_labels, + trap_props, + trapAreaThreshold=trap_area_threshold, + trapWidth=trap_align_metadata['trap_width'], + trapHeight=trap_align_metadata['trap_height'] + ) # create boolean array to contain filtered, correctly-shaped trap bounding boxes first_frame_trap_mask = np.zeros(traps.shape) @@ -337,6 +357,9 @@ good_trap_labels = measure.label(first_frame_trap_mask) good_trap_props = measure.regionprops(good_trap_labels) + # add trap rotation angle to each regionprops object in good_trap_props + for i,reg in enumerate(good_trap_props): + reg.rotation_angle = trap_rotations[i] # widen the traps to merge them into "trap regions" above and below the central trough dilated_traps = morphology.dilation(first_frame_trap_mask, dilator) @@ -415,13 +438,17 @@ 'n_channels':1, 'normalize_to_one':True, 'shuffle':False} - predict_gen_args = {'verbose':1, - 'use_multiprocessing':True, - 'workers':p['num_analyzers']} + predict_gen_args = { + 'verbose':1, + 'use_multiprocessing':False, + # 'workers':p['num_analyzers'], + } - img_generator = mm3.TrapSegmentationDataGenerator(align_region_stack, **data_gen_args) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + img_generator = mm3.TrapSegmentationDataGenerator(align_region_stack, **data_gen_args) - align_region_predictions = model.predict_generator(img_generator, **predict_gen_args) + align_region_predictions = model.predict_generator(img_generator, **predict_gen_args) #align_region_stack = mm3.apply_median_filter_and_normalize(align_region_stack) #align_region_predictions = model.predict(align_region_stack, batch_size=batch_size) # reduce dimensionality such that the class predictions are now (frame_number,512,512), and each voxel is labelled as the predicted region, i.e., 0=trap, 1=central trough, 2=background. @@ -457,11 +484,13 @@ frame_trap_labels = measure.label(align_traps[frame,:,:]) frame_trap_props = measure.regionprops(frame_trap_labels) - trap_bboxes = mm3.get_frame_trap_bounding_boxes(frame_trap_labels, - frame_trap_props, - trapAreaThreshold=trap_area_threshold, - trapWidth=trap_align_metadata['trap_width'], - trapHeight=trap_align_metadata['trap_height']) + trap_bboxes, _ = mm3.get_frame_trap_bounding_boxes( + frame_trap_labels, + frame_trap_props, + trapAreaThreshold=trap_area_threshold, + trapWidth=trap_align_metadata['trap_width'], + trapHeight=trap_align_metadata['trap_height'] + ) for i,bbox in enumerate(trap_bboxes): align_trap_mask_stack[frame,bbox[0]:bbox[2],bbox[1]:bbox[3]] = True @@ -481,8 +510,8 @@ trapTriggered = False for frame in range(trap_align_metadata['frame_count']): anyTraps = np.any(labelled_align_trap_mask_stack[frame,:,:] > 0) - # if anyTraps is False, that means no traps were detected for this frame. This usuall occurs due to a bug in our imaging system, - # which can cause it to miss the occasional frame. Should be fine to snag labels from prior frame. + # if anyTraps is False, that means no traps were detected for this frame. This usually occurs due to a bug in our imaging system, + # which can cause it to miss the occasional frame. Should be fine to snag labels from adjacent frame. if not anyTraps: trapTriggered = True mm3.information("Frame at index {} has no detected traps. Borrowing labels from an adjacent frame.".format(frame)) @@ -522,7 +551,7 @@ for label in bad_align_trap_props: labelled_align_trap_mask_stack[labelled_align_trap_mask_stack == label] = 0 - + align_centroids = [] for frame in range(trap_align_metadata['frame_count']): align_centroids.append([reg.centroid for reg in measure.regionprops(labelled_align_trap_mask_stack[frame,:,:])]) @@ -532,14 +561,22 @@ integer_shifts = np.round(shifts).astype('int16') good_trap_bboxes_dict = {} + good_trap_rotations_dict = {} for trap in good_trap_props: good_trap_bboxes_dict[trap.label] = trap.bbox + good_trap_rotations_dict[trap.label] = trap.rotation_angle # pprint(good_trap_bboxes_dict) # uncomment for debugging bbox_shift_dict = mm3.shift_bounding_boxes(good_trap_bboxes_dict, integer_shifts, img.shape[0]) # pprint(bbox_shift_dict) # uncomment for debugging - trap_images_fov_dict, trap_closed_end_px_dict = mm3.crop_traps(fov_file_names, good_trap_props, good_trap_labels, bbox_shift_dict, trap_align_metadata) + trap_images_fov_dict, trap_closed_end_px_dict = mm3.crop_traps( + fov_file_names, + good_trap_rotations_dict, + good_trap_labels, + bbox_shift_dict, + trap_align_metadata + ) for fn in fov_file_names: analyzed_imgs[fn]['channels'] = trap_closed_end_px_dict[fn] diff --git a/mm3_DetectFoci.py b/mm3_DetectFoci.py index 65cceaf..e2aa059 100755 --- a/mm3_DetectFoci.py +++ b/mm3_DetectFoci.py @@ -35,6 +35,27 @@ import mm3_helpers as mm3 +def segment_foci_single_file(infile_name, params, namespace): + + mm3.information("Segmenting image {}.".format(infile_name)) + # load model to pass to algorithm + mm3.information("Loading model...") + + if namespace.modelfile: + model_file_path = namespace.modelfile + else: + model_file_path = params['foci']['foci_model_file'] + + seg_model = models.load_model(model_file_path, + custom_objects={'bce_dice_loss': mm3.bce_dice_loss, + 'dice_loss': mm3.dice_loss, + 'precision_m': mm3.precision_m, + 'recall_m': mm3.recall_m, + 'f_precision_m': mm3.f_precision_m}) + mm3.information("Model loaded.") + mm3.segment_stack_unet(infile_name, seg_model, mode='foci') + sys.exit("Completed segmenting image {}.".format(infile_name)) + # when using this script as a function and not as a library the following will execute if __name__ == "__main__": @@ -43,6 +64,8 @@ description='Segment cells and create lineages.') parser.add_argument('-f', '--paramfile', type=str, required=True, help='Yaml file containing parameters.') + parser.add_argument('-i', '--infile', type=str, + required=False, help='Use this argument to segment ONLY on image. Name the single file to be segmented.') parser.add_argument('-o', '--fov', type=str, required=False, help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.') parser.add_argument('-j', '--nproc', type=int, @@ -82,6 +105,11 @@ p['seg_img'] = 'foci_seg_unet' p['pred_img'] = 'foci_pred_unet' + if namespace.infile: + + fname = namespace.infile + segment_foci_single_file(fname, p, namespace) + # load specs file specs = mm3.load_specs() # print(specs) # for debugging @@ -106,12 +134,19 @@ else: model_file_path = p['foci']['foci_model_file'] # *** Need parameter for weights - seg_model = models.load_model(model_file_path, - custom_objects={'bce_dice_loss': mm3.bce_dice_loss, - 'dice_loss': mm3.dice_loss, - 'precision_m': mm3.precision_m, - 'recall_m': mm3.recall_m, - 'f_precision_m': mm3.f_precision_m}) + seg_model = models.load_model( + model_file_path, + custom_objects = { + 'weighted_bce_dice_loss': mm3.weighted_bce_dice_loss, + 'weighted_bce': mm3.weighted_bce, + 'bce_dice_loss': mm3.bce_dice_loss, + 'dice_loss': mm3.dice_loss, + 'precision_m': mm3.precision_m, + 'recall_m': mm3.recall_m, + 'f_precision_m': mm3.f_precision_m, + 'f2_m': mm3.f2_m + } + ) mm3.information("Model loaded.") for fov_id in fov_id_list: diff --git a/mm3_FocusTrackGUI.py b/mm3_FocusTrackGUI.py index e468466..36e266b 100755 --- a/mm3_FocusTrackGUI.py +++ b/mm3_FocusTrackGUI.py @@ -25,8 +25,7 @@ import multiprocessing import pandas as pd -sys.path.insert(0, '/home/wanglab/src/mm3/') # Jeremy's path to mm3 folder -sys.path.insert(0, '/home/wanglab/src/mm3/aux/') +sys.path.insert(0,os.path.dirname(os.path.realpath(__file__))) import mm3_helpers as mm3 import mm3_plots @@ -240,9 +239,8 @@ def __init__(self, specs, channel): self.specs = specs self.channel = channel - # add QImages to scene (try three frames) + # add QImages to scene self.fov_id_list = [fov_id for fov_id in self.specs.keys()] - # self.center_frame_index = 1 self.fovIndex = 0 self.fov_id = self.fov_id_list[self.fovIndex] @@ -251,22 +249,33 @@ def __init__(self, specs, channel): self.peakIndex = 0 self.peak_id = self.peak_id_list_in_fov[self.peakIndex] + # print(self.peak_id) + # print(self.fov_id) # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) - # read in images self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) - # time_int = params['moviemaker']['seconds_per_time_index']/60 + # read in cell tracking info + cell_filename_all = os.path.join(params['cell_dir'], 'all_cells.pkl') + with open(cell_filename_all, 'rb') as cell_file: + self.All_Cells = pickle.load(cell_file) - focus_filename_all = os.path.join(params['foci_track_dir'], 'all_foci.pkl') + self.fov_Cells = mm3.filter_cells(self.All_Cells, attr="fov", val=self.fov_id) + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) + # read in focus tracking info + focus_filename_all = os.path.join(params['foci_track_dir'], 'all_foci.pkl') with open(focus_filename_all, 'rb') as focus_file: self.All_Foci = pickle.load(focus_file) - # mm3.calculate_pole_age(self.foci) # add poleage + + self.fov_Foci = mm3.filter_cells(self.All_Foci, attr="fov", val=self.fov_id) + self.these_Foci = mm3.filter_cells(self.fov_Foci, attr="fov", val=self.peak_id) plot_dir = os.path.join(params['foci_track_dir'], 'plots') if not os.path.exists(plot_dir): @@ -300,7 +309,7 @@ def __init__(self, specs, channel): # the below lookup table may need reworked to handle migration and child lines to/from a focus # or the handling may be better done in the update_focus_info function self.event_type_index_lookup = {"MigrationLine":0, - "ChildLine":1, + "ChildLine":1, "AppearSymbol":2, "DisappearSymbol":3, "JoinLine":4, @@ -337,13 +346,15 @@ def save_updates(self): track_info = self.track_info for t, time_info in track_info.items(): - for region_label, region in time_info['regions'].items(): + # all_t_cells = time_info['cells'] + # for cell_id, cell_info in all_t_cells.items(): + for region_label, region in time_info['regions'].items(): - if 'region_graphic' in track_info[t]['regions'][region_label]: - if 'pen' in track_info[t]['regions'][region_label]['region_graphic']: - track_info[t]['regions'][region_label]['region_graphic'].pop('pen') - if 'brush' in track_info[t]['regions'][region_label]['region_graphic']: - track_info[t]['regions'][region_label]['region_graphic'].pop('brush') + if 'region_graphic' in region: + if 'pen' in region['region_graphic']: + region['region_graphic'].pop('pen') + if 'brush' in region['region_graphic']: + region['region_graphic'].pop('brush') with open(self.pickle_file_name, 'wb') as track_file: try: @@ -353,8 +364,9 @@ def save_updates(self): except Exception as e: track_file.close() print(str(e)) - - df_file_name = '/home/wanglab/src/mm3/focus_track_training_file_paths.csv' + + + df_file_name = os.path.dirname(os.path.realpath(__file__))+r'\focus_track_training_file_paths.csv' # df_file_name = '/Users/jt/code/mm3/focus_track_training_file_paths.csv' if os.path.isfile(df_file_name): @@ -411,9 +423,14 @@ def go_to_fov_and_peak_id(self, fov_id, peak_id): # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) + + self.fov_Cells = mm3.filter_cells(self.All_Cells, attr="fov", val=self.fov_id) + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) # look for previously edited info and load it if it is found self.get_track_pickle() @@ -434,9 +451,13 @@ def next_peak(self): # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) + + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) # look for previously edited info and load it if it is found self.get_track_pickle() @@ -457,11 +478,13 @@ def prior_peak(self): # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) - # print(self.phaseImgPath) - # print(self.labelImgPath) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) + + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) # look for previously edited info and load it if it is found self.get_track_pickle() @@ -487,11 +510,14 @@ def next_fov(self): # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) - # print(self.phaseImgPath) - # print(self.labelImgPath) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) + + self.fov_Cells = mm3.filter_cells(self.All_Cells, attr="fov", val=self.fov_id) + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) # look for previously edited info and load it if it is found self.get_track_pickle() @@ -516,9 +542,14 @@ def prior_fov(self): # construct image stack file names from params self.phaseImgPath = os.path.join(params['chnl_dir'], "{}_xy{:0=3}_p{:0=4}_c{}.tif".format(params['experiment_name'], self.fov_id, self.peak_id, self.channel)) self.labelImgPath = os.path.join(params['foci_seg_dir'], "{}_xy{:0=3}_p{:0=4}_foci_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) + self.cellLabelImgPath = os.path.join(params['seg_dir'], "{}_xy{:0=3}_p{:0=4}_seg_unet.tif".format(params['experiment_name'], self.fov_id, self.peak_id)) self.labelStack = io.imread(self.labelImgPath) self.phaseStack = io.imread(self.phaseImgPath) + self.cellLabelStack = io.imread(self.cellLabelImgPath) + + self.fov_Cells = mm3.filter_cells(self.All_Cells, attr="fov", val=self.fov_id) + self.these_Cells = mm3.filter_cells(self.fov_Cells, attr="peak", val=self.peak_id) # look for previously edited info and load it if it is found self.get_track_pickle() @@ -531,12 +562,12 @@ def all_phase_img_and_regions(self): frame_dict_by_time = {} xPos = 0 - for time in self.track_info.keys(): + for time,t_info in self.track_info.items(): frame_index = time-1 - frame, regions = self.phase_img_and_regions(frame_index) + frame,t_info = self.phase_img_and_regions(frame_index) frame_dict_by_time[time] = self.addPixmap(frame) frame_dict_by_time[time].time = time - self.add_regions_to_frame(regions, frame_dict_by_time[time]) + self.add_regions_to_frame(t_info, frame_dict_by_time[time]) frame_dict_by_time[time].setPos(xPos, 0) xPos += frame_dict_by_time[time].pixmap().width() @@ -563,17 +594,18 @@ def phase_img_and_regions(self, frame_index, watershed=False): RGBLabelImg = (RGBLabelImg*255).astype('uint8') originalHeight, originalWidth, RGBLabelChannelNumber = RGBLabelImg.shape RGBLabelImg = QImage(RGBLabelImg, originalWidth, originalHeight, RGBLabelImg.strides[0], QImage.Format_RGB888)#.scaled(512, 512, aspectRatioMode=Qt.KeepAspectRatio) - # pprint(regions) - time_regions_and_events = self.track_info[time] - regions = time_regions_and_events['regions'] - time_regions_and_events['time'] = time - for region_id in regions.keys(): + time_info = self.track_info[time] + time_info['time'] = time + + # for cell_data in time_info['cells'].values(): + regions = time_info['regions'] + for region in regions.values(): brush = QBrush() brush.setStyle(Qt.SolidPattern) pen = QPen() pen.setStyle(Qt.SolidLine) - props = regions[region_id]['props'] + props = region['props'] coords = props.coords min_row, min_col, max_row, max_col = props.bbox label = props.label @@ -583,33 +615,98 @@ def phase_img_and_regions(self, frame_index, watershed=False): brush.setColor(brushColor) pen.setColor(brushColor) - regions[region_id]['region_graphic'] = {'top_y':min_row, 'bottom_y':max_row, - 'left_x':min_col, 'right_x':max_col, - 'coords':coords, - 'pen':pen, 'brush':brush} + region['region_graphic'] = {'top_y':min_row, 'bottom_y':max_row, + 'left_x':min_col, 'right_x':max_col, + 'coords':coords, + 'pen':pen, 'brush':brush} + + return(phaseQpixmap, time_info) + + # def get_foci_regions_by_time(self, focus_label_stack): + # '''This function provides support to detect to which cell each focus belongs, and to + # add additional information to each focus' regionprops based on that focus' position + # within its cell. + # ''' + + # t_adj = 1 + # foci_by_time = {frame+t_adj: {} for frame in range(focus_label_stack.shape[0])} + # cell_id_list = [k for k in self.these_Cells.keys()] + + # # get to first timepoint containing cells + # for t in foci_by_time.keys(): + + # t_foci = mm3.filter_cells_containing_val_in_attr(self.these_Cells, attr='times', val=t) + # if t_cells: + # break + + + + + + # skipped_cells = [] + # # loop over cells, begining with cells at current time + # for cell_id,cell in t_cells.items(): + + # foci_in_cell = cell.foci + + # for t in cell.times: + + # cell_idx = cell.times.index(t) + # foci_in_cell = cell.foci + + # these_foci = {focus_id:focus for focus_id,focus in foci_in_cell.items() if t in focus.times} + + + # t_foci = mm3.filter_cells_containing_val_in_attr(self.these_Foci, attr='times', val=t) + + # if len(t_foci) > 0: + + # for focus_id,focus in t_foci.items(): + + # focus_idx = focus.times.index(t) + # # loop over cells to which this focus belongs + # for cell in focus.cells: + # cell_times = cell.times + # # if the current time is in this cell grab the index + # # for this cell and its id + # if t in cell_times: + # cell_idx = cell.times.index(t) + # cell_id = cell.id + # break + + # cell_y,cell_x = cell.centroids[cell_idx] + # focus_y,focus_x = focus.regions[focus_idx].centroid + + # focus.regions[focus_idx].cell_y_pos = cell_y-focus_y + # focus.regions[focus_idx].cell_x_pos = cell_x-focus_x + # focus.daughter_cell_ids = [] + # if cell.daughters: + # for daughter_cell in cell.daughters: + # focus.daughter_cell_ids.append(daughter_cell.id) + # cell_data[t].append(focus) + + # return(foci_by_cell_and_time) - return(phaseQpixmap, time_regions_and_events) def create_tracking_information(self, fov_id, peak_id, label_stack, img_stack): t_adj = 1 regions_by_time = {frame+t_adj: measure.regionprops(label_stack[frame,:,:], img_stack[frame,:,:]) for frame in range(label_stack.shape[0])} - regions_and_events_by_time = {frame+t_adj : {'regions' : {}, 'matrix' : None} for frame in range(label_stack.shape[0])} + regions_and_events_by_time = {frame+t_adj: {'regions' : {}, 'matrix' : None} for frame in range(label_stack.shape[0])} + # iterate through all times, collecting focus information and adding + # cell-centric tracking info to the dictionary as we go + for t,regions in regions_by_time.items(): - # loop through regions and add them to the main dictionary. - for t, regions in regions_by_time.items(): - # this is a list, while we want it to be a dictionary with the region label as the key for region in regions: default_events = np.zeros(6, dtype=np.int) default_events[5] = 1 # set N to 1 - regions_and_events_by_time[t]['regions'][region.label] = {'props' : region, - 'events' : default_events} - # create default interaction matrix - # Now that we know how many regions there are per time point, we will create a default matrix which indicates how regions are connected to each between this time point, t, and the next one, t+1. The row index will be the region label of the region in t, and the column index will be the region label of the region in t+1. - # If a region migrates from t to t+1, its row should have a sum of 1 corresponding from which region (row) to which region (column) it moved. If the region divided, then both of the daughter columns will get value 1. - # Note that the regions are labeled from 1, but Numpy arrays and indexed from zero. We can use this to store some additional informaiton. If a region disappears, it will receive a 1 in the column with index 0. - # In the last time point all regions will be connected to the disappear column + # add information at 'focus_cell_label' + regions_and_events_by_time[t]['regions'][region.label] = { + 'props' : region, + 'events' : default_events, + } + for t, t_data in regions_and_events_by_time.items(): n_regions_in_t = len(regions_by_time[t]) if t+1 in regions_by_time: @@ -620,51 +717,48 @@ def create_tracking_information(self, fov_id, peak_id, label_stack, img_stack): t_data['matrix'] = np.zeros((n_regions_in_t+1, n_regions_in_t_plus_1+1), dtype=np.int) # Loop over foci and edit event information - # We will use the focus dictionary All_Foci. - # Each focus object has a number of attributes that are useful to us. + # We will use the dictionary All_Foci. + # Each Focus object has a number of attributes that are useful to us. # We will go through each focus by its time points and edit the events associated with that region. # We will also edit the matrix when appropriate. # pull out only the foci in of this FOV foci_tmp = mm3_plots.find_cells_of_fov_and_peak(self.All_Foci, fov_id, peak_id) print('There are {} foci for this trap'.format(len(foci_tmp))) - for focus_id, focus_tmp in foci_tmp.items(): + for focus_id,focus_tmp in foci_tmp.items(): # Check for when focus has less time points than it should - # if (focus_tmp.times[-1] - focus_tmp.times[0])+1 > len(focus_tmp.times): - # print('Focus {} has less time points than it should, skipping.'.format(focus_id)) - # continue - unique_times = list(np.unique(focus_tmp.times)) - focus_tmp_labels = [focus_tmp.labels[focus_tmp.times.index(t)] for t in unique_times] + if (focus_tmp.times[-1] - focus_tmp.times[0])+1 > len(focus_tmp.times): + print('Focus {} has less time points than it should, skipping.'.format(focus_id)) + continue - # Go over the time points of this focus and edit appropriate information main dictionary - for i, t in enumerate(unique_times): + for i,t in enumerate(focus_tmp.times): - # get the region label - label_tmp = focus_tmp_labels[i] + t_info = regions_and_events_by_time[t] + label_tmp = focus_tmp.labels[i] # M migration, event 0 # If the focus has another time point after this one then it must have migrated - if i != len(unique_times)-1: - regions_and_events_by_time[t]['regions'][label_tmp]['events'][0] = 1 + max_focus_time = np.max(focus_tmp.times) + # print('focus {} has max time {}.'.format(focus_id,max_focus_time)) + if t < max_focus_time: + t_info['regions'][label_tmp]['events'][0] = 1 # update matrix using this region label and the next one - # print(label_tmp, focus_tmp_labels[i+1], regions_and_events_by_time[t]['matrix']) - regions_and_events_by_time[t]['matrix'][label_tmp, focus_tmp_labels[i+1]] = 1 + # print(focus_cell_label, focus_cell_labels[focus_idx+1], cell_info['matrix']) + t_info['matrix'][label_tmp, focus_tmp.labels[i+1]] = 1 # S division, 1 - if focus_tmp.daughters and i == len(unique_times)-1: - regions_and_events_by_time[t]['regions'][label_tmp]['events'][1] = 1 + if focus_tmp.daughters and t == max_focus_time: + t_info['regions'][label_tmp]['events'][1] = 1 # daughter 1 and 2 label - # d1_label = self.All_Foci[focus_tmp.daughters[0].id].labels[0] - d1_label = self.All_Foci[focus_tmp.daughters[0]].labels[0] + d1_label = self.All_Foci[focus.daughters[0]].labels[0] try: - # d2_label = self.All_Foci[focus_tmp.daughters[1].id].labels[0] - d2_label = self.All_Foci[focus_tmp.daughters[1]].labels[0] - regions_and_events_by_time[t]['matrix'][label_tmp, d1_label] = 1 - regions_and_events_by_time[t]['matrix'][label_tmp, d2_label] = 1 + d2_label = self.All_Foci[focus.daughters[1]].labels[0] + t_info['matrix'][label_tmp, d1_label] = 1 + t_info['matrix'][label_tmp, d2_label] = 1 except IndexError as e: print("At timepoint {} there was an index error in assigning daughters: {}".format(t,e)) @@ -672,33 +766,183 @@ def create_tracking_information(self, fov_id, peak_id, label_stack, img_stack): # I appears, 2 if not t == 1: if not focus_tmp.parent and i == 0: - regions_and_events_by_time[t]['regions'][label_tmp]['events'][2] = 1 + t_info['regions'][label_tmp]['events'][2] = 1 # O disappears, 3 - if not focus_tmp.daughters and i == len(unique_times)-1: - regions_and_events_by_time[t]['regions'][label_tmp]['events'][3] = 1 - regions_and_events_by_time[t]['matrix'][label_tmp, 0] = 1 + if not focus_tmp.daughters and t == max_focus_time: + t_info['regions'][label_tmp]['events'][3] = 1 + t_info['matrix'][label_tmp, 0] = 1 # N no data, 4 - Set this to zero as this region as been checked. - regions_and_events_by_time[t]['regions'][label_tmp]['events'][5] = 0 + t_info['regions'][label_tmp]['events'][5] = 0 - # Set remaining regions to event space [0 0 0 0 1 1] - # Also make their appropriate matrix value 1, which should be in the first column. - for t, t_data in regions_and_events_by_time.items(): - for region, region_data in t_data['regions'].items(): + # If any focus has still not been visited, + # make their appropriate matrix value 1, which should be in the first column. + for t,t_data in regions_and_events_by_time.items(): + for region,region_data in t_data['regions'].items(): if region_data['events'][5] == 1: - t_data['matrix'][region, 0] = 1 return(regions_and_events_by_time) - def add_regions_to_frame(self, regions_and_events, frame): + # pprint(regions_and_events_by_time) + + # if not cell_id in regions_and_events_by_cell_and_time: + # regions_and_events_by_cell_and_time[cell_id] = {} + + # cell_data = regions_and_events_by_cell_and_time[cell_id] + + # for t,t_region_list in this_cell_time_info.items(): + + # if not t in cell_data: + # cell_data[t] = {'regions' : {}, 'matrix' : None} + + # for i,region in enumerate(t_region_list): + + # default_events = np.zeros(6, dtype=np.int) + # default_events[5] = 1 # set N to 1 + # cell_data[t]['regions'][i+1] = { + # 'props' : region, + # 'events' : default_events, + # } + + # create default interaction matrix + # Now that we know how many regions there are per time point, we will create a default matrix which indicates how regions are connected to each between this time point, t, and the next one, t+1. The row index will be the region label of the region in t, and the column index will be the region label of the region in t+1. + # If a region migrates from t to t+1, its row should have a sum of 1 corresponding from which region (row) to which region (column) it moved. If the region divided, then both of the daughter columns will get value 1. + # Note that the regions are labeled from 1, but Numpy arrays and indexed from zero. We can use this to store some additional informaiton. If a region disappears, it will receive a 1 in the column with index 0. + # In the last time point all regions will be connected to the disappear column + # for cell_id, this_cell_time_info in regions_and_events_by_cell_and_time.items(): + # for t,t_data in this_cell_time_info.items(): + # n_regions_in_t = len(t_data['regions']) + # if t+1 in this_cell_time_info: + # n_regions_in_t_plus_1 = len(this_cell_time_info[t+1]['regions']) + # else: + # n_regions_in_t_plus_1 = 0 + + # t_data['matrix'] = np.zeros((n_regions_in_t+1, n_regions_in_t_plus_1+1), dtype=np.int) + + + # regions_and_events_by_time = {frame+t_adj : {'regions' : {}, 'matrix' : None} for frame in range(label_stack.shape[0])} + + # # loop through regions and add them to the main dictionary. + # for t, regions in regions_by_time.items(): + # # this is a list, while we want it to be a dictionary with the region label as the key + # for region in regions: + # default_events = np.zeros(6, dtype=np.int) + # default_events[5] = 1 # set N to 1 + # regions_and_events_by_time[t]['regions'][region.label] = {'props' : region, + # 'events' : default_events} + # create default interaction matrix + # Now that we know how many regions there are per time point, we will create a default matrix which indicates how regions are connected to each between this time point, t, and the next one, t+1. The row index will be the region label of the region in t, and the column index will be the region label of the region in t+1. + # If a region migrates from t to t+1, its row should have a sum of 1 corresponding from which region (row) to which region (column) it moved. If the region divided, then both of the daughter columns will get value 1. + # Note that the regions are labeled from 1, but Numpy arrays and indexed from zero. We can use this to store some additional informaiton. If a region disappears, it will receive a 1 in the column with index 0. + # In the last time point all regions will be connected to the disappear column + # for t, t_data in regions_and_events_by_time.items(): + # n_regions_in_t = len(regions_by_time[t]) + # if t+1 in regions_by_time: + # n_regions_in_t_plus_1 = len(regions_by_time[t+1]) + # else: + # n_regions_in_t_plus_1 = 0 + + # t_data['matrix'] = np.zeros((n_regions_in_t+1, n_regions_in_t_plus_1+1), dtype=np.int) + + # Loop over foci and edit event information + # We will use the focus dictionary All_Foci. + # Each focus object has a number of attributes that are useful to us. + # We will go through each focus by its time points and edit the events associated with that region. + # We will also edit the matrix when appropriate. + # pull out only the foci in of this FOV + + # for focus_id,focus_tmp in foci_tmp.items(): + + # cell_id = focus_id_cell_id_lut[focus_id] + # cell_data = regions_and_events_by_cell_and_time[cell_id] + # # Check for when focus has less time points than it should + # # if (focus_tmp.times[-1] - focus_tmp.times[0])+1 > len(focus_tmp.times): + # # print('Focus {} has less time points than it should, skipping.'.format(focus_id)) + # # continue + # unique_times = list(np.unique(focus_tmp.times)) + # focus_tmp_labels = [focus_tmp.cell_labels[focus_tmp.times.index(t)] for t in unique_times] + + # # Go over the time points of this focus and edit appropriate information main dictionary + # for i, t in enumerate(unique_times): + + # # get the region label + # label_tmp = focus_tmp_labels[i] + + # # M migration, event 0 + # # If the focus has another time point after this one then it must have migrated + # if i != len(unique_times)-1: + # cell_data[t]['regions'][label_tmp]['events'][0] = 1 + # # regions_and_events_by_time[t]['regions'][label_tmp]['events'][0] = 1 + + # # update matrix using this region label and the next one + # # print(label_tmp, focus_tmp_labels[i+1], regions_and_events_by_time[t]['matrix']) + # cell_data[t]['matrix'][label_tmp, focus_tmp_labels[i+1]] = 1 + # # regions_and_events_by_time[t]['matrix'][label_tmp, focus_tmp_labels[i+1]] = 1 + + # # S division, 1 + # if focus_tmp.daughters and i == len(unique_times)-1: + # cell_data[t]['regions'][label_tmp]['events'][1] = 1 + # # regions_and_events_by_time[t]['regions'][label_tmp]['events'][1] = 1 + + # # daughter 1 and 2 label + # # d1_label = self.All_Foci[focus_tmp.daughters[0].id].labels[0] + # d1_label = self.All_Foci[focus_tmp.daughters[0]].labels[0] + + # try: + # # d2_label = self.All_Foci[focus_tmp.daughters[1].id].labels[0] + # d2_label = self.All_Foci[focus_tmp.daughters[1]].labels[0] + # cell_data[t]['matrix'][label_tmp, d1_label] = 1 + # cell_data[t]['matrix'][label_tmp, d2_label] = 1 + # # regions_and_events_by_time[t]['matrix'][label_tmp, d1_label] = 1 + # # regions_and_events_by_time[t]['matrix'][label_tmp, d2_label] = 1 + + # except IndexError as e: + # print("At timepoint {} there was an index error in assigning daughters: {}".format(t,e)) + + # # I appears, 2 + # if not t == 1: + # if not focus_tmp.parent and i == 0: + # cell_data[t]['regions'][label_tmp]['events'][2] = 1 + # # regions_and_events_by_time[t]['regions'][label_tmp]['events'][2] = 1 + + # # O disappears, 3 + # if not focus_tmp.daughters and i == len(unique_times)-1: + # cell_data[t]['regions'][label_tmp]['events'][3] = 1 + # cell_data[t]['matrix'][label_tmp, 0] = 1 + # # regions_and_events_by_time[t]['regions'][label_tmp]['events'][3] = 1 + # # regions_and_events_by_time[t]['matrix'][label_tmp, 0] = 1 + + # # N no data, 4 - Set this to zero as this region as been checked. + # cell_data[t]['regions'][label_tmp]['events'][5] = 0 + # # regions_and_events_by_time[t]['regions'][label_tmp]['events'][5] = 0 + + # # Set remaining regions to event space [0 0 0 0 1 1] + # # Also make their appropriate matrix value 1, which should be in the first column. + # for cell_id,cell_data in regions_and_events_by_cell_and_time.items(): + # for t, t_data in cell_data.items(): + # for region, region_data in t_data['regions'].items(): + # if region_data['events'][5] == 1: + + # t_data['matrix'][region, 0] = 1 + + # for t, t_data in regions_and_events_by_time.items(): + # for region, region_data in t_data['regions'].items(): + # if region_data['events'][5] == 1: + + # t_data['matrix'][region, 0] = 1 + + # return(regions_and_events_by_time) + + def add_regions_to_frame(self, t_info, frame): # loop through foci within this frame and add their ellipses as children of their corresponding qpixmap object - regions = regions_and_events['regions'] + # all_t_cells = t_info['cells'] # print(regions_and_events) - frame_time = regions_and_events['time'] - for region_id in regions.keys(): - region = regions[region_id] + frame_time = t_info['time'] + # for cell_id,cell_info in all_t_cells.items(): + regions = t_info['regions'] + for region in regions.values(): # construct the ellipse graphic = region['region_graphic'] top_left = QPoint(graphic['left_x'],graphic['top_y']) @@ -714,9 +958,16 @@ def add_regions_to_frame(self, regions_and_events, frame): # instantiate a QGraphicsEllipseItem ellipse = QGraphicsEllipseItem(rect, frame) # add focus information to the QGraphicsEllipseItem - ellipse.focusMatrix = regions_and_events['matrix'] - ellipse.focusEvents = regions_and_events['regions'][region_id]['events'] - ellipse.focusProps = regions_and_events['regions'][region_id]['props'] + ellipse.focusMatrix = t_info['matrix'] + ellipse.focusEvents = region['events'] + ellipse.focusProps = region['props'] + + #################### 2020-10-31 ##################### + # ellipse.cell_id = cell_id + # ellipse.cell_daughters = t_info['daughters'] + ##################################################### + + # print(dir(region['props'])) ellipse.time = frame_time ellipse.setBrush(brush) ellipse.setPen(pen) @@ -767,6 +1018,8 @@ def draw_focus_events(self, start_time=1, end_time=None, update=False, original_ # loop through frames at valid times for time in valid_times: + # print() + # print(time) if time in self.all_frames_by_time_dict: frame = self.all_frames_by_time_dict[time] @@ -788,14 +1041,25 @@ def draw_focus_events(self, start_time=1, end_time=None, update=False, original_ focus_label = focus_properties.label focus_interactions = startItem.focusMatrix[focus_label,:] focus_events = startItem.focusEvents - # print(focus_events) - # print(focus_interactions) + ############################ 2020-10-31 ################################# + # start_cell_id = startItem.cell_id + # start_cell_compare_ids = startItem.cell_daughters + # if start_cell_compare_ids is not None: + # start_cell_compare_ids.append(start_cell_id) + # else: + # start_cell_compare_ids = [start_cell_id] + ######################################################################## + + # print('cell label: ', focus_label) + # print('events: ', focus_events) + # print('interactions: ', focus_interactions) # get centroid of focus represented by this qgraphics item firstPointY = focus_properties.centroid[0] firstPointX = focus_properties.centroid[1] + startItem.parentItem().x() firstPoint = QPoint(firstPointX, firstPointY) # which events happened to this focus? event_indices = np.where(focus_events == 1)[0] + # print(event_indices) if 2 in event_indices: # if the second element in event_indices was 1, @@ -811,38 +1075,47 @@ def draw_focus_events(self, start_time=1, end_time=None, update=False, original_ eventItem = self.set_event_item(firstPoint=firstPoint, startItem=startItem) self.addItem(eventItem) - try: - nextFrame = self.all_frames_by_time_dict[time+1] - for endItem in nextFrame.childItems(): - # if the item is an ellipse, move on to look into it further - if endItem.type() == 4: - end_focus_properties = endItem.focusProps - end_focus_label = end_focus_properties.label - # test whether this ellipse represents the - # focus that interacts with the focus represented - # by self.startItem - if focus_interactions[end_focus_label] == 1: - # if this is the focus that interacts with the former frame's focus, draw the line. - endPointY = end_focus_properties.centroid[0] - endPointX = end_focus_properties.centroid[1] + endItem.parentItem().x() - lastPoint = QPoint(endPointX, endPointY) - if 0 in event_indices: - # If the zero-th element in event_indices was 1, the focus migrates in the next frame - # get the information from focus_matrix to figure out to which region - # in the next frame it migrated - # set self.migration = True - self.set_migration() - - if 1 in event_indices: - self.set_children() - - if 4 in event_indices: - self.set_join() - - eventItem = self.set_event_item(firstPoint=firstPoint, startItem=startItem, lastPoint=lastPoint, endItem=endItem) - self.addItem(eventItem) - except KeyError: - continue + # if this focus didn't disappear, try the next stuff for migration + else: + try: + nextFrame = self.all_frames_by_time_dict[time+1] + for endItem in nextFrame.childItems(): + # if the item is an ellipse, move on to look into it further + if endItem.type() == 4: + end_focus_properties = endItem.focusProps + end_focus_label = end_focus_properties.label + # end_cell_id = endItem.cell_id + # see whether this ellipse's cell is in the startItem's cell or daughters + + ############### 2020-10-31 #################### + # if end_cell_id in start_cell_compare_ids: + ############################################### + + # test whether this ellipse represents the + # focus that interacts with the focus represented + # by self.startItem + if focus_interactions[end_focus_label] == 1: + # if this is the focus that interacts with the former frame's focus, draw the line. + endPointY = end_focus_properties.centroid[0] + endPointX = end_focus_properties.centroid[1] + endItem.parentItem().x() + lastPoint = QPoint(endPointX, endPointY) + if 0 in event_indices: + # If the zero-th element in event_indices was 1, the focus migrates in the next frame + # get the information from focus_matrix to figure out to which region + # in the next frame it migrated + # set self.migration = True + self.set_migration() + + if 1 in event_indices: + self.set_children() + + if 4 in event_indices: + self.set_join() + + eventItem = self.set_event_item(firstPoint=firstPoint, startItem=startItem, lastPoint=lastPoint, endItem=endItem) + self.addItem(eventItem) + except KeyError: + continue if original_event_type is not None: if original_event_type == "MigrationLine": @@ -981,8 +1254,14 @@ def update_frame_info(self, frame): # # Fetch the focus's original information # 'matrix' is a 2D array, for which the row index is the region label at time t, and the column index is the region label at time t+1 # If a region disappears from t to t+1, it will receive a 1 in the column with index 0. - time_matrix = self.track_info[frame_time]['matrix'] - focus_events = self.track_info[frame_time]['regions'][focus_label]['events'] + + t_info = self.track_info[frame_time] + # pprint(t_info) + # all_t_cells = t_info['cells'] + # for cell_info in all_t_cells.values(): + + time_matrix = t_info['matrix'] + focus_events = t_info['regions'][focus_label]['events'] # print("Events and matrix for focus {} at time {}: \n\n".format(focus_label, frame_time), # focus_events, "\n\n", time_matrix, "\n") @@ -1011,18 +1290,6 @@ def update_frame_info(self, frame): # print(start_focus_label, end_focus_label) time_matrix[start_focus_label,end_focus_label] = 1 - # if event_type == "ChildLine": - # # if the event is a child line that terminates with this focus, - # # set the 'born' index of focus_events to 1 - # if event.endItem == focus: - # focus_events[self.event_type_index_lookup["BornSymbol"]] = 1 - - # # If the event is zeroFocusSymbol or Appear, do this - # elif event_type in self.end_check_events_list: - - # for i,old_event_type in enumerate(events): - # print(i, old_event_type) - # if the event is either disappear or die, do this stuff else: focus_events[self.event_type_index_lookup[event_type]] = 1 @@ -1031,13 +1298,14 @@ def update_frame_info(self, frame): time_matrix[focus_label,0] = 1 # print("New events and matrix for focus {} at time {}: \n\n".format(focus_label, frame_time), - # focus_events, "\n\n", time_matrix, "\n") + # focus_events, "\n\n", time_matrix, "\n") def remove_old_conflicting_events(self, event): - # This function derives the foci involved in a given newly-annotated event and evaluates - # whether a focus has conflicting events, such as two migrations to the - # next frame, or three children, etc.. It then attempts to resolve the conflict by removing - # the older annotation. + '''This function derives the foci involved in a given newly-annotated event and evaluates + whether a focus has conflicting events, such as two migrations to the + next frame, or three children, etc.. It then attempts to resolve the conflict by removing + the older annotation. + ''' event_type = event.type() if event_type in self.line_events_list: diff --git a/mm3_GUI_helpers.py b/mm3_GUI_helpers.py index 505fbac..87e7c84 100755 --- a/mm3_GUI_helpers.py +++ b/mm3_GUI_helpers.py @@ -602,29 +602,35 @@ def __init__(self, parent,imgPaths,fov_id_list,image_dir):#,frame_index,peak_id, self.frameIndex = 0 self.img = self.phaseStack[self.frameIndex,:,:] - # self.originalImgMax = np.max(self.img) - self.originalImgMax = np.max(self.phaseStack) - originalRGBImg = color.gray2rgb(self.img/2**16*2**8).astype('uint8') - self.originalPhaseQImage = QImage(originalRGBImg, originalRGBImg.shape[1], originalRGBImg.shape[0], originalRGBImg.strides[0], QImage.Format_RGB888) + self.originalImgMax = np.max(self.img) + # self.originalImgMax = np.max(self.phaseStack) + originalRGBImg = color.gray2rgb(self.img) + self.originalPhaseQImage = QImage(originalRGBImg, originalRGBImg.shape[1], originalRGBImg.shape[0], originalRGBImg.strides[0], QImage.Format_RGBA64) - rescaledImg = self.img/self.originalImgMax*255 - RGBImg = color.gray2rgb(rescaledImg).astype('uint8') + rescaledImg = self.img/self.originalImgMax*(2**16-1) + RGBImg = color.gray2rgba(rescaledImg) self.originalHeight, self.originalWidth, self.originalChannelNumber = RGBImg.shape - self.phaseQimage = QImage(RGBImg, RGBImg.shape[1], RGBImg.shape[0], RGBImg.strides[0], QImage.Format_RGB888).scaled(1024, 1024, aspectRatioMode=Qt.KeepAspectRatio) + self.phaseQimage = QImage(RGBImg, RGBImg.shape[1], RGBImg.shape[0], RGBImg.strides[0], QImage.Format_RGBA64).scaled(1024, 1024, aspectRatioMode=Qt.KeepAspectRatio) self.phaseQpixmap = QPixmap(self.phaseQimage) + # rescaledImg = self.img/self.originalImgMax*255 + # RGBImg = color.gray2rgb(rescaledImg).astype('uint8') + # self.originalHeight, self.originalWidth, self.originalChannelNumber = RGBImg.shape + # self.phaseQimage = QImage(RGBImg, RGBImg.shape[1], RGBImg.shape[0], RGBImg.strides[0], QImage.Format_RGB888).scaled(1024, 1024, aspectRatioMode=Qt.KeepAspectRatio) + # self.phaseQpixmap = QPixmap(self.phaseQimage) + self.label = QLabel(self) self.label.setPixmap(self.phaseQpixmap) def setImg(self): - # self.originalImgMax = np.max(self.img) - originalRGBImg = color.gray2rgb(self.img/2**16*2**8).astype('uint8') - self.originalPhaseQImage = QImage(originalRGBImg, originalRGBImg.shape[1], originalRGBImg.shape[0], originalRGBImg.strides[0], QImage.Format_RGB888) + self.originalImgMax = np.max(self.img) + originalRGBImg = color.gray2rgb(self.img) + self.originalPhaseQImage = QImage(originalRGBImg, originalRGBImg.shape[1], originalRGBImg.shape[0], originalRGBImg.strides[0], QImage.Format_RGBA64) # rescaledImg = self.img/np.max(self.img)*255 - rescaledImg = self.img/self.originalImgMax*255 - RGBImg = color.gray2rgb(rescaledImg).astype('uint8') - self.phaseQimage = QImage(RGBImg, RGBImg.shape[1], RGBImg.shape[0], RGBImg.strides[0], QImage.Format_RGB888).scaled(1024, 1024, aspectRatioMode=Qt.KeepAspectRatio) + rescaledImg = self.img/self.originalImgMax*(2**16-1) + RGBImg = color.gray2rgb(rescaledImg) + self.phaseQimage = QImage(RGBImg, RGBImg.shape[1], RGBImg.shape[0], RGBImg.strides[0], QImage.Format_RGBA64).scaled(1024, 1024, aspectRatioMode=Qt.KeepAspectRatio) self.phaseQpixmap = QPixmap(self.phaseQimage) self.label.setPixmap(self.phaseQpixmap) diff --git a/mm3_Segment-Unet.py b/mm3_Segment-Unet.py index de094e6..eacc2bd 100755 --- a/mm3_Segment-Unet.py +++ b/mm3_Segment-Unet.py @@ -35,6 +35,24 @@ import mm3_helpers as mm3 +def segment_single_file(infile_name, params, namespace): + + mm3.information("Segmenting image {}.".format(infile_name)) + # load model to pass to algorithm + mm3.information("Loading model...") + + if namespace.modelfile: + model_file_path = namespace.modelfile + else: + model_file_path = params['segment']['model_file'] + + seg_model = models.load_model(model_file_path, + custom_objects={'bce_dice_loss': mm3.bce_dice_loss, + 'dice_loss': mm3.dice_loss}) + mm3.information("Model loaded.") + mm3.segment_stack_unet(infile_name, seg_model, mode='segment') + sys.exit("Completed segmenting image {}.".format(infile_name)) + # when using this script as a function and not as a library the following will execute if __name__ == "__main__": @@ -43,6 +61,8 @@ description='Segment cells and create lineages.') parser.add_argument('-f', '--paramfile', type=str, required=True, help='Yaml file containing parameters.') + parser.add_argument('-i', '--infile', type=str, + required=False, help='Use this argument to segment ONLY on image. Name the single file to be segmented.') parser.add_argument('-o', '--fov', type=str, required=False, help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.') parser.add_argument('-j', '--nproc', type=int, @@ -59,6 +79,14 @@ mm3.warning('No param file specified. Using 100X template.') param_file_path = 'yaml_templates/params_SJ110_100X.yaml' p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + # set segmentation image name for saving and loading segmented images + p['seg_img'] = 'seg_unet' + p['pred_img'] = 'pred_unet' + + if namespace.infile: + + fname = namespace.infile + segment_single_file(fname, p, namespace) if namespace.fov: if '-' in namespace.fov: @@ -80,10 +108,6 @@ if not os.path.exists(p['cell_dir']): os.makedirs(p['cell_dir']) - # set segmentation image name for saving and loading segmented images - p['seg_img'] = 'seg_unet' - p['pred_img'] = 'pred_unet' - # load specs file specs = mm3.load_specs() # print(specs) # for debugging diff --git a/mm3_Track.py b/mm3_Track.py index f57fe74..366bd73 100755 --- a/mm3_Track.py +++ b/mm3_Track.py @@ -6,6 +6,7 @@ import sys import os # import time +import re import inspect import argparse import yaml @@ -17,9 +18,11 @@ import numpy as np from scipy.io import savemat -from skimage import measure +from skimage import measure, io from tensorflow.keras import models +from matplotlib import pyplot as plt # for debugging + # user modules # realpath() will make your script run, even if you symlink it cmd_folder = os.path.realpath(os.path.abspath( @@ -36,22 +39,365 @@ import mm3_helpers as mm3 +def extract_fov_and_peak_ids(infile_name): + + fov_id = mm3.get_fov(infile_name) + peak_id = mm3.get_peak(infile_name) + + return (fov_id,peak_id) + +def track_single_file( + phase_file_name, + seg_file_name, + params, + namespace): + + mm3.information("Tracking cells in {}.".format(seg_file_name)) + # load model to pass to algorithm + mm3.information("Loading model...") + + params['tracking']['migrate_model'] = namespace.migrate_modelfile + params['tracking']['child_model'] = namespace.child_modelfile + params['tracking']['appear_model'] = namespace.appear_modelfile + params['tracking']['die_model'] = namespace.die_modelfile + params['tracking']['disappear_model'] = namespace.disappear_modelfile + params['tracking']['born_model'] = namespace.born_modelfile + + model_dict = mm3.get_tracking_model_dict() + + fov_id,peak_id = extract_fov_and_peak_ids(phase_file_name) + + tracks = {} + track_loop( + fov_id, + peak_id, + params, + tracks, + model_dict, + phase_file_name=phase_file_name, + seg_file_name=seg_file_name + ) + + track_file_name = "{}_xy{:0=3}_p{:0=4}_tracks.pkl".format( + params['experiment_name'], + fov_id, + peak_id + ) + + with open(track_file_name, 'wb') as cell_file: + pickle.dump(tracks, cell_file) + + sys.exit("Completed tracking cells in stack {}.".format(seg_file_name)) + +def run_cells(tracks, + peak_id, + fov_id, + params, + predictions_dict, + regions_by_time, + born_threshold = 0.85, + appear_threshold = 0.85): + + G,graph_df = mm3.initialize_track_graph(peak_id=peak_id, + fov_id=fov_id, + experiment_name=params['experiment_name'], + predictions_dict=predictions_dict, + regions_by_time = regions_by_time, + born_threshold=born_threshold, + appear_threshold=appear_threshold) + + tracks.update(mm3.create_lineages_from_graph(G, graph_df, fov_id, peak_id)) + +def run_foci(tracks, + peak_id, + fov_id, + params, + predictions_dict, + regions_by_time, + Cells, + appear_threshold = 0.85, + max_cell_number = 6): + + G,graph_df = mm3.initialize_focus_track_graph( + peak_id=peak_id, + fov_id=fov_id, + experiment_name=params['experiment_name'], + predictions_dict=predictions_dict, + regions_by_time = regions_by_time, + appear_threshold=appear_threshold, + ) + + tracks.update(mm3.create_focus_lineages_from_graph(G, graph_df, fov_id, peak_id, Cells, max_cell_number)) + + +def track_loop( + fov_id, + peak_id, + params, + tracks, + model_dict, + cell_number = 6, + data_number = 9, + img_file_name = None, + seg_file_name = None, + track_type = 'cells', + max_cell_number = 6): + + if img_file_name is None: + + if track_type == 'cells': + seg_stack = mm3.load_stack(fov_id, peak_id, color=params['seg_img']) + img_stack = mm3.load_stack(fov_id, peak_id, color=params['phase_plane']) + elif track_type == 'foci': + seg_stack = mm3.load_stack(fov_id, peak_id, color=params['seg_img']) + img_stack = mm3.load_stack(fov_id, peak_id, color=params['foci']['foci_plane']) + + else: + seg_stack = io.imread(seg_file_name) + img_stack = io.imread(img_file_name) + + # run predictions for each tracking class + # consider only the top six cells for a given trap when doing tracking + frame_number = seg_stack.shape[0] + + # sometimes a phase contrast image is missed and has no signal. + # This is a workaround for that problem + no_signal_frames = [] + for k,img in enumerate(img_stack): + if track_type == 'foci': + if np.max(img) < 100: + no_signal_frames.append(k) + elif track_type == 'cells': + # if the mean phase image signal is less than 200, add its index to list + if np.mean(img) < 200: + no_signal_frames.append(k) + + # loop through segmentation stack and replace frame from missed phase image + # with the prior frame. + for k,label_img in enumerate(seg_stack): + if k in no_signal_frames: + seg_stack[k,...] = seg_stack[k-1,...] + + if track_type == 'cells': + regions_by_time = [measure.regionprops(label_image=img) for img in seg_stack] + elif track_type == 'foci': + with open(p['cell_dir'] + '/all_cells.pkl', 'rb') as cell_file: + Cells = pickle.load(cell_file) + regions_by_time = [] + for i,img in enumerate(seg_stack): + regs = measure.regionprops(label_image=img, intensity_image=img_stack[i,:,:]) + regs_sorted = mm3.sort_regions_in_list(regs) + regions_by_time.append(regs_sorted) + + if track_type == 'cells': + # have generator yield info for top six cells in all frames + prediction_generator = mm3.PredictTrackDataGenerator(regions_by_time, batch_size=frame_number, dim=(cell_number,5,data_number), track_type=track_type) + elif track_type == 'foci': + prediction_generator = mm3.PredictTrackDataGenerator( + regions_by_time, + batch_size=frame_number, + dim=(cell_number,5,data_number), + track_type=track_type, + img_stack=img_stack, + images=True, + img_dim=(5,256,32) + ) + cell_info = prediction_generator.__getitem__(0) + + predictions_dict = {} + # run data through each classification model + for key,mod in model_dict.items(): + + # Run predictions and add to dictionary + if key in ['zero_cell_model', 'one_cell_model' , 'two_cell_model', 'geq_three_cell_model']: + continue + + mm3.information('Predicting probability of {} events in FOV {}, trap {}.'.format('_'.join(key.split('_')[:-1]), fov_id, peak_id)) + predictions_dict['{}_predictions'.format(key)] = mod.predict(cell_info) + + if track_type == 'cells': + run_cells( + tracks, + peak_id, + fov_id, + params, + predictions_dict, + regions_by_time, + ) + + elif track_type == 'foci': + pred_dict = {} + ( + outbound1, + outbound2, + outbound3, + outbound4, + outbound5, + outbound6, + pred_dict['appear_model_predictions'] + ) = predictions_dict['all_model_predictions'] + # for this in predictions_dict['all_model_predictions']: + # print(this.shape) + # pred_dict['appear_model_predictions'],pred_dict['disappear_model_predictions'],pred_dict['appear_model_predictions'] = predictions_dict['all_model_predictions'] + + # take the -2nd element of each outbound array. the -1st is for "no focus", -2nd is for 'disappear, 0:6 are for migrate. + pred_dict['disappear_model_predicitons'] = np.transpose(np.array( + [outbound1[:,-2],outbound2[:,-2],outbound3[:,-2],outbound4[:,-2],outbound5[:,-2],outbound6[:,-2]] + )) + + # take the 0:6 elements of each outbound prediction result. + pred_dict['migrate_model_predictions'] = np.concatenate( + [ + outbound1[:,:6], + outbound2[:,:6], + outbound3[:,:6], + outbound4[:,:6], + outbound5[:,:6], + outbound6[:,:6], + ], + axis=1 + ) + + # print(pred_dict['migrate_model_predictions'].shape) + + run_foci( + tracks, + peak_id, + fov_id, + params, + pred_dict, + regions_by_time, + Cells, + max_cell_number=max_cell_number, + appear_threshold=0.85 + ) + # when using this script as a function and not as a library the following will execute if __name__ == "__main__": # set switches and parameters - parser = argparse.ArgumentParser(prog='python mm3_Track.py', - description='Track cells and create lineages.') - parser.add_argument('-f', '--paramfile', type=str, - required=True, help='Yaml file containing parameters.') - parser.add_argument('-o', '--fov', type=str, - required=False, help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.') - # parser.add_argument('-p', '--peak', type=str, - # required=False, help='List of peak ids to analyze. Input "1", "1,2,3", or "1-10", etc.') - parser.add_argument('-j', '--nproc', type=int, - required=False, help='Number of processors to use.') - parser.add_argument('-m', '--modelfile', type=str, - required=False, help='Path to trained model.') + parser = argparse.ArgumentParser( + prog='python mm3_Track.py', + description='Track cells or fluroescent foci and create lineages.' + ) + subparsers = parser.add_subparsers(help='commands', dest='command') + + # cells + cell_parser = subparsers.add_parser( + 'cells', + help = "Track cells", + ) + + # foci + focus_parser = subparsers.add_parser( + 'foci', + help = "Track fluorescent foci" + ) + + parser.add_argument( + '-f', + '--paramfile', + type=str, + required=True, + help='Yaml file containing parameters.' + ) + parser.add_argument( + '-o', + '--fov', + type=str, + required=False, + help='List of fields of view to analyze. Input "1", "1,2,3", or "1-10", etc.' + ) + parser.add_argument( + '--peak', + type=str, + required=False, + help='List of peaks to analyze. Input "1", "1,2,3", or "1-10", etc.' + ) + # parser.add_argument( + # '-j', + # '--nproc', + # type=int, + # required=False, + # help='Number of processors to use.' + # ) + parser.add_argument( + '-r', + '--chtc', + action='store_true', + required=False, + help='Add this flag at the command line if the job will run at chtc.' + ) + cell_parser.add_argument( + '-p', + '--phase_file_name', + type=str, + required=False, + help='Name of file containing stack of images for a single fov/peak' + ) + focus_parser.add_argument( + '-fl', + '--fluor_file_name', + type=str, + required=False, + help='Name of file containing stack of fluorescent images for a single fov/peak' + ) + parser.add_argument( + '-s', + '--seg_file_name', + type=str, + required=False, + help='Name of file containing stack of images for a single fov/peak' + ) + parser.add_argument( + '--migrate_modelfile', + type=str, + required=False, + help='Path to trained migration model.' + ) + cell_parser.add_argument( + '--child_modelfile', + type=str, + required=False, + help='Path to trained child model.' + ) + parser.add_argument( + '--appear_modelfile', + type=str, + required=False, + help='Path to trained appear model.' + ) + cell_parser.add_argument( + '--die_modelfile', + type=str, + required=False, + help='Path to trained die model.' + ) + parser.add_argument( + '--disappear_modelfile', + type=str, + required=False, + help='Path to trained disappear model.' + ) + cell_parser.add_argument( + '--born_modelfile', + type=str, + required=False, + help='Path to trained born model.' + ) + parser.add_argument( + '--specfile', + type=str, + required=False, + help='Path to specs file.' + ) + parser.add_argument( + '--timefile', + type=str, + required=False, + help='Path to file containing time table.' + ) + namespace = parser.parse_args() # Load the project parameters file @@ -72,29 +418,40 @@ else: user_spec_fovs = [] - # if namespace.peak: - # if '-' in namespace.peak: - # user_spec_peaks = range(int(namespace.fov.split("-")[0]), - # int(namespace.fov.split("-")[1])+1) - # else: - # user_spec_peaks = [int(val) for val in namespace.fov.split(",")] - # else: - # user_spec_peaks = [] - - # number of threads for multiprocessing - if namespace.nproc: - p['num_analyzers'] = namespace.nproc - mm3.information('Using {} threads for multiprocessing.'.format(p['num_analyzers'])) - - if not os.path.exists(p['cell_dir']): - os.makedirs(p['cell_dir']) + if namespace.peak: + if '-' in namespace.peak: + user_spec_peaks = range(int(namespace.peak.split("-")[0]), + int(namespace.peak.split("-")[1])+1) + else: + user_spec_peaks = [int(val) for val in namespace.peak.split(",")] + else: + user_spec_peaks = [] # set segmentation image name for saving and loading segmented images - p['seg_img'] = 'seg_unet' + if namespace.command == 'cells': + p['seg_img'] = 'seg_unet' + elif namespace.command == 'foci': + p['seg_img'] = 'foci_seg_unet' # load specs file - specs = mm3.load_specs() - # pprint(specs) # for debugging + if namespace.chtc: + specs = mm3.load_specs(fname=namespace.specfile) + mm3.load_time_table(fname=namespace.timefile) + else: + specs = mm3.load_specs() + mm3.load_time_table() + + if namespace.command == 'cells': + if namespace.phase_file_name: + track_single_file( + namespace.phase_file_name, + namespace.seg_file_name, + p, + namespace + ) + + if not os.path.exists(p['cell_dir']): + os.makedirs(p['cell_dir']) # make list of FOVs to process (keys of channel_mask file) fov_id_list = sorted([fov_id for fov_id in specs.keys()]) @@ -111,13 +468,10 @@ # read in models as dictionary # keys are 'migrate_model', 'child_model', 'appear_model', 'die_model', 'disappear_model', etc. # NOTE on 2019-07-15: For now, some of the models are ignored by the tracking algorithm, as they don't yet perform well - model_dict = mm3.get_tracking_model_dict() - - # Load time table, which goes into params - mm3.load_time_table() - - # This dictionary holds information for all cells - # Cells = {} + if namespace.command == 'cells': + model_dict = mm3.get_tracking_model_dict() + elif namespace.command == 'foci': + model_dict = mm3.get_focus_tracking_model_dict() # do lineage creation per fov, per trap tracks = {} @@ -126,54 +480,70 @@ # update will add the output from make_lineages_function, which is a # dict of Cell entries, into Cells ana_peak_ids = [peak_id for peak_id in specs[fov_id].keys() if specs[fov_id][peak_id] == 1] - # ana_peak_ids = [9,13,15,19,25,33,36,37,38,39] # was used for debugging + if user_spec_peaks: + ana_peak_ids[:] = [peak for peak in ana_peak_ids if peak in user_spec_peaks] + for j,peak_id in enumerate(ana_peak_ids): - seg_stack = mm3.load_stack(fov_id, peak_id, color=p['seg_img']) - # run predictions for each tracking class - # consider only the top six cells for a given trap when doing tracking - cell_number = 6 - frame_number = seg_stack.shape[0] - # get region properties - regions_by_time = [measure.regionprops(label_image=img) for img in seg_stack] + if namespace.command == 'cells': + + track_loop( + fov_id, + peak_id, + p, + tracks, + model_dict, + track_type = namespace.command + ) + + elif namespace.command == 'foci': + + track_loop( + fov_id, + peak_id, + p, + tracks, + model_dict, + data_number = 11, + track_type = namespace.command, + max_cell_number = 6 + ) + + mm3.information("Finished lineage creation.") - # have generator yield info for top six cells in all frames - prediction_generator = mm3.PredictTrackDataGenerator(regions_by_time, batch_size=frame_number, dim=(cell_number,5,9)) - cell_info = prediction_generator.__getitem__(0) + ### Now prune and save the data. + if namespace.command == 'cells': + mm3.information("Saving cell data.") - predictions_dict = {} - # run data through each classification model - for key,mod in model_dict.items(): + ### save the cell data. Use the script mm3_OutputData for additional outputs. + # All cell data (includes incomplete cells) + if not os.path.isdir(p['cell_dir']): + os.mkdir(p['cell_dir']) - # Run predictions and add to dictionary - if key in ['zero_cell_model', 'one_cell_model' , 'two_cell_model', 'geq_three_cell_model']: - continue + with open(p['cell_dir'] + '/all_cells.pkl', 'wb') as cell_file: + pickle.dump(tracks, cell_file, protocol=pickle.HIGHEST_PROTOCOL) - mm3.information('Predicting probability of {} events in FOV {}, trap {}.'.format('_'.join(key.split('_')[:-1]), fov_id, peak_id)) - predictions_dict['{}_predictions'.format(key)] = mod.predict(cell_info) + if os.path.isfile(os.path.join(p['cell_dir'], 'complete_cells.pkl')): + os.remove(os.path.join(p['cell_dir'], 'complete_cells.pkl')) - G,graph_df = mm3.initialize_track_graph(peak_id=peak_id, - fov_id=fov_id, - experiment_name=p['experiment_name'], - predictions_dict=predictions_dict, - regions_by_time = regions_by_time, - born_threshold=0.85, - appear_threshold=0.85) + os.symlink( + os.path.join(p['cell_dir'], 'all_cells.pkl'), + os.path.join(p['cell_dir'], 'complete_cells.pkl') + ) - # tracks[fov_id][peak_id] = mm3.create_lineages_from_graph_2(G, graph_df, fov_id, peak_id) - tracks.update(mm3.create_lineages_from_graph(G, graph_df, fov_id, peak_id)) + mm3.information("Finished curating and saving cell data.") - mm3.information("Finished lineage creation.") + elif namespace.command == 'foci': + mm3.information("Saving focus track data.") - ### Now prune and save the data. - mm3.information("Saving cell data.") + if not os.path.isdir(p['foci_track_dir']): + os.mkdir(p['foci_track_dir']) - ### save the cell data. Use the script mm3_OutputData for additional outputs. - # All cell data (includes incomplete cells) - if not os.path.isdir(p['cell_dir']): - os.mkdir(p['cell_dir']) + with open(os.path.join(p['foci_track_dir'], 'all_foci.pkl'), 'wb') as foci_file: + pickle.dump(tracks, foci_file, protocol=pickle.HIGHEST_PROTOCOL) - with open(p['cell_dir'] + '/all_cells.pkl', 'wb') as cell_file: - pickle.dump(tracks, cell_file, protocol=pickle.HIGHEST_PROTOCOL) + # with open(os.path.join(p['cell_dir'],'all_cells_with_foci.pkl'), 'wb') as cell_file: + # pickle.dump(Cells, cell_file, protocol=pickle.HIGHEST_PROTOCOL) - mm3.information("Finished curating and saving cell data.") + mm3.information("Finished curating and saving focus data in {} and updated cell data in {}.".format(os.path.join(p['foci_track_dir'], 'all_foci.pkl'), + os.path.join(p['cell_dir'], 'all_cells_with_foci.pkl'))) diff --git a/mm3_TrackFoci.py b/mm3_TrackFoci.py index 69a16a2..c42ff88 100755 --- a/mm3_TrackFoci.py +++ b/mm3_TrackFoci.py @@ -117,9 +117,20 @@ with open(os.path.join(p['cell_dir'],'all_cells.pkl'), 'rb') as cell_file: Cells = pickle.load(cell_file) + ######################################################################################################################## + ########## TO DO: reorganize how tracking is done, so that it goes cell-by-cell, rather than frame-by-frame. ########### + ######################################################################################################################## foci = {} # foci_info_unet modifies foci dictionary in place, so nothing returned here - mm3.dev_foci_info_unet(foci, + # mm3.foci_info_unet( + # foci, + # Cells, + # specs, + # p['time_table'], + # channel_name="c{}".format(namespace.channel) + # ) + + mm3.foci_info_unet(foci, Cells, specs, p['time_table'], diff --git a/mm3_curateTrainingData.py b/mm3_curateTrainingData.py index abb2c2a..0727179 100755 --- a/mm3_curateTrainingData.py +++ b/mm3_curateTrainingData.py @@ -60,7 +60,7 @@ required=False, default=1, help='Which channel, e.g. phase or some fluorescence image, should be used for creating masks. \ Accepts integers. Default is 1, which is usually your phase contrast images.') - parser.add_argument('-n', '--no_prior_mask', action='store_true', + parser.add_argument('-m', '--focus_mask', action='store_true', help='Apply this argument is you are making masks de novo, i.e., if no masks exist yet for your images.') namespace = parser.parse_args() @@ -84,8 +84,15 @@ else: user_spec_fovs = [] - if not os.path.exists(p['seg_dir']): - sys.exit("Exiting: Segmentation directory, {}, not found.".format(p['seg_dir'])) + if namespace.focus_mask: + seg_dir = 'foci_seg_dir' + seg_img_search = r'.+(foci_seg_.+)\.tif$' + else: + seg_dir = 'seg_dir' + seg_img_search = r'.+(seg_.+)\.tif$' + + if not os.path.exists(p[seg_dir]): + sys.exit("Exiting: Segmentation directory, {}, not found.".format(p[seg_dir])) if not os.path.exists(p['chnl_dir']): sys.exit("Exiting: Channel directory, {}, not found.".format(p['chnl_dir'])) @@ -97,27 +104,22 @@ fov_id_list[:] = [fov for fov in fov_id_list if fov in user_spec_fovs] # set segmentation image name for segmented images - seg_suffix_finder = re.compile(r'.+(seg_.+)\.tif$') - test_name = glob.glob(os.path.join(p['seg_dir'],'*xy{:0=3}*.tif'.format(fov_id_list[0])))[0] + seg_suffix_finder = re.compile(seg_img_search) + test_name = glob.glob(os.path.join(p[seg_dir],'*xy{:0=3}*.tif'.format(fov_id_list[0])))[0] mat = seg_suffix_finder.match(test_name) p['seg_img'] = mat.groups()[0] ## be careful here, it is lookgin for segmented images + print(p['seg_img']) # get paired phase file names and mask file names for each fov fov_filename_dict = {} for fov_id in fov_id_list: - if namespace.no_prior_mask: - mask_filenames = None - else: - mask_filenames = [os.path.join(p['seg_dir'],fname) for fname in glob.glob(os.path.join(p['seg_dir'],'*xy{:0=3}*{}.tif'.format(fov_id, p['seg_img'])))] - - image_filenames = [fname.replace(p['seg_dir'], p['chnl_dir']).replace(p['seg_img'], 'c{}'.format(namespace.channel)) for fname in mask_filenames] + mask_filenames = [os.path.join(p[seg_dir],fname) for fname in glob.glob(os.path.join(p[seg_dir],'*xy{:0=3}*{}.tif'.format(fov_id, p['seg_img'])))] + + image_filenames = [fname.replace(p[seg_dir], p['chnl_dir']).replace(p['seg_img'], 'c{}'.format(namespace.channel)) for fname in mask_filenames] fov_filename_dict[fov_id] = [] - if mask_filenames is not None: - for i in range(len(mask_filenames)): - fov_filename_dict[fov_id].append((image_filenames[i],mask_filenames[i])) - else: - fov_filename_dict[fov_id].append((image_filenames[i], None)) + for i in range(len(mask_filenames)): + fov_filename_dict[fov_id].append((image_filenames[i],mask_filenames[i])) # print([names for names in fov_filename_dict[1]]) # for debugging diff --git a/mm3_helpers.py b/mm3_helpers.py index 41f54be..8a148b8 100755 --- a/mm3_helpers.py +++ b/mm3_helpers.py @@ -88,12 +88,15 @@ def information(*objs): print(time.strftime("%H:%M:%S", time.localtime()), *objs, file=sys.stdout) # load the parameters file into a global dictionary for this module -def init_mm3_helpers(param_file_path): +def init_mm3_helpers(param_file_path, datapath = None): # load all the parameters into a global dictionary global params with open(param_file_path, 'r') as param_file: params = yaml.safe_load(param_file) + if datapath is not None: + params['experiment_directory'] = datapath + # set up how to manage cores for multiprocessing params['num_analyzers'] = multiprocessing.cpu_count() @@ -136,7 +139,7 @@ def julian_day_number(): return jdn def get_plane(filepath): - pattern = r'(c\d+).tif' + pattern = r'(c\d+)\.tif' res = re.search(pattern,filepath) if (res != None): return res.group(1) @@ -144,7 +147,31 @@ def get_plane(filepath): return None def get_fov(filepath): - pattern = r'xy(\d+)\w*.tif' + pattern = r'xy(\d{2,4})\w*\.tif' + res = re.search(pattern,filepath) + if (res != None): + return int(res.group(1)) + else: + return None + +def get_peak(filepath): + pattern = r'p(\d{3,4})\w*\.tif' + res = re.search(pattern,filepath) + if (res != None): + return int(res.group(1)) + else: + return None + +def get_pkl_fov(filepath): + pattern = r'xy(\d{2,4})\w*\.pkl' + res = re.search(pattern,filepath) + if (res != None): + return int(res.group(1)) + else: + return None + +def get_pkl_peak(filepath): + pattern = r'p(\d{3,4})\w*\.pkl' res = re.search(pattern,filepath) if (res != None): return int(res.group(1)) @@ -152,7 +179,7 @@ def get_fov(filepath): return None def get_time(filepath): - pattern = r't(\d+)xy\w+.tif' + pattern = r't(\d+)xy\w+\.tif' res = re.search(pattern,filepath) if (res != None): return np.int_(res.group(1)) @@ -225,18 +252,26 @@ def load_stack(fov_id, peak_id, color='c1', image_return_number=None): return img_stack # load the time table and add it to the global params -def load_time_table(): +def load_time_table(fname=None): '''Add the time table dictionary to the params global dictionary. This is so it can be used during Cell creation. ''' # try first for yaml, then for pkl - try: - with open(os.path.join(params['ana_dir'], 'time_table.yaml'), 'rb') as time_table_file: - params['time_table'] = yaml.safe_load(time_table_file) - except: - with open(os.path.join(params['ana_dir'], 'time_table.pkl'), 'rb') as time_table_file: - params['time_table'] = pickle.load(time_table_file) + if fname is None: + try: + with open(os.path.join(params['ana_dir'], 'time_table.yaml'), 'rb') as time_table_file: + params['time_table'] = yaml.safe_load(time_table_file) + except: + with open(os.path.join(params['ana_dir'], 'time_table.pkl'), 'rb') as time_table_file: + params['time_table'] = pickle.load(time_table_file) + else: + try: + with open(fname, 'rb') as time_table_file: + params['time_table'] = yaml.safe_load(time_table_file) + except: + with open(fname, 'rb') as time_table_file: + params['time_table'] = pickle.load(time_table_file) return @@ -264,18 +299,29 @@ def load_channel_masks(): return channel_masks # function for loading the specs file -def load_specs(): +def load_specs(fname = None): '''Load specs file which indicates which channels should be analyzed, used as empties, or ignored.''' - try: - with open(os.path.join(params['ana_dir'], 'specs.yaml'), 'r') as specs_file: - specs = yaml.safe_load(specs_file) - except: + if fname is None: try: - with open(os.path.join(params['ana_dir'], 'specs.pkl'), 'rb') as specs_file: - specs = pickle.load(specs_file) - except ValueError: - warning('Could not load specs file.') + with open(os.path.join(params['ana_dir'], 'specs.yaml'), 'r') as specs_file: + specs = yaml.safe_load(specs_file) + except: + try: + with open(os.path.join(params['ana_dir'], 'specs.pkl'), 'rb') as specs_file: + specs = pickle.load(specs_file) + except ValueError: + warning('Could not load specs file.') + else: + try: + with open(fname, 'r') as specs_file: + specs = yaml.safe_load(specs_file) + except: + try: + with open(fname, 'rb') as specs_file: + specs = pickle.load(specs_file) + except ValueError: + warning('Could not load specs file.') return specs @@ -633,11 +679,6 @@ def make_time_table(analyzed_imgs): time_table[int(idata['fov'])][int(idata['t'])] = int(t_in_seconds) - # save to .pkl. This pkl will be loaded into the params - # with open(os.path.join(params['ana_dir'], 'time_table.pkl'), 'wb') as time_table_file: - # pickle.dump(time_table, time_table_file, protocol=pickle.HIGHEST_PROTOCOL) - # with open(os.path.join(params['ana_dir'], 'time_table.txt'), 'w') as time_table_file: - # pprint(time_table, stream=time_table_file) with open(os.path.join(params['ana_dir'], 'time_table.yaml'), 'w') as time_table_file: yaml.dump(data=time_table, stream=time_table_file, default_flow_style=False, tags=None) information('Time table saved.') @@ -1123,19 +1164,23 @@ def predict_first_image_channels(img, model, 'n_channels':1, 'normalize_to_one':True, 'shuffle':False} - predict_gen_args = {'verbose':1, - 'use_multiprocessing':True, - 'workers':params['num_analyzers']} + predict_gen_args = { + 'verbose':1, + 'use_multiprocessing':False, + # 'workers':params['num_analyzers'], + } img_generator = TrapSegmentationDataGenerator(crops, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) prediction = imageConcatenatorFeatures(predictions, subImageNumber=subImageNumber) #print(prediction.shape) cropsExpand = tileImage(imgStackExpand, subImageNumber=padSubImageNumber) cropsExpand = np.expand_dims(cropsExpand, -1) img_generator = TrapSegmentationDataGenerator(cropsExpand, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) predictionExpand = imageConcatenatorFeatures2(predictions, subImageNumber=padSubImageNumber) predictionExpand = util.crop(predictionExpand, ((0,0),(shiftDistance,shiftDistance),(shiftDistance,shiftDistance),(0,0))) #print(predictionExpand.shape) @@ -1143,7 +1188,8 @@ def predict_first_image_channels(img, model, cropsShiftLeft = tileImage(imgStackShiftLeft, subImageNumber=subImageNumber) cropsShiftLeft = np.expand_dims(cropsShiftLeft, -1) img_generator = TrapSegmentationDataGenerator(cropsShiftLeft, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) predictionLeft = imageConcatenatorFeatures(predictions, subImageNumber=subImageNumber) predictionLeft = np.pad(predictionLeft, pad_width=((0,0),(0,0),(0,shiftDistance),(0,0)), mode='constant', constant_values=((0,0),(0,0),(0,0),(0,0)))[:,:,shiftDistance:,:] @@ -1152,7 +1198,8 @@ def predict_first_image_channels(img, model, cropsShiftRight = tileImage(imgStackShiftRight, subImageNumber=subImageNumber) cropsShiftRight = np.expand_dims(cropsShiftRight, -1) img_generator = TrapSegmentationDataGenerator(cropsShiftRight, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) predictionRight = imageConcatenatorFeatures(predictions, subImageNumber=subImageNumber) predictionRight = np.pad(predictionRight, pad_width=((0,0),(0,0),(shiftDistance,0),(0,0)), mode='constant', constant_values=((0,0),(0,0),(0,0),(0,0)))[:,:,:(-1*shiftDistance),:] @@ -1162,7 +1209,8 @@ def predict_first_image_channels(img, model, #print(cropsShiftUp.shape) cropsShiftUp = np.expand_dims(cropsShiftUp, -1) img_generator = TrapSegmentationDataGenerator(cropsShiftUp, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) predictionUp = imageConcatenatorFeatures(predictions, subImageNumber=subImageNumber) predictionUp = np.pad(predictionUp, pad_width=((0,0),(0,shiftDistance),(0,0),(0,0)), mode='constant', constant_values=((0,0),(0,0),(0,0),(0,0)))[:,shiftDistance:,:,:] @@ -1171,7 +1219,8 @@ def predict_first_image_channels(img, model, cropsShiftDown = tileImage(imgStackShiftDown, subImageNumber=subImageNumber) cropsShiftDown = np.expand_dims(cropsShiftDown, -1) img_generator = TrapSegmentationDataGenerator(cropsShiftDown, **data_gen_args) - predictions = model.predict_generator(img_generator, **predict_gen_args) + # predictions = model.predict_generator(img_generator, **predict_gen_args) + predictions = model.predict(img_generator, **predict_gen_args) predictionDown = imageConcatenatorFeatures(predictions, subImageNumber=subImageNumber) predictionDown = np.pad(predictionDown, pad_width=((0,0),(shiftDistance,0),(0,0),(0,0)), mode='constant', constant_values=((0,0),(0,0),(0,0),(0,0)))[:,:(-1*shiftDistance),:,:] @@ -1186,7 +1235,7 @@ def predict_first_image_channels(img, model, # takes initial U-net centroids for trap locations, and creats bounding boxes for each trap at the defined height and width def get_frame_trap_bounding_boxes(trapLabels, trapProps, trapAreaThreshold=2000, trapWidth=27, trapHeight=256): - badTrapLabels = [reg.label for reg in trapProps if reg.area < trapAreaThreshold] # filter out small "trap" regions + badTrapLabels = [reg.label for reg in trapProps if reg.area < trapAreaThreshold] # filter out small regions goodTraps = trapLabels.copy() for label in badTrapLabels: @@ -1194,9 +1243,12 @@ def get_frame_trap_bounding_boxes(trapLabels, trapProps, trapAreaThreshold=2000, goodTrapProps = measure.regionprops(goodTraps) trapCentroids = [(int(np.round(reg.centroid[0])),int(np.round(reg.centroid[1]))) for reg in goodTrapProps] # get centroids as integers + trap_orientations = [reg.orientation for reg in goodTrapProps] + trapBboxes = [] + trap_rotations = [] - for centroid in trapCentroids: + for i,centroid in enumerate(trapCentroids): rowIndex = centroid[0] colIndex = centroid[1] @@ -1217,20 +1269,35 @@ def get_frame_trap_bounding_boxes(trapLabels, trapProps, trapAreaThreshold=2000, trapBboxes.append((minRow,minCol,maxRow,maxCol)) - return(trapBboxes) + orientation = trap_orientations[i] + + if orientation < 0: + orientation = np.pi + orientation + rotation_angle = -1 * orientation / np.pi * 180 + 90 + + trap_rotations.append(rotation_angle) + + return(trapBboxes,trap_rotations) # this function performs image alignment as defined by the shifts passed as an argument -def crop_traps(fileNames, trapProps, labelledTraps, bboxesDict, trap_align_metadata): +def crop_traps(fileNames, trap_rotations_dict, labelledTraps, bboxesDict, trap_align_metadata): frameNum = trap_align_metadata['frame_count'] channelNum = trap_align_metadata['plane_number'] - trapImagesDict = {key:np.zeros((frameNum, - trap_align_metadata['trap_height'], - trap_align_metadata['trap_width'], - channelNum)) for key in bboxesDict} + trapImagesDict = { + key:np.zeros( + ( + frameNum, + trap_align_metadata['trap_height'], + trap_align_metadata['trap_width'], + channelNum + ) + ) for key in bboxesDict + } trapClosedEndPxDict = {} flipImageDict = {} trapMask = labelledTraps + pad_size = 20 for frame in range(frameNum): @@ -1242,12 +1309,54 @@ def crop_traps(fileNames, trapProps, labelledTraps, bboxesDict, trap_align_metad if len(fullFrameImg.shape) == 3: if fullFrameImg.shape[0] < 3: # for tifs with less than three imaging channels, the first dimension separates channels fullFrameImg = np.transpose(fullFrameImg, (1,2,0)) + + # if frame == 0: + # row_num = fullFrameImg.shape[0] + # col_num = fullFrameImg.shape[1] + trapClosedEndPxDict[fileNames[frame]] = {key:{} for key in bboxesDict.keys()} for key in trapImagesDict.keys(): + rotation_angle = trap_rotations_dict[key] bbox = bboxesDict[key][frame] - trapImagesDict[key][frame,:,:,:] = fullFrameImg[bbox[0]:bbox[2],bbox[1]:bbox[3],:] + # rotation_center = (int(np.round((bbox[1]+bbox[3])/2)), int(np.round((bbox[0]+bbox[2])/2))) + + if np.abs(rotation_angle) > 1.5: + + # PADDING IS SLOW!!!! MOVE ABOVE AND PAD FOR EACH IMAGE, REGARDLESS OF ROTATION OF A GIVEN TRAP + padded_fullFrameImg = np.pad( + fullFrameImg, + ((pad_size,pad_size), (pad_size,pad_size), (0,0)), + mode='constant' + ) + + min_row = bbox[0] + max_row = bbox[2] + 2*pad_size + min_col = bbox[1] + max_col = bbox[3] + 2*pad_size + + tmp_img = padded_fullFrameImg[min_row:max_row, min_col:max_col, :] + + for rot_frame in range(tmp_img.shape[-1]): + tmp_img[:,:,rot_frame] = transform.rotate( + tmp_img[:,:,rot_frame], + angle=rotation_angle, + preserve_range=True + ) + trapImagesDict[key][frame,:,:,:] = tmp_img[pad_size:-pad_size,pad_size:-pad_size,:] + + # for rot_frame in range(fullFrameImg.shape[-1]): + + # fullFrameImg[:,:,rot_frame] = transform.rotate( + # fullFrameImg[:,:,rot_frame], + # angle=rotation_angle, + # center=rotation_center, + # preserve_range=True + # ) + + else: + trapImagesDict[key][frame,:,:,:] = fullFrameImg[bbox[0]:bbox[2],bbox[1]:bbox[3],:] #tmpImg = np.reshape(fullFrameImg[trapMask==key], (trapHeight,trapWidth,channelNum)) @@ -2295,6 +2404,14 @@ def bce_dice_loss(y_true, y_pred): loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) return loss +def weighted_bce(y_true, y_pred): + bce = losses.binary_crossentropy(y_true, y_pred) + bce = tf.expand_dims(bce, -1) + y_true = tf.cast(y_true, 'float32') + weights = y_true * 270. + 1. + weighted_loss = tf.reduce_mean(bce * weights) + return weighted_loss + def tversky_loss(y_true, y_pred): alpha = 0.5 beta = 0.5 @@ -2317,6 +2434,43 @@ def cce_tversky_loss(y_true, y_pred): loss = losses.categorical_crossentropy(y_true, y_pred) + tversky_loss(y_true, y_pred) return loss +def weighted_categorical_crossentropy(weights): + """ + A weighted version of keras.objectives.categorical_crossentropy + + Variables: + weights: numpy array of shape (C,) where C is the number of classes + + Usage: + weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x. + loss = weighted_categorical_crossentropy(weights) + model.compile(loss=loss,optimizer='adam') + """ + + weights = K.variable(weights) + + def loss(y_true, y_pred): + # scale predictions so that the class probas of each sample sum to 1 + y_pred /= K.sum(y_pred, axis=-1, keepdims=True) + # clip to prevent NaN's and Inf's + y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon()) + # calc + loss = y_true * K.log(y_pred) * weights + loss = -K.sum(loss, -1) + return loss + + return loss + +def loss(y_true, y_pred): + # scale predictions so that the class probas of each sample sum to 1 + y_pred /= K.sum(y_pred, axis=-1, keepdims=True) + # clip to prevent NaN's and Inf's + y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon()) + # calc + loss = y_true * K.log(y_pred) * weights + loss = -K.sum(loss, -1) + return loss + def get_pad_distances(unet_shape, img_height, img_width): '''Finds padding and trimming sizes to make the input image the same as the size expected by the U-net model. @@ -2352,13 +2506,110 @@ def get_pad_distances(unet_shape, img_height, img_width): return pad_dict +def prediction_post_processing(predictions, pad_dict, unet_shape, mode='segment', remote=False): + + if mode == 'segment': + cellClassThreshold = params[mode]['cell_class_threshold'] + min_object_size = params['segment']['min_object_size'] + + elif mode == 'foci': + cellClassThreshold = params[mode]['focus_threshold'] + if cellClassThreshold == 'None': # yaml imports None as a string + cellClassThreshold = False + + # remove padding including the added last dimension + predictions = predictions[:, pad_dict['top_pad']:unet_shape[0]-pad_dict['bottom_pad'], + pad_dict['left_pad']:unet_shape[1]-pad_dict['right_pad'], 0] + + # pad back incase the image had been trimmed + predictions = np.pad(predictions, + ((0,0), + (0,pad_dict['bottom_trim']), + (0,pad_dict['right_trim'])), + mode='constant') + + if params[mode]['save_predictions']: + + # save out the segmented image + if remote: + information("Saving predictions in remote mode is not yet supported!!!!!!!!\n \ + Moving on to saving your segmentation result.") + + elif mode == 'segment': + if not os.path.isdir(params['pred_dir']): + os.makedirs(params['pred_dir']) + + pred_filename = params['experiment_name'] + '_xy%03d_p%04d_%s.tif' % (fov_id, peak_id, params['pred_img']) + elif mode == 'foci': + pred_filename = params['experiment_name'] + '_xy%03d_p%04d_%s.tif' % (fov_id, peak_id, params['foci_pred_img']) + if not os.path.isdir(params['foci_pred_dir']): + os.makedirs(params['foci_pred_dir']) + + int_preds = (predictions * 255).astype('uint8') + + if not remote: + if mode == 'segment': + tiff.imsave(os.path.join(params['pred_dir'], pred_filename), + int_preds, compress=4) + elif mode == 'foci': + tiff.imsave(os.path.join(params['foci_pred_dir'], pred_filename), + int_preds, compress=4) + + # binarized and label (if there is a threshold value, otherwise, save a grayscale for debug) + if cellClassThreshold: + predictions[predictions >= cellClassThreshold] = 1 + predictions[predictions < cellClassThreshold] = 0 + predictions = predictions.astype('uint8') + + segmented_imgs = np.zeros(predictions.shape, dtype='uint8') + # process and label each frame of the channel + for frame in range(segmented_imgs.shape[0]): + if mode == 'segment': + # get rid of small holes + predictions[frame,:,:] = morphology.remove_small_holes(predictions[frame,:,:], min_object_size) + # get rid of small objects. + predictions[frame,:,:] = morphology.remove_small_objects(morphology.label(predictions[frame,:,:], connectivity=1), min_size=min_object_size) + + # remove labels which touch the boarder + predictions[frame,:,:] = segmentation.clear_border(predictions[frame,:,:]) + + # relabel now + if mode == 'segment': + segmented_imgs[frame,:,:] = morphology.label(predictions[frame,:,:], connectivity=1) + elif mode == 'foci': + segmented_imgs[frame,:,:] = morphology.label(predictions[frame,:,:], connectivity=2) + + else: # in this case you just want to scale the 0 to 1 float image to 0 to 255 + information('Converting predictions to grayscale.') + segmented_imgs = np.around(predictions * 100) + + # both binary and grayscale should be 8bit. This may be ensured above and is unneccesary + segmented_imgs = segmented_imgs.astype('uint8') + + return(segmented_imgs) + +def normalize_stack(img_stack): + med_stack = np.zeros(img_stack.shape) + selem = morphology.disk(1) + + for frame_idx in range(img_stack.shape[0]): + tmpImg = img_stack[frame_idx,...] + med_stack[frame_idx,...] = median(tmpImg, selem) + + # robust normalization of peak's image stack to 1 + max_val = np.max(med_stack) + img_stack = img_stack/max_val + img_stack[img_stack > 1] = 1 + + return(img_stack) + def segment_cells_unet(ana_peak_ids, fov_id, pad_dict, unet_shape, model): batch_size = params['segment']['batch_size'] cellClassThreshold = params['segment']['cell_class_threshold'] if cellClassThreshold == 'None': # yaml imports None as a string cellClassThreshold = False - min_object_size = params['segment']['min_object_size'] + # min_object_size = params['segment']['min_object_size'] # arguments to data generator # data_gen_args = {'batch_size':batch_size, @@ -2366,27 +2617,28 @@ def segment_cells_unet(ana_peak_ids, fov_id, pad_dict, unet_shape, model): # 'normalize_to_one':False, # 'shuffle':False} # arguments to predict_generator - predict_args = dict(use_multiprocessing=True, - workers=params['num_analyzers'], + predict_args = dict(use_multiprocessing=False, + # workers=params['num_analyzers'], verbose=1) for peak_id in ana_peak_ids: information('Segmenting peak {}.'.format(peak_id)) img_stack = load_stack(fov_id, peak_id, color=params['phase_plane']) + dt = img_stack.dtype + # sometimes a phase contrast image is missed and has no signal. + # This is a workaround for that problem + # we just take the prior timepoint's image and replace the + # missing image with it. + for k,img in enumerate(img_stack): + # if the mean phase image signal is less than 200, add its index to list + if ((dt == 'uint16') and (np.mean(img) < 200)): + img_stack[k,...] = img_stack[k-1,...] + elif ((dt == 'uint8') and (np.mean(img) < 200/(2**16-1)*(2**8-1))): + img_stack[k,...] = img_stack[k-1,...] if params['segment']['normalize_to_one'] is not None: - med_stack = np.zeros(img_stack.shape) - selem = morphology.disk(1) - - for frame_idx in range(img_stack.shape[0]): - tmpImg = img_stack[frame_idx,...] - med_stack[frame_idx,...] = median(tmpImg, selem) - - # robust normalization of peak's image stack to 1 - max_val = np.max(med_stack) - img_stack = img_stack/max_val - img_stack[img_stack > 1] = 1 + img_stack = normalize_stack(img_stack) # trim and pad image to correct size img_stack = img_stack[:, :unet_shape[0], :unet_shape[1]] @@ -2404,52 +2656,11 @@ def segment_cells_unet(ana_peak_ids, fov_id, pad_dict, unet_shape, model): shuffle=False) # keep same order # predict cell locations. This has multiprocessing built in but I need to mess with the parameters to see how to best utilize it. *** - predictions = model.predict_generator(image_generator, **predict_args) + # predictions = model.predict_generator(image_generator, **predict_args) + predictions = model.predict(image_generator, **predict_args) # post processing - # remove padding including the added last dimension - predictions = predictions[:, pad_dict['top_pad']:unet_shape[0]-pad_dict['bottom_pad'], - pad_dict['left_pad']:unet_shape[1]-pad_dict['right_pad'], 0] - - # pad back incase the image had been trimmed - predictions = np.pad(predictions, - ((0,0), - (0,pad_dict['bottom_trim']), - (0,pad_dict['right_trim'])), - mode='constant') - - if params['segment']['save_predictions']: - pred_filename = params['experiment_name'] + '_xy%03d_p%04d_%s.tif' % (fov_id, peak_id, params['pred_img']) - if not os.path.isdir(params['pred_dir']): - os.makedirs(params['pred_dir']) - int_preds = (predictions * 255).astype('uint8') - tiff.imsave(os.path.join(params['pred_dir'], pred_filename), - int_preds, compress=4) - - # binarized and label (if there is a threshold value, otherwise, save a grayscale for debug) - if cellClassThreshold: - predictions[predictions >= cellClassThreshold] = 1 - predictions[predictions < cellClassThreshold] = 0 - predictions = predictions.astype('uint8') - - segmented_imgs = np.zeros(predictions.shape, dtype='uint8') - # process and label each frame of the channel - for frame in range(segmented_imgs.shape[0]): - # get rid of small holes - predictions[frame,:,:] = morphology.remove_small_holes(predictions[frame,:,:], min_object_size) - # get rid of small objects. - predictions[frame,:,:] = morphology.remove_small_objects(morphology.label(predictions[frame,:,:], connectivity=1), min_size=min_object_size) - # remove labels which touch the boarder - predictions[frame,:,:] = segmentation.clear_border(predictions[frame,:,:]) - # relabel now - segmented_imgs[frame,:,:] = morphology.label(predictions[frame,:,:], connectivity=1) - - else: # in this case you just want to scale the 0 to 1 float image to 0 to 255 - information('Converting predictions to grayscale.') - segmented_imgs = np.around(predictions * 100) - - # both binary and grayscale should be 8bit. This may be ensured above and is unneccesary - segmented_imgs = segmented_imgs.astype('uint8') + segmented_imgs = prediction_post_processing(predictions, pad_dict, unet_shape) # save out the segmented stacks if params['output'] == 'TIFF': @@ -2472,6 +2683,66 @@ def segment_cells_unet(ana_peak_ids, fov_id, pad_dict, unet_shape, model): compression="gzip", shuffle=True, fletcher32=True) h5f.close() + +def segment_stack_unet(fname, model, mode='segment'): + ''' + Segments the channels from one fov using the U-net CNN model. + + Parameters + ---------- + fname : str + model : TensorFlow model + ''' + + # load segmentation parameters + unet_shape = (params['segment']['trained_model_image_height'], + params['segment']['trained_model_image_width']) + + img_stack = io.imread(fname) + img_height = img_stack.shape[1] + img_width = img_stack.shape[2] + + pad_dict = get_pad_distances(unet_shape, img_height, img_width) + batch_size = params[mode]['batch_size'] + + # arguments to predict_generator + predict_args = dict(use_multiprocessing=False, + verbose=1) + + if mode == 'segment': + if params['segment']['normalize_to_one'] is not None: + img_stack = normalize_stack(img_stack) + + # trim and pad image to correct size + img_stack = img_stack[:, :unet_shape[0], :unet_shape[1]] + img_stack = np.pad(img_stack, + ((0,0), + (pad_dict['top_pad'],pad_dict['bottom_pad']), + (pad_dict['left_pad'],pad_dict['right_pad'])), + mode='constant') + img_stack = np.expand_dims(img_stack, -1) # TF expects images to be 4D + # set up image generator + # image_generator = CellSegmentationDataGenerator(img_stack, **data_gen_args) + image_datagen = ImageDataGenerator() + image_generator = image_datagen.flow(x=img_stack, + batch_size=batch_size, + shuffle=False) # keep same order + + # pred = model.predict_generator(image_generator, **predict_args) + pred = model.predict(image_generator, **predict_args) + + seg = prediction_post_processing(pred, pad_dict, unet_shape, mode=mode, remote=True) + + # save out the segmented image + fname_split = fname.split('_')[:-1] + seg_filename = '_'.join(fname_split) + seg_filename = seg_filename + "_{}.tif".format(params['seg_img']) + # print(seg_filename) + + tiff.imsave(seg_filename, seg, compress=4) + + return + def segment_fov_unet(fov_id, specs, model, color=None): ''' Segments the channels from one fov using the U-net CNN model. @@ -2548,11 +2819,12 @@ def segment_foci_unet(ana_peak_ids, fov_id, pad_dict, unet_shape, model): (pad_dict['left_pad'],pad_dict['right_pad'])), mode='constant') img_stack = np.expand_dims(img_stack, -1) + # print(img_stack.dtype) # set up image generator image_generator = FocusSegmentationDataGenerator(img_stack, **data_gen_args) # predict foci locations. - predictions = model.predict_generator(image_generator, **predict_args) + predictions = model.predict(image_generator, **predict_args) # post processing # remove padding including the added last dimension @@ -2718,10 +2990,12 @@ def __data_generation(self, array_list_temp): X = X[:i,...] break + tmpImg = tmpImg.astype('float32') + # ensure image is uint8 - if tmpImg.dtype=="uint16": - tmpImg = tmpImg / 2**16 * 2**8 - tmpImg = tmpImg.astype('uint8') + # if tmpImg.dtype=="uint16": + # tmpImg = tmpImg / 2**16 * 2**8 + # tmpImg = tmpImg.astype('uint8') if self.normalize_to_one: with warnings.catch_warnings(): @@ -2855,6 +3129,7 @@ def __getitem__(self, index): # Generate data X = self.__data_generation(array_list_temp) + X = X.astype('float32') return X @@ -2867,10 +3142,13 @@ def on_epoch_end(self): def __data_generation(self, array_list_temp): 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) # Initialization - X = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.n_channels), 'uint16') - if self.normalize_to_one: - max_pixels = [] + X = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.n_channels), 'float64') + else: + X = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.n_channels), 'uint16') + + # if self.normalize_to_one: + # max_pixels = [] # Generate data for i in range(self.batch_size): @@ -2878,9 +3156,9 @@ def __data_generation(self, array_list_temp): try: tmpImg = array_list_temp[i] if self.normalize_to_one: - # tmpMedian = filters.median(tmpImg, self.selem) - tmpMax = np.max(tmpImg) - max_pixels.append(tmpMax) + medImg = filters.median(tmpImg, self.selem) + tmpImg = tmpImg/np.max(medImg) + tmpImg[tmpImg > 1] = 1 except IndexError: X = X[:i,...] break @@ -2900,13 +3178,13 @@ def __data_generation(self, array_list_temp): X[i,:,:,0] = tmpImg - if self.normalize_to_one: - channel_max = np.max(max_pixels) / (2**8 - 1) - # print("Channel max: {}".format(channel_max)) - # print("Array max: {}".format(np.max(X))) - X = X/channel_max - # print("Normalized array max: {}".format(np.max(X))) - X[X > 1] = 1 + # if self.normalize_to_one: + # channel_max = np.max(max_pixels) / (2**8 - 1) + # # print("Channel max: {}".format(channel_max)) + # # print("Array max: {}".format(np.max(X))) + # X = X/channel_max + # # print("Normalized array max: {}".format(np.max(X))) + # X[X > 1] = 1 return (X) @@ -3041,6 +3319,18 @@ def absolute_dice_loss(y_true, y_pred): loss = dice_loss(y_true, y_pred) + absolute_diff(y_true, y_pred) return loss +def weighted_bce(y_true, y_pred): + bce = losses.binary_crossentropy(y_true, y_pred) + bce = tf.expand_dims(bce, -1) + y_true = tf.cast(y_true, 'float32') + weights = y_true * 80. + 1. + weighted_loss = tf.reduce_mean(bce * weights) + return weighted_loss + +def weighted_bce_dice_loss(y_true, y_pred): + loss = weighted_bce(y_true, y_pred) + dice_loss(y_true, y_pred) + return loss + def recall_m(y_true, y_pred): true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) @@ -3240,7 +3530,7 @@ def get_tracking_model_dict(): if not 'migrate_model' in model_dict: model_dict['migrate_model'] = models.load_model(params['tracking']['migrate_model'], - custom_objects={'all_loss':all_loss, + custom_objects={'bce_dice_loss':bce_dice_loss, 'f2_m':f2_m}) if not 'child_model' in model_dict: model_dict['child_model'] = models.load_model(params['tracking']['child_model'], @@ -3248,20 +3538,21 @@ def get_tracking_model_dict(): 'f2_m':f2_m}) if not 'appear_model' in model_dict: model_dict['appear_model'] = models.load_model(params['tracking']['appear_model'], - custom_objects={'all_loss':all_loss, + custom_objects={'bce_dice_loss':bce_dice_loss, 'f2_m':f2_m}) if not 'die_model' in model_dict: model_dict['die_model'] = models.load_model(params['tracking']['die_model'], - custom_objects={'all_loss':all_loss, + custom_objects={'bce_dice_loss':bce_dice_loss, 'f2_m':f2_m}) if not 'disappear_model' in model_dict: model_dict['disappear_model'] = models.load_model(params['tracking']['disappear_model'], - custom_objects={'all_loss':all_loss, + custom_objects={'bce_dice_loss':bce_dice_loss, 'f2_m':f2_m}) if not 'born_model' in model_dict: model_dict['born_model'] = models.load_model(params['tracking']['born_model'], - custom_objects={'all_loss':all_loss, + custom_objects={'bce_dice_loss':bce_dice_loss, 'f2_m':f2_m}) + # In future work, may implement predictors for cell number in a given region. # if not 'zero_cell_model' in model_dict: # model_dict['zero_cell_model'] = models.load_model(params['tracking']['zero_cell_model'], # custom_objects={'absolute_dice_loss':absolute_dice_loss, @@ -3281,6 +3572,33 @@ def get_tracking_model_dict(): return(model_dict) +def get_focus_tracking_model_dict(): + + model_dict = {} + + if not 'all_model' in model_dict: + model_dict['all_model'] = models.load_model(params['foci']['all_model'], + custom_objects={'bce_dice_loss':bce_dice_loss, + 'f2_m':f2_m, + 'f1_m':f1_m, + 'weighted_categorical_crossentropy':weighted_categorical_crossentropy, + 'loss':loss}) + + # if not 'migrate_model' in model_dict: + # model_dict['migrate_model'] = models.load_model(params['foci']['migrate_model'], + # custom_objects={'bce_dice_loss':bce_dice_loss, + # 'f2_m':f2_m}) + # if not 'appear_model' in model_dict: + # model_dict['appear_model'] = models.load_model(params['foci']['appear_model'], + # custom_objects={'bce_dice_loss':bce_dice_loss, + # 'f2_m':f2_m}) + # if not 'disappear_model' in model_dict: + # model_dict['disappear_model'] = models.load_model(params['foci']['disappear_model'], + # custom_objects={'bce_dice_loss':bce_dice_loss, + # 'f2_m':f2_m}) + + return(model_dict) + # Creates lineage for a single channel def make_lineage_chnl_stack(fov_and_peak_id): ''' @@ -3513,9 +3831,11 @@ def __init__(self, detection_id, region, t): self.area = region.area # calculating cell length and width by using Feret Diamter. These values are in pixels - length_tmp, width_tmp = feretdiameter(region) - if length_tmp == None: - mm3.warning('feretdiameter() failed for ' + self.id + ' at t=' + str(t) + '.') + # length_tmp, width_tmp = feretdiameter(region) + length_tmp = region.major_axis_length + width_tmp = region.minor_axis_length + # if length_tmp == None: + # mm3.warning('feretdiameter() failed for ' + self.id + ' at t=' + str(t) + '.') self.length = length_tmp self.width = width_tmp @@ -3769,6 +4089,63 @@ def get_cell(self, cell_id): def get_top_from_cell(self, cell_id): pass +class OrphanFocusCell(): + ''' + The CellFromGraph class is one cell that has been born. + It is not neccesarily a cell that has divided. + ''' + + # initialize (birth) the cell + def __init__(self, fov_id, peak_id, t): + '''The cell must be given a unique cell_id and passed the region + information from the segmentation + + Parameters + __________ + + cell_id : str + cell_id is a string in the form fXpXtXrX + f is 3 digit FOV number + p is 4 digit peak number + t is 4 digit time point at time of birth + r is region label for that segmentation + Use the function create_cell_id to do return a proper string. + + region : region properties object + Information about the labeled region from + skimage.measure.regionprops() + + parent_id : str + id of the parent if there is one. + ''' + + # create all the attributes + # id + self.id = 'orphan' + + # identification convenience + self.fov = fov_id + self.peak = peak_id + self.birth_label = 'orphan' + self.regions = None + + # parent is a CellFromGraph object, can be None + self.parent = None + + # daughters is updated when cell divides + # if this is none then the cell did not divide + self.daughters = None + + # birth and division time + self.birth_time = t + self.division_time = None # filled out if cell divides + + # the following information is on a per timepoint basis + self.times = [t] + self.abs_times = [params['time_table'][self.fov][t]] # elapsed time in seconds + self.centroids = [(0,0)] + + # this is the object that holds all information for a cell class CellFromGraph(): ''' @@ -4069,13 +4446,9 @@ def make_long_df(self): return(df) -# this is the object that holds all information for a fluorescent focus -# this class can eventually be used in focus tracking, much like the Cell class -# is used for cell tracking -class Focus(): +class OrphanFocus(): ''' - The Focus class holds information on fluorescent foci. - A single focus can be present in multiple different cells. + The OrphanFocus class holds information on fluorescent foci that are not in any cell object. ''' # initialize the focus @@ -4091,7 +4464,7 @@ def __init__(self, Parameters __________ - cell : a Cell object + cell : an OrphanFocusCell object region : region properties object Information about the labeled region from @@ -4118,29 +4491,90 @@ def __init__(self, self.regions = [region] self.fov = cell.fov self.peak = cell.peak + self.labels = [region.label] + self.daughters = None + self.parent = None - # cell is a CellFromGraph object - # cells are added later using the .add_cell method + # cell is an OrphanFocusCell object self.cells = [cell] - # daughters is updated when focus splits - # if this is none then the focus did not split - self.parent = None - self.daughters = None - self.merger_partner = None - # appearance and split time self.appear_time = t - self.split_time = None # filled out if focus splits - - # the following information is on a per timepoint basis self.times = [t] - self.abs_times = [params['time_table'][cell.fov][t]] # elapsed time in seconds - self.labels = [region.label] - self.bboxes = [region.bbox] - self.areas = [region.area] - # calculating focus length and width by using Feret Diamter. +# this is the object that holds all information for a fluorescent focus +# this class can eventually be used in focus tracking, much like the Cell class +# is used for cell tracking +class Focus(): + ''' + The Focus class holds information on fluorescent foci. + A single focus can be present in multiple different cells. + ''' + + # initialize the focus + def __init__(self, + cell, + region, + seg_img, + intensity_image, + t): + '''The cell must be given a unique cell_id and passed the region + information from the segmentation + + Parameters + __________ + + cell : a Cell object + + region : region properties object + Information about the labeled region from + skimage.measure.regionprops() + + seg_img : 2D numpy array + Labelled image of cell segmentations + + intensity_image : 2D numpy array + Fluorescence image with foci + ''' + + # create all the attributes + # id + focus_id = create_focus_id(region, + t, + cell.peak, + cell.fov, + experiment_name=params['experiment_name']) + self.id = focus_id + + # identification convenience + self.appear_label = int(region.label) + self.regions = [region] + self.fov = cell.fov + self.peak = cell.peak + + # cell is a CellFromGraph object + # cells are added later using the .add_cell method + self.cells = [cell] + + # daughters is updated when focus splits + # if this is none then the focus did not split + self.parent = None + self.daughters = None + self.merger_partner = None + + # appearance and split time + self.appear_time = t + self.split_time = None # filled out if focus splits + + # the following information is on a per timepoint basis + self.times = [t] + self.abs_times = [params['time_table'][cell.fov][t]] # elapsed time in seconds + self.labels = [region.label] + # self.cell_labels = [region.cell_centric_label] + self.bboxes = [region.bbox] + self.areas = [region.area] + + # calculating focus length and width by using Feret Diamter. # These values are in pixels # NOTE: in the future, update to straighten a focus an get straightened length/width # print(region) @@ -4160,7 +4594,7 @@ def __init__(self, self.orientations = [region.orientation] self.centroids = [region.centroid] - # special information for focci + # special information for foci self.elong_rate = None self.disappear = None self.area_mean_fluorescence = [] @@ -4203,6 +4637,7 @@ def grow(self, self.times.append(t) self.abs_times.append(params['time_table'][self.cells[-1].fov][t]) self.labels.append(region.label) + # self.cell_labels.append(region.cell_centric_label) self.bboxes.append(region.bbox) self.areas.append(region.area) self.regions.append(region) @@ -4238,11 +4673,15 @@ def calculate_fluorescence(self, # get the focus' displacement from center of cell # find x and y position relative to the whole image (convert from small box) - # calculate distance of foci from middle of cell (scikit image) + # calculate distance of focus from middle of cell (scikit image) orientation = region.orientation if orientation < 0: orientation = np.pi+orientation + # print('focus labels: {}'.format(self.labels)) + # print('cell labels: {}'.format(self.cells[-1].labels)) + # print('cell times: {}'.format(self.cells[-1].times)) + # print('focus times: {}'.format(self.times)) cell_idx = self.cells[-1].times.index(self.times[-1]) # final time in self.times is current time cell_centroid = self.cells[-1].centroids[cell_idx] focus_centroid = region.centroid @@ -4347,19 +4786,61 @@ def make_long_df(self): return(df) +def sort_regions_in_list(regions): + + y_positions = [] + for reg in regions: + y,_ = reg.centroid + y_positions.append(y) + + order = np.argsort(y_positions) + # print(order) + sorted_regs = [regions[idx] for idx in order] + return(sorted_regs) + + class PredictTrackDataGenerator(utils.Sequence): '''Generates data for running tracking class preditions Input is a stack of labeled images''' def __init__(self, data, batch_size=32, - dim=(4,5,9)): + dim=(4,5,9), + img_dim=(5,256,32), + track_type = 'cells', + images = False, + img_stack = None): 'Initialization' self.batch_size = batch_size self.data = data self.dim = dim self.on_epoch_end() + self.track_type = track_type + self.images = images + self.img_dim = img_dim + + if images: + + unet_shape = ( + params['segment']['trained_model_image_height'], + params['segment']['trained_model_image_width'] + ) + + pad_dict = get_pad_distances( + self.img_dim[1:], + img_stack.shape[1], + img_stack.shape[2] + ) + + img_stack = img_stack[:, :unet_shape[0], :unet_shape[1]] + self.img_stack = np.pad( + img_stack, + ((0,0), # add nothing to the time dimension + (pad_dict['top_pad'],pad_dict['bottom_pad']), + (pad_dict['left_pad'],pad_dict['right_pad'])), + mode='constant' + ) def __len__(self): 'Denotes the number of batches per epoch' @@ -4384,13 +4865,14 @@ def __data_generation(self, batch_indices): # Initialization # shape is (batch_size, max_cell_num, frame_num, cell_feature_num, 1) X = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.dim[2], 1)) + if self.images: + img_X = np.zeros((self.batch_size, self.dim[1], self.img_dim[1], self.img_dim[2])) # Generate data for idx in batch_indices: start_idx = idx-2 end_idx = idx+3 -# print(start_idx, end_idx) if start_idx < 0: batch_frame_list = [] for empty_idx in range(abs(start_idx)): @@ -4407,13 +4889,14 @@ def __data_generation(self, batch_indices): for i,frame_region_list in enumerate(batch_frame_list): - # shape is (max_cell_num, frame_num, cell_feature_num) -# tmp_x = np.zeros((self.dim[0], self.dim[1], self.dim[2])) - if not frame_region_list: continue - for region_idx, region, in enumerate(frame_region_list): + if self.images: + img_X[idx, i, :, :] = self.img_stack[start_idx+i,...] # add i to start_idx to get correct img_stack frame + + for region_idx,region in enumerate(frame_region_list): + y,x = region.centroid bbox = region.bbox orientation = region.orientation @@ -4425,15 +4908,21 @@ def __data_generation(self, batch_indices): length = region.major_axis_length cell_label = region.label cell_index = cell_label - 1 - cell_info = (min_x, max_x, x, min_y, max_y, y, orientation, area, length) - - if region_idx + 1 > self.dim[0]: + if self.track_type == 'cells': + cell_info = (min_x, max_x, x, min_y, max_y, y, orientation, area, length) + elif self.track_type == 'foci': + mean_fluor = region.mean_intensity + max_fluor = region.max_intensity + cell_info = (min_x, max_x, x, min_y, max_y, y, orientation, area, length, mean_fluor, max_fluor) + + # only take self.dim[0] number of regions + if cell_label > self.dim[0]: continue - # supplement tmp_x at (region_idx, ) -# tmp_x[region_idx, i, :] = cell_info + X[idx, cell_index, i, :,0] = cell_info - X[idx, cell_index, i, :,0] = cell_info # tmp_x + if self.images: + return (X, img_X) return X @@ -4503,9 +4992,9 @@ def create_lineages_from_graph(graph, experiment_name=params['experiment_name']) current_cell = CellFromGraph(cell_id, - prior_node_region, - prior_node_time, - parent=None) + prior_node_region, + prior_node_time, + parent=None) if not cell_id in tracks.keys(): tracks[cell_id] = current_cell @@ -4561,6 +5050,8 @@ def create_lineages_from_graph(graph, # move on. Otherwise, append to our list if graph.nodes[successor_node_id]['visited']: continue + if graph.nodes[successor_node_id]['has_input']: + continue else: unvisited_node_ids.append(successor_node_id) @@ -4615,6 +5106,7 @@ def create_lineages_from_graph(graph, current_cell.add_daughter(new_cell, new_cell_time) # initialize a scores array to select highest score from the available options + # This is working toward identification of the second daughter cell. unvisited_detection_nodes = [unvisited_node_id for unvisited_node_id in unvisited_node_ids if unvisited_node_id.startswith(params['experiment_name'])] child_scores = np.zeros(len(unvisited_detection_nodes)) @@ -4622,22 +5114,23 @@ def create_lineages_from_graph(graph, for i in range(len(unvisited_detection_nodes)): successor_node_id = unvisited_detection_nodes[i] if successor_node_id == next_node_id: - child_scores[i] = -np.inf + child_scores[i] = -np.inf # give already-identified child a score of -inf so as not to re-identify it as the second daughter continue child_score = get_score_by_type(prior_node_id, successor_node_id, graph, score_type='child') child_scores[i] = child_score try: - second_daughter_score = np.max(child_scores) + # second_daughter_score = np.max(child_scores) # sometimes a second daughter doesn't exist: perhaps parent is at mouth of a trap and one # daughter is lost to the central channel at division time. In this case, do the following: - if second_daughter_score < np.log(0.5): + if len(child_scores) == 1: current_cell = new_cell else: second_daughter_index = np.argmax(child_scores) # grab the node_id corresponding to traversing the highest-scoring edge from the prior node other_daughter_node_id = unvisited_detection_nodes[second_daughter_index] + graph.nodes[other_daughter_node_id]['has_input'] = True other_daughter_cell_time = graph.nodes[other_daughter_node_id]['time'] other_daughter_cell_region = graph.nodes[other_daughter_node_id]['region'] @@ -4655,6 +5148,14 @@ def create_lineages_from_graph(graph, tracks[other_daughter_cell_id] = other_daughter_cell current_cell.add_daughter(other_daughter_cell, new_cell_time) + + ############################## TO DO: ################################################# + #### fix problem where other_daughter_cell doesn't get 'visited' here ################# + #### I can't just 'visit' it here, because that would block it from use ############### + #### downstream in the graph. I may need to split visited into ################ + #### has_input and has_output. I may also just be able to add has_input ############### + ####################################################################################### + # now we remove current_cell, since it's done, and move on to one of the daughters current_cell = new_cell @@ -4704,7 +5205,287 @@ def create_lineages_from_graph(graph, if same_iter_num > 10: print("WARNING: Ten iterations surpassed without decreasing the number of visited nodes.\n \ - Breaking tracking loop now. You should probably not trust these results.") + Breaking tracking loop now.\n \ + You should not trust these results!") + break + + return tracks + +def get_focus_cell_distance(focus_label_img, cell_label_img, focus_regions): + + dist_arr = np.zeros((len(focus_regions),np.max(cell_label_img))) + cell_regions = measure.regionprops(cell_label_img) + + for i,focus_region in enumerate(focus_regions): + + focus_y,focus_x = focus_region.centroid + focus_label = focus_region.label + + for j,cell_region in enumerate(cell_regions): + cell_y,cell_x = cell_region.centroid + cell_label = cell_region.label + bin_img = np.ones(focus_label_img.shape) + bin_img[focus_label_img == focus_label] = 0 + bin_img[cell_label_img == cell_label] = 0 + dist_img = ndi.morphology.distance_transform_edt(bin_img, sampling=None, return_distances=True, return_indices=False, distances=None, indices=None) + line_y, line_x = np.linspace(focus_x, cell_x, 100), np.linspace(focus_y, cell_y, 100) + + # Extract the values along the line, using cubic interpolation + line_dists = ndi.map_coordinates(dist_img, np.vstack((line_x,line_y)), order=1) + dist_arr[i,j] = np.max(line_dists) + + # print(dist_arr.shape) + return(dist_arr) + +def create_focus_lineages_from_graph(graph, + graph_df, + fov_id, + peak_id, + Cells, + max_cell_number): + ''' + This function iterates through nodes in a graph of detections + to link the nodes as "Focus" objects, eventually + leading to the ultimate goal of returning + a CellTree object with each cell's information for the experiment. + ''' + + # keep cells with this fov_id/peak_id + fov_cells = filter_cells(Cells, attr='fov', val=fov_id) + peak_cells = filter_cells(fov_cells, attr='peak', val=peak_id) + + # read in focus label images and fluorescence intensity images + seg_cell_stack = load_stack(fov_id, peak_id, color='seg_unet') + seg_stack = load_stack(fov_id, peak_id, color='foci_seg_unet') + fluor_stack = load_stack(fov_id, peak_id, color=params['foci']['foci_plane']) + + # iterate through all nodes in graph + tracks = {} + + for node_id in graph.nodes: + graph.nodes[node_id]['visited'] = False + graph_df['visited'] = False + num_unvisited = count_unvisited(graph, params['experiment_name']) + + while num_unvisited > 0: + + # which detection nodes are not yet visited + unvisited_detection_nodes = graph_df[(~(graph_df.visited) & graph_df.node_id.str.startswith(params['experiment_name']))] + # grab the first unvisited node_id from the dataframe + prior_node_id = unvisited_detection_nodes.iloc[0,1] + prior_node_time = graph.nodes[prior_node_id]['time'] + prior_node_region = graph.nodes[prior_node_id]['region'] + + # print('Prior node id: {}'.format(prior_node_id)) + # print('Prior node time: {}'.format(prior_node_time)) + # print('Prior node region: {}'.format(prior_node_region)) + + focus_id = create_focus_id( + prior_node_region, + prior_node_time, + peak_id, + fov_id, + experiment_name=params['experiment_name'] + ) + + # print('focus id: {}'.format(focus_id)) + + frame_cells = filter_cells_containing_val_in_attr( + peak_cells, + attr='times', + val=prior_node_time, # putting here for now. check this. with logic below I'm note sure this is right + ) + + # print(frame_cells) + + seg_cell_img = seg_cell_stack[prior_node_time - 1,:,:] # putting here for now. check this. with logic below I'm note sure this is right + seg_foci_img = seg_stack[prior_node_time - 1,:,:] # putting here for now. check this. with logic below I'm note sure this is right + fluor_img = fluor_stack[prior_node_time - 1,:,:] # putting here for now. check this. with logic below I'm note sure this is right + + cell = get_focus_cell( + frame_cells, + seg_cell_img, + seg_foci_img, + prior_node_region, # putting here for now. check this. with logic below I'm note sure this is right + prior_node_time, + max_cell_number + ) + + # print(cell) + if cell == 'too far down': + cell = OrphanFocusCell(fov_id, peak_id, prior_node_time) + + current_focus = Focus( + cell, + prior_node_region, + seg_foci_img, + fluor_img, + prior_node_time, + ) + + if cell == 'too far down': + disappear_time = graph.nodes[prior_node_id]['time'] + disappear_region = graph.nodes[prior_node_id]['region'] + current_focus.disappears(disappear_region, disappear_time) + continue + + if not focus_id in tracks.keys(): + tracks[focus_id] = current_focus + else: + current_focus = tracks[focus_id] + + # for use later in establishing predecessors + current_node_id = prior_node_id + + # set this detection's "visited" status to True in the graph and in the dataframe + graph.nodes[prior_node_id]['visited'] = True + graph_df.iloc[np.where(graph_df.node_id==prior_node_id)[0][0],3] = True + + # build current_track list to this detection's node + current_track = collections.deque() + current_track.append(current_node_id) + predecessors_list = [k for k in graph.predecessors(prior_node_id)] + unvisited_predecessors_list = [k for k in predecessors_list if not graph.nodes[k]['visited']] + + while len(unvisited_predecessors_list) != 0: + + # initialize a scores array to select highest score from the available options + predecessor_scores = np.zeros(len(unvisited_predecessors_list)) + + # populate array with scores + for i in range(len(unvisited_predecessors_list)): + predecessor_node_id = unvisited_predecessors_list[i] + edge_type, edge_score = get_greatest_score_info(predecessor_node_id, current_node_id, graph) + predecessor_scores[i] = edge_score + + # find highest score + max_index = np.argmax(predecessor_scores) + # grab the node_id corresponding to traversing the highest-scoring edge from the prior node + current_node_id = unvisited_predecessors_list[max_index] + current_track.appendleft(current_node_id) + + predecessors_list = [k for k in graph.predecessors(current_node_id)] + unvisited_predecessors_list = [k for k in predecessors_list if not graph.nodes[k]['visited']] + + while prior_node_id is not 'B': + + # which nodes succeed our current node? + successor_node_ids = [node_id for node_id in graph.successors(prior_node_id)] + + # keep only the potential successor detections that have not yet been visited + unvisited_node_ids = [] + for i,successor_node_id in enumerate(successor_node_ids): + + # if it starts with params['experiment_name'], it is a detection node, and not born, appear, etc. + if successor_node_id.startswith(params['experiment_name']): + + # if it has been used in the focus track graph, i.e., if 'visited' is True, + # move on. Otherwise, append to our list + if graph.nodes[successor_node_id]['visited']: + continue + if graph.nodes[successor_node_id]['has_input']: + continue + else: + unvisited_node_ids.append(successor_node_id) + + # if it doesn't start with params['experiment_name'], it is a born, appear, etc., and should always be appended + else: + unvisited_node_ids.append(successor_node_id) + + # initialize a scores array to select highest score from the available options + successor_scores = np.zeros(len(unvisited_node_ids)) + successor_edge_types = [] + + # populate array with scores + for i in range(len(unvisited_node_ids)): + successor_node_id = unvisited_node_ids[i] + edge_type, edge_score = get_greatest_score_info(prior_node_id, successor_node_id, graph) + successor_scores[i] = edge_score + successor_edge_types.append(edge_type) + + # find highest score + max_score = np.max(successor_scores) + max_index = np.argmax(successor_scores) + # grab the node_id corresponding to traversing the highest-scoring edge from the prior node + next_node_id = unvisited_node_ids[max_index] + max_edge_type = successor_edge_types[max_index] + + # if the max_score in successor_scores isn't greater than log(0.1), just make the focus disappear for now. + if max_score < np.log(0.1): + max_edge_type = 'disappear' + next_node_id = [n_id for n_id in unvisited_node_ids if n_id.startswith('disappear')][0] + + # if this is a migration, grow the current_focus. + elif max_edge_type == 'migrate': + + focus_time = graph.nodes[next_node_id]['time'] + focus_region = graph.nodes[next_node_id]['region'] + + frame_cells = filter_cells_containing_val_in_attr( + peak_cells, + attr='times', + val=focus_time, + ) + + seg_cell_img = seg_cell_stack[focus_time - 1,:,:] + seg_foci_img = seg_stack[focus_time - 1,:,:] + fluor_img = fluor_stack[focus_time - 1,:,:] + + cell = get_focus_cell( + frame_cells, + seg_cell_img, + seg_foci_img, + focus_region, + focus_time, + max_cell_number + ) + + # print(cell) + + # if the cell's label was greater than the max number of cells tracked, it won't exist, + # so get_focus_cell returns 'too far down'. Here, we just have the focus disappear in this case. + if cell == 'too far down': + disappear_time = graph.nodes[prior_node_id]['time'] + disappear_region = graph.nodes[prior_node_id]['region'] + current_focus.disappears(disappear_region, disappear_time) + + else: + current_focus.grow( + focus_region, + focus_time, + seg_foci_img, + fluor_img, + cell + ) + + # if the event represents disappearance, end the focus + elif max_edge_type == 'disappear': + + if prior_node_id.startswith(params['experiment_name']): + disappear_time = graph.nodes[prior_node_id]['time'] + disappear_region = graph.nodes[prior_node_id]['region'] + current_focus.disappears(disappear_region, disappear_time) + + # set the next node to 'visited' + graph.nodes[next_node_id]['visited'] = True + if next_node_id != 'B': + graph_df.iloc[np.where(graph_df.node_id==next_node_id)[0][0],3] = True + + # reset prior_node_id to iterate to next frame and append node_id to current track + prior_node_id = next_node_id + + if num_unvisited != count_unvisited(graph, params['experiment_name']): + same_iter_num = 0 + else: + same_iter_num += 1 + + num_unvisited = count_unvisited(graph, params['experiment_name']) + print("{} detections remain unvisited.".format(num_unvisited)) + + if same_iter_num > 10: + print("WARNING: Ten iterations surpassed without decreasing the number of visited nodes.\n \ + Breaking tracking loop now.\n \ + You should not trust these results!") break return tracks @@ -5231,7 +6012,7 @@ def feretdiameter(region): cosorient = np.cos(region.orientation) sinorient = np.sin(region.orientation) # print(cosorient, sinorient) - amp_param = 1.2 #amplifying number to make sure the axis is longer than actual cell length + amp_param = 1.2 # amplifying number to make sure the axis is longer than actual cell length # coordinates relative to bounding box # r_coords = region.coords - [np.int16(region.bbox[0]), np.int16(region.bbox[1])] @@ -5241,6 +6022,10 @@ def feretdiameter(region): distance_image = ndi.distance_transform_edt(region_binimg) r_coords = np.where(distance_image == 1) r_coords = list(zip(r_coords[0], r_coords[1])) + if params['foci']['debug']: + print(r_coords) + io.imshow(distance_image) + plt.show(); # coordinates are already sorted by y. partion into top and bottom to search faster later # if orientation > 0, L1 is closer to top of image (lower Y coord) @@ -5270,12 +6055,26 @@ def feretdiameter(region): # pt_L1 = r_coords[np.argmin([np.sqrt(np.power(Pt[0]-L1_pt[0],2) + np.power(Pt[1]-L1_pt[1],2)) for Pt in r_coords])] # pt_L2 = r_coords[np.argmin([np.sqrt(np.power(Pt[0]-L2_pt[0],2) + np.power(Pt[1]-L2_pt[1],2)) for Pt in r_coords])] - try: - pt_L1 = L1_coords[np.argmin([np.sqrt(np.power(Pt[0]-L1_pt[0],2) + np.power(Pt[1]-L1_pt[1],2)) for Pt in L1_coords])] - pt_L2 = L2_coords[np.argmin([np.sqrt(np.power(Pt[0]-L2_pt[0],2) + np.power(Pt[1]-L2_pt[1],2)) for Pt in L2_coords])] - length = np.sqrt(np.power(pt_L1[0]-pt_L2[0],2) + np.power(pt_L1[1]-pt_L2[1],2)) - except: - length = None + # try: + L1_euclids = [np.sqrt(np.power(Pt[0]-L1_pt[0],2) + np.power(Pt[1]-L1_pt[1],2)) for Pt in L1_coords] + pt_L1 = L1_coords[np.argmin(L1_euclids)] + L2_euclids = [np.sqrt(np.power(Pt[0]-L2_pt[0],2) + np.power(Pt[1]-L2_pt[1],2)) for Pt in L2_coords] + if params['foci']['debug']: + print('orientation: {}'.format(region.orientation)) + print(L2_coords) + print(L2_euclids) + pt_L2 = L2_coords[np.argmin(L2_euclids)] + if params['foci']['debug']: + print(pt_L1) + print(pt_L2) + length = np.sqrt(np.power(pt_L1[0]-pt_L2[0],2) + np.power(pt_L1[1]-pt_L2[1],2)) + # except: + # length = None + + # if params['foci']['debug']: + # # if length is None: + # io.imshow(region.image) + # plt.show(); ##################### # calculate cell width @@ -5291,6 +6090,8 @@ def feretdiameter(region): W_coords.append(r_coords[:int(np.round(len(r_coords)/2))]) # starting points + # print(cosorient) + # print(length) x1 = x0 + cosorient * 0.5 * length*0.4 y1 = y0 - sinorient * 0.5 * length*0.4 x2 = x0 - cosorient * 0.5 * length*0.4 @@ -5476,7 +6277,7 @@ def initialize_track_graph(peak_id, if det.area is not None: # if the detection represents a segmentation from our imaging, add its ID, # which is also its key in detection_dict, as a node in G - G.add_node(det.id, visited=False, cell_count=1, region=region, time=timepoint) + G.add_node(det.id, visited=False, cell_count=1, region=region, time=timepoint, has_input=False) timepoint_list.append(timepoint) node_id_list.append(detection_id) region_label_list.append(region.label) @@ -5616,32 +6417,225 @@ def initialize_track_graph(peak_id, 'region_label':region_label_list}) return(G, graph_df) -# function for a growing cell, used to calculate growth rate -def cell_growth_func(t, sb, elong_rate): - ''' - Assumes you have taken log of the data. - It also allows the size at birth to be a free parameter, rather than fixed - at the actual size at birth (but still uses that as a guess) - Assumes natural log, not base 2 (though I think that makes less sense) +def initialize_focus_track_graph(peak_id, + fov_id, + experiment_name, + predictions_dict, + regions_by_time, + max_focus_number=6, + appear_threshold=0.5): - old form: sb*2**(alpha*t) - ''' - return sb+elong_rate*t + detection_dict = {} + frame_num = predictions_dict['migrate_model_predictions'].shape[0] -# functions for checking if a cell has divided or not -# this function should also take the variable t to -# weight the allowed changes by the difference in time as well -def check_growth_by_region(cell, region): - '''Checks to see if it makes sense - to grow a cell by a particular region''' - # load parameters for checking - max_growth_length = params['track']['max_growth_length'] - min_growth_length = params['track']['min_growth_length'] - max_growth_area = params['track']['max_growth_area'] - min_growth_area = params['track']['min_growth_area'] + ebunch = [] - # check if length is not too much longer - if cell.lengths[-1]*max_growth_length < region.major_axis_length: + G = nx.MultiDiGraph() + # create common start point + G.add_node('A') + # create common end point + G.add_node('B') + + last_frame = False + + node_id_list = [] + timepoint_list = [] + region_label_list = [] + + # for k,vals in predictions_dict.items(): + # print(k) + # print(vals.shape) + # pprint(predictions_dict) + + for frame_idx in range(frame_num): + # print(frame_idx) + + timepoint = frame_idx + 1 + paired_detection_time = timepoint+1 + + # get detections for this frame + frame_regions_list = regions_by_time[frame_idx] + + # if we're at the end of the imaging, make all cells migrate to node 'B' + if timepoint == frame_num: + last_frame = True + else: + paired_frame_regions_list = regions_by_time[frame_idx+1] + + # get state change probabilities (class predictions) for this frame + frame_prediction_dict = {key:val[frame_idx,...] for key,val in predictions_dict.items() if key != 'general_model_predictions'} + + # create the "will appear" nodes for this frame + prior_appear_state = 'appear_{:0=4}'.format(timepoint-1) + appear_state = 'appear_{:0=4}'.format(timepoint) + G.add_node(appear_state, visited=False, time=timepoint) + + if frame_idx == 0: + ebunch.append(('A', appear_state, 'start', {'weight':appear_threshold, 'score':1*np.log(appear_threshold)})) + + # create the "Disappeared" nodes to link from prior frame + prior_disappear_state = 'disappear_{:0=4}'.format(timepoint-1) + disappear_state = 'disappear_{:0=4}'.format(timepoint) + next_disappear_state = 'disappear_{:0=4}'.format(timepoint+1) + G.add_node(disappear_state, visited=False, time=timepoint) + + node_id_list.extend([appear_state, disappear_state]) + timepoint_list.extend([timepoint, timepoint]) + region_label_list.extend([0,0]) + + if frame_idx > 0: + ebunch.append((prior_disappear_state, disappear_state, 'disappear', {'weight':1.1, 'score':1*np.log(1.1)})) # impossible to move out of disappear track + ebunch.append((prior_appear_state, appear_state, 'appear', {'weight':appear_threshold, 'score':1*np.log(appear_threshold)})) + + if last_frame: + ebunch.append((appear_state, 'B', 'end', {'weight':1, 'score':1*np.log(1)})) + ebunch.append((disappear_state, 'B', 'end', {'weight':1, 'score':1*np.log(1)})) + + for region_idx in range(max_focus_number): + + # the tracking models assume there are 6 detections in each frame, regardless of how many + # are actually there. Therefore, this try/except logic will catch cases where there + # were fewer than 6 detections in a frame. + try: + region = frame_regions_list[region_idx] + region_label = region.label + except IndexError: + region = None + region_label = region_idx + 1 + + # create the name for this detection + detection_id = create_detection_id( + timepoint, + peak_id, + fov_id, + region_label, + experiment_name=experiment_name + ) + + # print(region) + det = Detection(detection_id, region, timepoint) + detection_dict[det.id] = det + + if det.area is not None: + # if the detection represents a segmentation from our imaging, add its ID, + # which is also its key in detection_dict, as a node in G + G.add_node(det.id, visited=False, cell_count=1, region=region, time=timepoint, has_input=False) + timepoint_list.append(timepoint) + node_id_list.append(detection_id) + region_label_list.append(region.label) + # also set up all edges for this detection's node in our ebunch + # loop through prediction types and add each to the ebunch + + for key,val in frame_prediction_dict.items(): + + if frame_idx == 0: + + ebunch.append(('A', detection_id, 'start', {'weight':1, 'score':1*np.log(1)})) + + if last_frame: + + ebunch.append((detection_id, 'B', 'end', {'weight':1, 'score':1*np.log(1)})) + + if val.shape[0] == max_focus_number ** 2: + continue + + else: + frame_predictions = val + detection_prediction = frame_predictions[region_idx] + + if key == 'appear_model_predictions': + if frame_idx == 0: + continue + elem = (prior_appear_state, detection_id, 'appear', {'weight':detection_prediction, 'score':1*np.log(detection_prediction)}) + + ebunch.append(elem) + + else: + # if the array is cell_number^2, reshape it to cell_number x cell_number + # Then slice our detection's row and iterate over paired_cells + if val.shape[0] == max_focus_number**2: + + frame_predictions = val.reshape((max_focus_number,max_focus_number)) + detection_predictions = frame_predictions[region_idx,:] + + # loop through paired detection predictions, test whether paired detection exists + # then append the edge to our ebunch + for paired_cell_idx in range(detection_predictions.size): + + # attempt to grab the paired detection. If we get an IndexError, it doesn't exist. + try: + paired_detection = paired_frame_regions_list[paired_cell_idx] + except IndexError: + continue + + # create the paired detection's id for use in our ebunch + paired_detection_id = create_detection_id( + paired_detection_time, + peak_id, + fov_id, + paired_detection.label, + experiment_name=experiment_name + ) + + paired_prediction = detection_predictions[paired_cell_idx] + + if 'migrate_' in key: + migrate_weight = paired_prediction + elem = (detection_id, paired_detection_id, 'migrate', {'migrate_weight':migrate_weight, 'score':1*np.log(migrate_weight)}) + ebunch.append(elem) + + # if the array is cell_number long, do similar stuff as above. + elif val.shape[0] == max_focus_number: + + frame_predictions = val + detection_prediction = frame_predictions[region_idx] + + if key == 'appear_model_predictions': + if frame_idx == 0: + continue + # print("Linking {} to {}.".format(prior_appear_state, detection_id)) + elem = (prior_appear_state, detection_id, 'appear', {'weight':detection_prediction, 'score':1*np.log(detection_prediction)}) + + elif 'disappear_' in key: + if last_frame: + continue + # print("Linking {} to {}.".format(detection_id, next_disappear_state)) + elem = (detection_id, next_disappear_state, 'disappear', {'weight':detection_prediction, 'score':1*np.log(detection_prediction)}) + + ebunch.append(elem) + + G.add_edges_from(ebunch) + graph_df = pd.DataFrame(data={'timepoint':timepoint_list, + 'node_id':node_id_list, + 'region_label':region_label_list}) + return(G, graph_df) + +# function for a growing cell, used to calculate growth rate +def cell_growth_func(t, sb, elong_rate): + ''' + Assumes you have taken log of the data. + It also allows the size at birth to be a free parameter, rather than fixed + at the actual size at birth (but still uses that as a guess) + Assumes natural log, not base 2 (though I think that makes less sense) + + old form: sb*2**(alpha*t) + ''' + return sb+elong_rate*t + +# functions for checking if a cell has divided or not +# this function should also take the variable t to +# weight the allowed changes by the difference in time as well +def check_growth_by_region(cell, region): + '''Checks to see if it makes sense + to grow a cell by a particular region''' + # load parameters for checking + max_growth_length = params['track']['max_growth_length'] + min_growth_length = params['track']['min_growth_length'] + max_growth_area = params['track']['max_growth_area'] + min_growth_area = params['track']['min_growth_area'] + + # check if length is not too much longer + if cell.lengths[-1]*max_growth_length < region.major_axis_length: return False # check if it is not too short (cell should not shrink really) @@ -5747,15 +6741,21 @@ def filter_foci(Foci, label, t, debug=False): Filtered_Foci = {} - for focus_id, focus in Foci.items(): + for focus_id,focus in Foci.items(): # copy the times list so as not to update it in-place times = focus.times if debug: + print(t) print(times) match_inds = [i for i,time in enumerate(times) if time == t] labels = [focus.labels[idx] for idx in match_inds] + + if debug: + print(match_inds) + print(label) + print(labels) if label in labels: Filtered_Foci[focus_id] = focus @@ -5780,6 +6780,18 @@ def filter_cells(Cells, attr, val, idx=None, debug=False): return Filtered_Cells +def filter_orphan_foci(Foci): + '''Return only cells whose designated attribute equals "val".''' + + Filtered_Foci = {} + + for focus_id, focus in Foci.items(): + if isinstance(focus, OrphanFocus): + continue + Filtered_Foci[focus_id] = focus + + return Filtered_Foci + def filter_cells_containing_val_in_attr(Cells, attr, val): '''Return only cells that have val in list attribute, attr.''' @@ -5826,7 +6838,7 @@ def compile_cell_info_df(Cells): data['{}_volume_mean_fluorescence'.format(fluorescence_channel)] = np.zeros(long_df_row_number) data['{}_total_fluorescence'.format(fluorescence_channel)] = np.zeros(long_df_row_number) - data = populate_focus_arrays(Cells, data, cell_quants=True) + data = populate_cell_arrays(Cells, data, cell_quants=True) long_df = pd.DataFrame(data=data) wide_df_row_number = len(Cells) @@ -5850,7 +6862,7 @@ def compile_cell_info_df(Cells): 'death': np.zeros(wide_df_row_number), 'disappear': np.zeros(wide_df_row_number) } - data = populate_focus_arrays(Cells, data, cell_quants=True, wide=True) + data = populate_cell_arrays(Cells, data, cell_quants=True, wide=True) # data['parent_id'] = data['parent_id'].decode() # data['child1_id'] = data['child1_id'].decode() # data['child2_id'] = data['child2_id'].decode() @@ -5858,52 +6870,171 @@ def compile_cell_info_df(Cells): return(wide_df,long_df) -def populate_focus_arrays(Foci, data_dict, cell_quants=False, wide=False): +def populate_focus_arrays(objects, data_dict, cell_quants=False, wide=False): + + ''' + populate_focus_arrays + + Parameters + __________ + + objects : a dictionary of either Cell objects or Focus objects + + data_dict : a dictionary of key/value pairs that will end up + being column headers/data in the dataframe - focus_counter = 0 - focus_count = len(Foci) + cell_quants : set True if 'objects' argument is a dictionary of Cell + objects, false if it is a dictionary of Foci + + wide : return a wide DataFrame if set to True, otherwise return long + ''' + + object_counter = 0 + object_count = len(objects) end_idx = 0 - for i,focus in enumerate(Foci.values()): + for i,obj in enumerate(objects.values()): + + try: + len(obj) + except TypeError: + continue if wide: start_idx = i end_idx = i + 1 else: + start_idx = end_idx + end_idx = len(obj) + start_idx + + if object_counter % 100 == 0: + print("Generating information for object {} out of {}.".format(object_counter+1, object_count)) + + # loop over keys in data dictionary, and set + # values in appropriate array, at appropriate indices + # to those we find in the object. + for key in data_dict.keys(): + + if '_id' in key: + + if key == 'parent_id': + if obj.parent is None: + data_dict[key][start_idx:end_idx] = '' + else: + data_dict[key][start_idx:end_idx] = obj.parent.id + + if obj.daughters is None: + if key == 'child1_id' or key == 'child2_id': + data_dict[key][start_idx:end_idx] = '' + elif len(obj.daughters) == 1: + if key == 'child2_id': + data_dict[key][start_idx:end_idx] = '' + elif key == 'child1_id': + data_dict[key][start_idx:end_idx] = obj.daughters[0].id + elif key == 'child2_id': + data_dict[key][start_idx:end_idx] = obj.daughters[1].id + + else: + # if '_fluorescence' in key: + + # key_elements = key.split('_') + # fluorescence_channel = key_elements[0] + # key_base = '_'.join(key_elements[1:]) + # attr_vals = getattr(obj, key_base) + # attr_vals = attr_vals[fluorescence_channel] + + # else: + attr_vals = getattr(obj, key) + + if (cell_quants and key=='abs_times'): + if len(attr_vals) == end_idx-start_idx: + data_dict[key][start_idx:end_idx] = attr_vals + else: + data_dict[key][start_idx:end_idx] = attr_vals[:-1] + else: + # print(key) + # print(attr_vals) + data_dict[key][start_idx:end_idx] = attr_vals + + object_counter += 1 + data_dict['id'] = data_dict['id'].decode() + + return(data_dict) + +def populate_cell_arrays(objects, data_dict, cell_quants=False, wide=False): + + ''' + populate_cell_arrays + + Parameters + __________ + + objects : a dictionary of either Cell objects or Focus objects + + data_dict : a dictionary of key/value pairs that will end up + being column headers/data in the dataframe + + cell_quants : set True if 'objects' argument is a dictionary of Cell + objects, false if it is a dictionary of Foci + + wide : return a wide DataFrame if set to True, otherwise return long + ''' + + object_counter = 0 + object_count = len(objects) + end_idx = 0 + + for i,obj in enumerate(objects.values()): + + if wide: + start_idx = i + end_idx = i + 1 + + else: start_idx = end_idx - end_idx = len(focus) + start_idx + end_idx = len(obj) + start_idx - if focus_counter % 100 == 0: - print("Generating focus information for focus {} out of {}.".format(focus_counter+1, focus_count)) + if object_counter % 100 == 0: + print("Generating information for object {} out of {}.".format(object_counter+1, object_count)) # loop over keys in data dictionary, and set # values in appropriate array, at appropriate indices - # to those we find in the focus. + # to those we find in the object. for key in data_dict.keys(): if '_id' in key: if key == 'parent_id': - if focus.parent is None: + if obj.parent is None: data_dict[key][start_idx:end_idx] = '' else: - data_dict[key][start_idx:end_idx] = focus.parent.id + data_dict[key][start_idx:end_idx] = obj.parent.id - if focus.daughters is None: + if obj.daughters is None: if key == 'child1_id' or key == 'child2_id': data_dict[key][start_idx:end_idx] = '' - elif len(focus.daughters) == 1: + elif len(obj.daughters) == 1: if key == 'child2_id': data_dict[key][start_idx:end_idx] = '' elif key == 'child1_id': - data_dict[key][start_idx:end_idx] = focus.daughters[0].id + data_dict[key][start_idx:end_idx] = obj.daughters[0].id elif key == 'child2_id': - data_dict[key][start_idx:end_idx] = focus.daughters[1].id + data_dict[key][start_idx:end_idx] = obj.daughters[1].id else: - attr_vals = getattr(focus, key) + if '_fluorescence' in key: + + key_elements = key.split('_') + fluorescence_channel = key_elements[0] + key_base = '_'.join(key_elements[1:]) + attr_vals = getattr(obj, key_base) + attr_vals = attr_vals[fluorescence_channel] + + else: + attr_vals = getattr(obj, key) + if (cell_quants and key=='abs_times'): if len(attr_vals) == end_idx-start_idx: data_dict[key][start_idx:end_idx] = attr_vals @@ -5914,12 +7045,13 @@ def populate_focus_arrays(Foci, data_dict, cell_quants=False, wide=False): # print(attr_vals) data_dict[key][start_idx:end_idx] = attr_vals - focus_counter += 1 + object_counter += 1 data_dict['id'] = data_dict['id'].decode() return(data_dict) + def compile_foci_info_long_df(Foci): ''' Parameters @@ -5937,8 +7069,13 @@ def compile_foci_info_long_df(Foci): # count the number of rows that will be in the long dataframe long_df_row_number = 0 - for focus in Foci.values(): - long_df_row_number += len(focus) + orphan_ids = [] + for k,focus in Foci.items(): + try: + long_df_row_number += len(focus) + except TypeError: + orphan_ids.append(k) + continue # initialize some arrays for filling with data data = { @@ -6444,6 +7581,50 @@ def foci_lap(img, img_foci, cell, t): return disp_l, disp_w, foci_h +# get this focus' cell +def get_focus_cell(frame_cells, seg_cell_img, seg_foci_img, tracked_focus, t, max_cell_number): + + focus_regions = [tracked_focus] + + cell_distances = get_focus_cell_distance( + seg_foci_img, + seg_cell_img, + focus_regions + ) + + # print(t) + # print(cell_distances) + # print(cell_distances.size) + # print(tracked_focus.label) + + # sometimes focus is bad in phase, so all cells are gone, + # but foci remain. Handle that here. + if cell_distances.size == 0: + bad = 'too far down' + return(bad) + + this_label = np.argmin(cell_distances) + 1 + + # print(this_label) + # print(max_cell_number) + # if the focus is too far from any cell, handle that here + if np.min(cell_distances) > 5: + bad = 'too far down' + return(bad) + + if this_label > max_cell_number: + bad = 'too far down' + return(bad) + + # determine which cell this focus belongs to + for cell_id,cell in frame_cells.items(): + + cell_idx = cell.times.index(t) + cell_label = cell.labels[cell_idx] + + if this_label == cell_label: + return(cell) + # actual worker function for foci detection def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): '''foci_info_unet operates on cells in which foci have been found using @@ -6475,14 +7656,14 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): # iterate over each peak in fov for peak_id,peak_value in fov_peaks.items(): - # print(fov_id, peak_id) - # keep cells with this peak_id - peak_cells = filter_cells(fov_cells, attr='peak', val=peak_id) - # if peak_id's value is not 1, go to next peak if peak_value != 1: continue + # print(fov_id, peak_id) + # keep cells with this peak_id + peak_cells = filter_cells(fov_cells, attr='peak', val=peak_id) + print("Analyzing foci in experiment {}, channel {}, fov {}, peak {}.".format(params['experiment_name'], channel_name, fov_id, peak_id)) # Load fluorescent images and segmented images for this channel fl_stack = load_stack(fov_id, peak_id, color=channel_name) @@ -6507,6 +7688,7 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): frame_cells = filter_cells_containing_val_in_attr(peak_cells, attr='times', val=t) # loop over focus regions in this frame focus_regions = measure.regionprops(seg_foci_img) + orig_focus_regions = focus_regions.copy() # compare this frame's foci to prior frame's foci for tracking if frame > 0: @@ -6518,21 +7700,26 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): peak_foci = filter_cells(fov_foci, attr='peak', val=peak_id) + prior_frame_foci = filter_cells_containing_val_in_attr(peak_foci, attr='times', val=t-1) + prior_frame_foci = filter_orphan_foci(prior_frame_foci) # if there were foci in prior frame, do stuff if len(prior_frame_foci) > 0: prior_regions = measure.regionprops(prior_seg_foci_img) # compare_array is prior_focus_number x this_focus_number - # contains dice indices for each pairwise comparison - # between focus positions + # contains product of multiplying blurred focus mask + # for each pair-wise comparison of foci from the prior + # frame to this frame compare_array = np.zeros((np.max(prior_seg_foci_img), np.max(seg_foci_img))) + # populate the array with dice indices for prior_focus_idx in range(np.max(prior_seg_foci_img)): prior_focus_mask = np.zeros(seg_foci_img.shape) + # set this focus' pixels to 1 prior_focus_mask[prior_seg_foci_img == (prior_focus_idx + 1)] = 1 # apply gaussian blur with sigma=1 to prior focus mask @@ -6549,16 +7736,31 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): # multiply the two images and place max into campare_array product = gaus_1 * gaus_2 compare_array[prior_focus_idx, this_focus_idx] = np.max(product) - + # which rows of each column are maximum product of gaussian blurs? + # reset other cells in compare_array to 0 + max_cols_by_row = np.argmax(compare_array, axis=1) + for r_idx in range(compare_array.shape[0]): + max_col_idx = max_cols_by_row[r_idx] + for c_idx in range(compare_array.shape[1]): + if c_idx == max_col_idx: + continue + compare_array[r_idx,c_idx] = 0 + max_inds = np.argmax(compare_array, axis=0) - # because np.argmax returns zero if all rows are equal, we - # need to evaluate if all rows are equal. - # If std_dev is zero, then all were equal, - # and we omit that index from consideration for - # focus tracking. - sd_vals = np.std(compare_array, axis=0) - tracked_inds = np.where(sd_vals > 0)[0] + + # if compare_array has greater than 1 row, check sd of each column + if compare_array.shape[0] > 1: + # because np.argmax returns zero if all rows are equal, we + # need to evaluate if all rows are equal. + # If std_dev is zero, then all were equal, + # and we omit that index from consideration for + # focus tracking. + sd_vals = np.std(compare_array, axis=0) + tracked_inds = np.where(sd_vals > 0)[0] + else: + tracked_inds = np.array([max_inds[0]]) + # if there is an index from a tracked focus, do this if tracked_inds.size > 0: @@ -6566,8 +7768,7 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): # grab this frame's region belonging to tracked focus tracked_label = tracked_idx + 1 (tracked_region_idx, tracked_region) = [(_,reg) for _,reg in enumerate(focus_regions) if reg.label == tracked_label][0] - # pop the region from focus_regions - del focus_regions[tracked_region_idx] + (orig_tracked_idx, orig_tracked_region) = [(_,reg) for _,reg in enumerate(orig_focus_regions) if reg.label == tracked_label][0] # grab prior frame's region belonging to tracked focus prior_tracked_label = max_inds[tracked_idx] + 1 @@ -6575,6 +7776,7 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): # grab the focus for which the prior_tracked_label is in # any of the labels in the prior focus from the prior time + # print(t) prior_tracked_foci = filter_foci( prior_frame_foci, label=prior_tracked_label, @@ -6582,6 +7784,13 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): debug=False ) + if len(prior_tracked_foci) == 0: + continue + # pprint(prior_tracked_foci) + + # pop the region from focus_regions + del focus_regions[tracked_region_idx] + prior_tracked_focus = [val for val in prior_tracked_foci.values()][0] # determine which cell this focus belongs to @@ -6590,28 +7799,98 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): cell_idx = cell.times.index(t) cell_label = cell.labels[cell_idx] + # create binary image with only our query cell + # labelled as 1 masked_cell_img = np.zeros(seg_cell_img.shape) masked_cell_img[seg_cell_img == cell_label] = 1 + # create binary image with only our query focus + # labelled as 1 masked_focus_img = np.zeros(seg_foci_img.shape) - masked_focus_img[seg_foci_img == tracked_region.label] = 1 + masked_focus_img[seg_foci_img == tracked_label] = 1 + # add images together intersect_img = masked_cell_img + masked_focus_img - pixels_two = len(np.where(intersect_img == 2)) - pixels_one = len(np.where(masked_focus_img == 1)) + # how many pixels are 2 in sum image? + pixels_two = len(np.where(intersect_img == 2)[0]) + # how many pixels were 1 in focus binary image? + pixels_one = len(np.where(masked_focus_img == 1)[0]) # if over half the focus is within this cell, do the following - if pixels_two/pixels_one >= 0.5: - - prior_tracked_focus.grow( - region=tracked_region, - t=t, - seg_img=seg_foci_img, - intensity_image=fl_img, - current_cell=cell - ) - + if pixels_two/pixels_one > 0: + + # focus belongs to this cell, so now we need to keep only foci + # in this cell, relabel foci with cell-centric focus labels + # current focus' label is 'tracked_label'. + # iterate from 1 to 'tracked_label', successively adding + # foci within this cell to the mask, then labelling the resulting mask + # Then keep the resulting cell-centric label of this tracked focus + foci_in_cell_seg_img = np.zeros(seg_foci_img.shape) + + for focus_reg in orig_focus_regions: + masked_focus_img = np.zeros(seg_foci_img.shape) + masked_focus_img[seg_foci_img == focus_reg.label] = 1 + + foci_intersect_img = masked_cell_img + masked_focus_img + + pixels_two = len(np.where(foci_intersect_img == 2)[0]) + pixels_one = len(np.where(masked_focus_img == 1)[0]) + + # if over half the focus is within this cell, keep the focus + if pixels_two/pixels_one > 0: + foci_in_cell_seg_img = foci_in_cell_seg_img + masked_focus_img + + foci_in_cell_label_img = measure.label(foci_in_cell_seg_img) + tracked_region.cell_centric_label = foci_in_cell_label_img[seg_foci_img == tracked_label][0] + + # sanity check: does this cell make sense given prior cell? + prior_cell = prior_tracked_focus.cells[-1] + prior_cell_end_time = prior_cell.times[-1] + # if the prior cell is this cell, we're good + if prior_cell.id == cell.id: + + prior_tracked_focus.grow( + region=tracked_region, + t=t, + seg_img=seg_foci_img, + intensity_image=fl_img, + current_cell=cell + ) + + # if the prior cell ended one frame ago, this cell must be one of + # the prior cell's daughters to grow the focus in this cell + elif (prior_cell_end_time - t) == -1: + prior_cell_daughters = prior_cell.daughters + + if prior_cell_daughters is not None: + if cell in prior_cell_daughters: + prior_tracked_focus.grow( + region=tracked_region, + t=t, + seg_img=seg_foci_img, + intensity_image=fl_img, + current_cell=cell + ) + + else: + # make the focus_id + new_id = create_focus_id( + region = focus_region, + t = t, + peak = peak_id, + fov = fov_id, + experiment_name = params['experiment_name'] + ) + + foci[new_id] = Focus( + cell = cell, + region = tracked_region, + seg_img = seg_foci_img, + intensity_image = fl_img, + t = t + ) + # after tracking foci, those that were tracked have been removed from focus_regions list # now we check if any regions remain in the list # if there are any remaining, instantiate new foci @@ -6626,7 +7905,8 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): t = t, peak = peak_id, fov = fov_id, - experiment_name = params['experiment_name']) + experiment_name = params['experiment_name'] + ) # populate list for later checking if any are missing # from foci dictionary's keys new_ids.append(new_id) @@ -6645,286 +7925,104 @@ def foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): intersect_img = masked_cell_img + masked_focus_img - pixels_two = len(np.where(intersect_img == 2)) - pixels_one = len(np.where(masked_focus_img == 1)) + pixels_two = len(np.where(intersect_img == 2)[0]) + pixels_one = len(np.where(masked_focus_img == 1)[0]) # if over half the focus is within this cell, do the following - if pixels_two/pixels_one >= 0.5: + if pixels_two/pixels_one > 0: + + # focus belongs to this cell, so now we need to keep only foci + # in this cell, relabel foci, and get new labels for cell-centric focus + # labelling + retained_foci_seg_img = np.zeros(seg_foci_img.shape) + + for focus_reg in orig_focus_regions: + masked_focus_img = np.zeros(seg_foci_img.shape) + masked_focus_img[seg_foci_img == focus_reg.label] = 1 + + intersect_img = masked_cell_img + masked_focus_img + + pixels_two = len(np.where(intersect_img == 2)[0]) + pixels_one = len(np.where(masked_focus_img == 1)[0]) + + # if over half the focus is within this cell, keep the focus + if pixels_two/pixels_one > 0: + retained_foci_seg_img = retained_foci_seg_img + masked_focus_img + + retained_foci_label_img = measure.label(retained_foci_seg_img) + focus_region.cell_centric_label = retained_foci_label_img[seg_foci_img == focus_region.label][0] + # set up the focus # if no foci in cell, just add this one. - foci[new_id] = Focus(cell = cell, - region = focus_region, - seg_img = seg_foci_img, - intensity_image = fl_img, - t = t) - - for new_id in new_ids: - # if new_id is not a key in the foci dictionary, - # that suggests the focus doesn't overlap well - # with any cells in this frame, so we'll relabel - # this frame of seg_foci_stack to zero for that - # focus to avoid trying to track a focus - # that doesn't exist. + foci[new_id] = Focus( + cell = cell, + region = focus_region, + seg_img = seg_foci_img, + intensity_image = fl_img, + t = t + ) + # since we've already identified the cell, break this loop + # and move on to next focus + break + + # if we didn't find a cell for this focus, + # create the focus object, but make it's cell label 0 to denote its orphan status if new_id not in foci: - # get label of new_id's region - this_label = int(new_id[-2:]) - # set pixels in this frame that match this label to 0 - seg_foci_stack[frame, seg_foci_img == this_label] = 0 + # create a placeholder region with label = 0 to mark as orphan + focus_region.cell_centric_label = 0 + + foci[new_id] = OrphanFocus( + cell = OrphanFocusCell( + fov_id, + peak_id, + t + ), + region = focus_region, + seg_img = seg_foci_img, + intensity_image = fl_img, + t = t + ) + + # for new_id in new_ids: + # # if new_id is not a key in the foci dictionary, + # # that suggests the focus doesn't overlap well + # # with any cells in this frame, so we'll relabel + # # this frame of seg_foci_stack to zero for that + # # focus to avoid trying to track a focus + # # that doesn't exist. + # if new_id not in foci: + + # # get label of new_id's region + # this_label = int(new_id[-2:]) + # # set pixels in this frame that match this label to 0 + # seg_foci_stack[frame, seg_foci_img == this_label] = 0 return -# def dev_foci_info_unet(foci, Cells, specs, time_table, channel_name='sub_c2'): -# '''foci_info_unet operates on cells in which foci have been found using -# using Unet. - -# Parameters -# ---------- -# Foci : empty dictionary for Focus objects to be placed into -# Cells : dictionary of Cell objects to which foci will be added -# specs : dictionary containing information on which fov/peak ids -# are to be used, and which are to be excluded from analysis -# time_table : dictionary containing information on which time -# points correspond to which absolute times in seconds -# channel_name : name of fluorescent channel for reading in -# fluorescence images for focus quantification - -# Returns -# ------- -# Updates cell information in Cells in-place. -# Cells must have .foci attribute -# ''' - -# # iterate over each fov in specs -# for fov_id,fov_peaks in specs.items(): - -# # keep cells with this fov_id -# fov_cells = filter_cells(Cells, attr='fov', val=fov_id) - -# # iterate over each peak in fov -# for peak_id,peak_value in fov_peaks.items(): - -# # print(fov_id, peak_id) -# # keep cells with this peak_id -# peak_cells = filter_cells(fov_cells, attr='peak', val=peak_id) - -# # if peak_id's value is not 1, go to next peak -# if peak_value != 1: -# continue - -# print("Analyzing foci in experiment {}, channel {}, fov {}, peak {}.".format(params['experiment_name'], channel_name, fov_id, peak_id)) -# # Load fluorescent images and segmented images for this channel -# fl_stack = load_stack(fov_id, peak_id, color=channel_name) -# seg_foci_stack = load_stack(fov_id, peak_id, color='foci_seg_unet') -# seg_cell_stack = load_stack(fov_id, peak_id, color='seg_unet') - -# # loop over each frame -# for frame in range(fl_stack.shape[0]): - -# fl_img = fl_stack[frame, ...] -# seg_foci_img = seg_foci_stack[frame, ...] -# seg_cell_img = seg_cell_stack[frame, ...] - -# # if there are no foci in this frame, move to next frame -# if np.max(seg_foci_img) == 0: -# continue -# # if there are no cells in this fov/peak/frame, move to next frame -# if np.max(seg_cell_img) == 0: -# continue - -# t = frame+1 -# frame_cells = filter_cells_containing_val_in_attr(peak_cells, attr='times', val=t) -# next_frame_cells = filter_cells_containing_val_in_attr(peak_cells, attr='times', val=t+1) - -# # prepare focus regions in this frame -# focus_regions = measure.regionprops(seg_foci_img) - -# # loop over cells in this frame, linking to same cell or inherited cells in next frame -# for cell in frame_cells: - -# pass - - - -# # compare this frame's foci to prior frame's foci for tracking -# if frame > 0: -# prior_seg_foci_img = seg_foci_stack[frame-1, ...] - -# fov_foci = filter_cells(foci, -# attr='fov', -# val=fov_id) -# peak_foci = filter_cells(fov_foci, -# attr='peak', -# val=peak_id) -# prior_frame_foci = filter_cells_containing_val_in_attr(peak_foci, attr='times', val=t-1) - -# # if there were foci in prior frame, do stuff -# if len(prior_frame_foci) > 0: -# prior_regions = measure.regionprops(prior_seg_foci_img) - -# # compare_array is prior_focus_number x this_focus_number -# # contains dice indices for each pairwise comparison -# # between focus positions -# compare_array = np.zeros((np.max(prior_seg_foci_img), -# np.max(seg_foci_img))) -# # populate the array with dice indices -# for prior_focus_idx in range(np.max(prior_seg_foci_img)): - -# prior_focus_mask = np.zeros(seg_foci_img.shape) -# prior_focus_mask[prior_seg_foci_img == (prior_focus_idx + 1)] = 1 - -# # apply gaussian blur with sigma=1 to prior focus mask -# sig = 1 -# gaus_1 = filters.gaussian(prior_focus_mask, sigma=sig) - -# for this_focus_idx in range(np.max(seg_foci_img)): - -# this_focus_mask = np.zeros(seg_foci_img.shape) -# this_focus_mask[seg_foci_img == (this_focus_idx + 1)] = 1 - -# # apply gaussian blur with sigma=1 to this focus mask -# gaus_2 = filters.gaussian(this_focus_mask, sigma=sig) -# # multiply the two images and place max into campare_array -# product = gaus_1 * gaus_2 -# compare_array[prior_focus_idx, this_focus_idx] = np.max(product) - -# # which rows of each column are maximum product of gaussian blurs? -# max_inds = np.argmax(compare_array, axis=0) -# # because np.argmax returns zero if all rows are equal, we -# # need to evaluate if all rows are equal. -# # If std_dev is zero, then all were equal, -# # and we omit that index from consideration for -# # focus tracking. -# sd_vals = np.std(compare_array, axis=0) -# tracked_inds = np.where(sd_vals > 0)[0] -# # if there is an index from a tracked focus, do this -# if tracked_inds.size > 0: - -# for tracked_idx in tracked_inds: -# # grab this frame's region belonging to tracked focus -# tracked_label = tracked_idx + 1 -# (tracked_region_idx, tracked_region) = [(_,reg) for _,reg in enumerate(focus_regions) if reg.label == tracked_label][0] -# # pop the region from focus_regions -# del focus_regions[tracked_region_idx] - -# # grab prior frame's region belonging to tracked focus -# prior_tracked_label = max_inds[tracked_idx] + 1 -# # prior_tracked_region = [reg for reg in prior_regions if reg.label == prior_tracked_label][0] - -# # grab the focus for which the prior_tracked_label is in -# # any of the labels in the prior focus from the prior time -# prior_tracked_foci = filter_foci( -# prior_frame_foci, -# label=prior_tracked_label, -# t = t-1, -# debug=False -# ) - -# prior_tracked_focus = [val for val in prior_tracked_foci.values()][0] - -# # determine which cell this focus belongs to -# for cell_id,cell in frame_cells.items(): - -# cell_idx = cell.times.index(t) -# cell_label = cell.labels[cell_idx] - -# masked_cell_img = np.zeros(seg_cell_img.shape) -# masked_cell_img[seg_cell_img == cell_label] = 1 - -# masked_focus_img = np.zeros(seg_foci_img.shape) -# masked_focus_img[seg_foci_img == tracked_region.label] = 1 - -# intersect_img = masked_cell_img + masked_focus_img - -# pixels_two = len(np.where(intersect_img == 2)) -# pixels_one = len(np.where(masked_focus_img == 1)) - -# # if over half the focus is within this cell, do the following -# if pixels_two/pixels_one >= 0.5: - -# prior_tracked_focus.grow( -# region=tracked_region, -# t=t, -# seg_img=seg_foci_img, -# intensity_image=fl_img, -# current_cell=cell -# ) - -# # after tracking foci, those that were tracked have been removed from focus_regions list -# # now we check if any regions remain in the list -# # if there are any remaining, instantiate new foci -# if len(focus_regions) > 0: -# new_ids = [] - -# for focus_region in focus_regions: - -# # make the focus_id -# new_id = create_focus_id( -# region = focus_region, -# t = t, -# peak = peak_id, -# fov = fov_id, -# experiment_name = params['experiment_name']) -# # populate list for later checking if any are missing -# # from foci dictionary's keys -# new_ids.append(new_id) - -# # determine which cell this focus belongs to -# for cell_id,cell in frame_cells.items(): - -# cell_idx = cell.times.index(t) -# cell_label = cell.labels[cell_idx] - -# masked_cell_img = np.zeros(seg_cell_img.shape) -# masked_cell_img[seg_cell_img == cell_label] = 1 - -# masked_focus_img = np.zeros(seg_foci_img.shape) -# masked_focus_img[seg_foci_img == focus_region.label] = 1 - -# intersect_img = masked_cell_img + masked_focus_img - -# pixels_two = len(np.where(intersect_img == 2)) -# pixels_one = len(np.where(masked_focus_img == 1)) - -# # if over half the focus is within this cell, do the following -# if pixels_two/pixels_one >= 0.5: -# # set up the focus -# # if no foci in cell, just add this one. - -# foci[new_id] = Focus(cell = cell, -# region = focus_region, -# seg_img = seg_foci_img, -# intensity_image = fl_img, -# t = t) - -# for new_id in new_ids: -# # if new_id is not a key in the foci dictionary, -# # that suggests the focus doesn't overlap well -# # with any cells in this frame, so we'll relabel -# # this frame of seg_foci_stack to zero for that -# # focus to avoid trying to track a focus -# # that doesn't exist. -# if new_id not in foci: - -# # get label of new_id's region -# this_label = int(new_id[-2:]) -# # set pixels in this frame that match this label to 0 -# seg_foci_stack[frame, seg_foci_img == this_label] = 0 +def foci_info_unet_curated(foci, Cells, specs, time_table, channel_name='sub_c2'): -# return + for Cell in Cells: + pass + pass def update_cell_foci(cells, foci): '''Updates cells' .foci attribute in-place using information in foci dictionary ''' for focus_id, focus in foci.items(): + # print(focus) + if isinstance(focus, OrphanFocus): + continue for cell in focus.cells: cell_id = cell.id + if cell_id == 'orphan': + continue cells[cell_id].foci[focus_id] = focus -# finds best fit for 2d gaussian using functin above +# finds best fit for 2d gaussian using function above def fitgaussian(data): """Returns (height, x, y, width_x, width_y) the gaussian parameters of a 2D distribution found by a fit diff --git a/mm3_metamorphToTIFF.py b/mm3_metamorphToTIFF.py index f834edb..0abec2c 100755 --- a/mm3_metamorphToTIFF.py +++ b/mm3_metamorphToTIFF.py @@ -47,10 +47,12 @@ def information(*objs): '''Edit TIFFs from Jeremy's format to the one expected by mm3.''' # set switches and parameters - parser = argparse.ArgumentParser(prog='python mm3_Compile.py', - description='Identifies and slices out channels into individual TIFF stacks through time.') + parser = argparse.ArgumentParser(prog='python mm3_metamorphToTIFF.py', + description='Converts TIFFs from metamorph into the format expected by mm3.') parser.add_argument('-f', '--paramfile', type=str, required=True, help='Yaml file containing parameters.') + parser.add_argument('-p', '--path', type=str, + required=False, help='Path to data directory. Overrides what is in param file') namespace = parser.parse_args() # Load the project parameters file @@ -60,7 +62,11 @@ def information(*objs): else: mm3.warning('No param file specified. Using 100X template.') param_file_path = 'yaml_templates/params_SJ110_100X.yaml' - p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + + if namespace.path: + p = mm3.init_mm3_helpers(param_file_path, datapath=namespace.path) # initialized the helper library + else: + p = mm3.init_mm3_helpers(param_file_path, datapath=None) # define variables here source_dir = p['experiment_directory'] diff --git a/mm3_plots.py b/mm3_plots.py index fae0b61..90eb8d7 100755 --- a/mm3_plots.py +++ b/mm3_plots.py @@ -270,6 +270,50 @@ def find_last_daughter(cell, Cells): # finally, return the deepest cell return cell +def get_lineage_list(cell, all_cell_ids): + '''Finds the cells in a lineage after cell passed to function''' + + # work back to original cell in lineage + while ((cell.parent is not None) and (cell.parent in all_cell_ids)): + cell = cell.parent + + # initialize the list of cells in the lineage + lineage_list = [cell.id] + + # if this cell has daughters, do the following + while cell.daughters is not None: + cell = cell.daughters[0] + lineage_list.append(cell.id) + + lineage_list.sort() + + # finally, return the list of cell id's in the lineage + return lineage_list + +def find_all_lineages(Cells): + ''' + Generates lists of cell_id's. + Each list corresponds to the cells in a single continuous + lineage + ''' + + all_cell_ids = [cell_id for cell_id in Cells.keys()] + all_cell_ids.sort() + all_lineages = [] + + while len(all_cell_ids) > 0: + + cell_lineage = get_lineage_list(Cells[all_cell_ids[0]], all_cell_ids) + + for lineage_cell_id in cell_lineage: + if lineage_cell_id in all_cell_ids: + all_cell_ids.pop(all_cell_ids.index(lineage_cell_id)) + + all_lineages.append(cell_lineage) + + return all_lineages + + def find_continuous_lineages(Cells, specs, t1=0, t2=1000): ''' Uses a recursive function to only return cells that have continuous @@ -344,6 +388,81 @@ def find_continuous_lineages(Cells, specs, t1=0, t2=1000): return Cells +def find_lineages(Cells, specs): + ''' + Uses a recursive function to only return cells that have continuous + lineages between two time points. Takes a "lineage" form of Cells and + returns a dictionary of the same format. Good for plotting + with saw_tooth_plot() + + t1 : int + First cell in lineage must be born before this time point + t2 : int + Last cell in lineage must be born after this time point + ''' + + Lineages = organize_cells_by_channel(Cells, specs) + + # This is a mirror of the lineages dictionary, just for the continuous cells + Continuous_Lineages = {} + + for fov, peaks in six.iteritems(Lineages): + # print("fov = {:d}".format(fov)) + # Create a dictionary to hold this FOV + Continuous_Lineages[fov] = {} + + for peak, Cells in six.iteritems(peaks): + # print("{:<4s}peak = {:d}".format("",peak)) + # sort the cells by time in a list for this peak + cells_sorted = [(cell_id, cell) for cell_id, cell in six.iteritems(Cells)] + cells_sorted = sorted(cells_sorted, key=lambda x: x[1].birth_time) + + # Sometimes there are not any cells for the channel even if it was to be analyzed + if not cells_sorted: + continue + + # look through list to find the cell born immediately before t1 + # and divides after t1, but not after t2 + first_cell_index = 0 + # for i, cell_data in enumerate(cells_sorted): + # cell_id, cell = cell_data + # if cell.birth_time < t1 and t1 <= cell.division_time < t2: + # first_cell_index = i + # break + + # filter cell_sorted or skip if you got to the end of the list + if i == len(cells_sorted) - 1: + continue + else: + cells_sorted = cells_sorted[i:] + + # get the first cell and it's last contiguous daughter + first_cell = cells_sorted[0][1] + last_daughter = find_last_daughter(first_cell, Cells) + + # check to the daughter makes the second cut off + if last_daughter.birth_time > t2: + # print(fov, peak, 'Made it') + + # now retrieve only those cells within the two times + # use the function to easily return in dictionary format + Cells_cont = find_cells_born_after(Cells, born_after=t1) + # Cells_cont = find_cells_born_before(Cells_cont, born_before=t2) + + # append the first cell which was filtered out in the above step + Cells_cont[first_cell.id] = first_cell + + # and add it to the big dictionary + Continuous_Lineages[fov][peak] = Cells_cont + + # remove keys that do not have any lineages + if not Continuous_Lineages[fov]: + Continuous_Lineages.pop(fov) + + Cells = lineages_to_dict(Continuous_Lineages) # revert back to return + + return Cells + def find_generation_gap(cell, Cells, gen): '''Finds how many continuous ancestors this cell has.''' diff --git a/aux/edit_tiffs.py b/sup/edit_tiffs.py similarity index 100% rename from aux/edit_tiffs.py rename to sup/edit_tiffs.py diff --git a/aux/fitmodel.py b/sup/fitmodel.py similarity index 100% rename from aux/fitmodel.py rename to sup/fitmodel.py diff --git a/aux/mm3_Colors.py b/sup/mm3_Colors.py similarity index 100% rename from aux/mm3_Colors.py rename to sup/mm3_Colors.py diff --git a/aux/mm3_Foci.py b/sup/mm3_Foci.py similarity index 100% rename from aux/mm3_Foci.py rename to sup/mm3_Foci.py diff --git a/aux/mm3_MovieMaker.py b/sup/mm3_MovieMaker.py similarity index 100% rename from aux/mm3_MovieMaker.py rename to sup/mm3_MovieMaker.py diff --git a/aux/mm3_OutputData.py b/sup/mm3_OutputData.py similarity index 100% rename from aux/mm3_OutputData.py rename to sup/mm3_OutputData.py diff --git a/aux/mm3_nd2ToTIFF.py b/sup/mm3_nd2ToTIFF.py similarity index 100% rename from aux/mm3_nd2ToTIFF.py rename to sup/mm3_nd2ToTIFF.py diff --git a/track_training_file_paths.csv b/track_training_file_paths.csv index 93b7c10..84a4c02 100644 --- a/track_training_file_paths.csv +++ b/track_training_file_paths.csv @@ -1,27 +1,37 @@ file_path,include /home/wanglab/sandbox/trackingGUI/testset1/analysis_20190423/cell_data/testset1_xy001_p0033_updated_tracks.pkl,0 /home/wanglab/sandbox/trackingGUI/testset1/analysis_20190423/cell_data/testset1_xy001_p0077_updated_tracks.pkl,0 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy001_p0056_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy003_p0063_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy003_p0067_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy003_p0068_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy003_p0069_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy001_p0056_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy003_p0063_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy003_p0067_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy003_p0068_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy003_p0069_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20190404/analysis/cell_data/20190404_JDW3407_xy001_p0006_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20190404/analysis/cell_data/20190404_JDW3407_xy001_p0010_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20190329/analysis/cell_data/20190329_JDW3407_xy001_p0020_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20190329/analysis/cell_data/20190329_JDW3407_xy001_p0032_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20190110/analysis/cell_data/20190110_JDW3411_xy001_p0086_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy001_p0039_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy001_p0041_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy001_p0051_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy001_p0052_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy002_p0040_updated_tracks.pkl,1 -/home/wanglab/Users_local/Jeremy/Imaging/20181011/normedSeg_analysis/cell_data/20181011_JDW3308_xy002_p0057_updated_tracks.pkl,1 -/Volumes/JunLabSSD_04/test_data/testset1/analysis/cell_data/testset1_xy001_p0077_updated_tracks.pkl,1 -/Volumes/JunLabSSD_04/test_data/testset1/analysis/cell_data/testset1_xy001_p0033_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy001_p0039_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy001_p0041_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy001_p0051_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy001_p0052_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy002_p0040_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20181011/analysis/cell_data/curated_tracks/20181011_JDW3308_xy002_p0057_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0034_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0043_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0107_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0116_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0121_updated_tracks.pkl,1 /home/wanglab/Users_local/Jeremy/Imaging/20191024/analysis/cell_data/20191024_JDW3410_xy001_p0122_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191101/analysis/cell_data/20191101_JDW3410_xy001_p0024_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191101/analysis/cell_data/20191101_JDW3410_xy001_p0031_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191101/analysis/cell_data/20191101_JDW3410_xy001_p0032_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191111/analysis/cell_data/20191111_JDW3410_xy001_p0003_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy001_p0012_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy001_p0022_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy001_p0023_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy001_p0027_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy002_p0006_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy002_p0009_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy002_p0013_updated_tracks.pkl,1 +/home/wanglab/Users_local/Jeremy/Imaging/20191112/analysis/cell_data/20191112_JDW3410_xy002_p0015_updated_tracks.pkl,1 diff --git a/transfer_channels_to_chtc.py b/transfer_channels_to_chtc.py new file mode 100755 index 0000000..86d6e06 --- /dev/null +++ b/transfer_channels_to_chtc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +# import modules +import sys +import os +import argparse +import yaml +import inspect +from pprint import pprint +from getpass import getpass +import glob +try: + import cPickle as pickle +except: + import pickle + +import paramiko + + +# user modules +# realpath() will make your script run, even if you symlink it +cmd_folder = os.path.realpath(os.path.abspath( + os.path.split(inspect.getfile(inspect.currentframe()))[0])) +if cmd_folder not in sys.path: + sys.path.insert(0, cmd_folder) + +# This makes python look for modules in ./external_lib +cmd_subfolder = os.path.realpath(os.path.abspath( + os.path.join(os.path.split(inspect.getfile( + inspect.currentframe()))[0], "external_lib"))) +if cmd_subfolder not in sys.path: + sys.path.insert(0, cmd_subfolder) + +# this is the mm3 module with all the useful functions and classes +import mm3_helpers as mm3 + + +# set switches and parameters +parser = argparse.ArgumentParser(prog='python mm3_Compile.py', + description='Identifies and slices out channels into individual TIFF stacks through time.') +parser.add_argument('-f', '--paramfile', type=str, + required=True, help='Yaml file containing parameters.') +parser.add_argument('-s', '--transfer_segmentation', action='store_true', + required=False, help='Add this option at command line to send segmentation files.') +parser.add_argument('-c', '--transfer_all_channels', action='store_true', + required=False, help='Add this option at command line to send all channels, not just phase.') +parser.add_argument('-j', '--transfer_job_file', action='store_true', + required=False, help='Add this option at command line to compile text file containing file names for job submission at chtc.') +namespace = parser.parse_args() + +# Load the project parameters file +mm3.information('Loading experiment parameters.') +if namespace.paramfile: + param_file_path = namespace.paramfile +else: + mm3.warning('No param file specified. Using 100X template.') + param_file_path = 'yaml_templates/params_SJ110_100X.yaml' + +param_file_path = os.path.join(os.getcwd(), param_file_path) +p = mm3.init_mm3_helpers(param_file_path) # initialized the helper library + +# load specs file +specs = mm3.load_specs() + +# identify files to be copied to chtc +files_to_transfer = [] +if namespace.transfer_job_file: + job_file_name = '{}_files_list.txt'.format(os.path.basename(param_file_path.split('.')[0])) + job_file = open(job_file_name,'w') + spec_file_name = os.path.join(p['ana_dir'], 'specs.yaml') + new_spec_file_name = '{}_specs.yaml'.format(p['experiment_name']) + time_file_name = os.path.join(p['ana_dir'], 'time_table.yaml') + new_time_file_name = '{}_time.yaml'.format(p['experiment_name']) + new_param_file_name = '{}_params.yaml'.format(p['experiment_name']) + +for fov_id,peak_ids in specs.items(): + for peak_id,val in peak_ids.items(): + if val == 1: + + if namespace.transfer_all_channels: + base_name = '{}_xy{:0=3}_p{:0=4}_*.tif'.format( + p['experiment_name'], + fov_id, + peak_id + ) + match_list = glob.glob(os.path.join(p['chnl_dir'],base_name)) + match_list.sort() + + else: + base_name = '{}_xy{:0=3}_p{:0=4}_{}.tif'.format( + p['experiment_name'], + fov_id, + peak_id, + p['phase_plane'] + ) + match_list = glob.glob(os.path.join(p['chnl_dir'],base_name)) + + if namespace.transfer_segmentation: + base_name = '{}_xy{:0=3}_p{:0=4}_{}.tif'.format( + p['experiment_name'], + fov_id, + peak_id, + 'seg_unet' + ) + fname = os.path.join(p['seg_dir'],base_name) + match_list.append(fname) + + # pprint(match_list) + + files_to_transfer.extend(match_list) + match_base_names = [os.path.basename(fname) for fname in match_list] + + if namespace.transfer_job_file: + + match_base_names.append(new_spec_file_name) + match_base_names.append(new_time_file_name) + match_base_names.append(param_file_path.split('/')[-1]) + + line_to_write = ','.join(match_base_names) + line_to_write = line_to_write + '\n' + + job_file.write(line_to_write) + + +if namespace.transfer_job_file: + job_file.close() + files_to_transfer.append(job_file_name) + +# files_to_transfer.append(param_file_path) + +print("You'll be sending {} files total to chtc.".format(len(files_to_transfer))) + +# connect to chtc +ssh = paramiko.SSHClient() +ssh.load_host_keys(os.path.expanduser(os.path.join("~", ".ssh", "known_hosts"))) +username = input("Username: ") +server = input("Hostname: ") +password = getpass("Password for {}@{}: ".format(username,server)) +ssh.connect(server, username=username, password=password) + +# copy files +sftp = ssh.open_sftp() +for localpath in files_to_transfer: + + print(localpath) + remotepath = localpath.split('/')[-1] + sftp.put(localpath, remotepath) + +sftp.put(param_file_path, new_param_file_name) + +if namespace.transfer_job_file: + + sftp.put(spec_file_name, new_spec_file_name) + sftp.put(time_file_name, new_time_file_name) + +sftp.close() +ssh.close() \ No newline at end of file diff --git a/weights/20200108_migration_predictor.hdf5 b/weights/20200108_migration_predictor.hdf5 new file mode 100644 index 0000000..ee521ac Binary files /dev/null and b/weights/20200108_migration_predictor.hdf5 differ diff --git a/weights/20200109_appear_predictor.hdf5 b/weights/20200109_appear_predictor.hdf5 new file mode 100644 index 0000000..44ed7d5 Binary files /dev/null and b/weights/20200109_appear_predictor.hdf5 differ diff --git a/weights/20200109_born_predictor.hdf5 b/weights/20200109_born_predictor.hdf5 new file mode 100644 index 0000000..1e538a5 Binary files /dev/null and b/weights/20200109_born_predictor.hdf5 differ diff --git a/weights/20190722_migration_predictor.hdf5 b/weights/20200109_child_predictor.hdf5 similarity index 77% rename from weights/20190722_migration_predictor.hdf5 rename to weights/20200109_child_predictor.hdf5 index a60f2e6..2adb77c 100644 Binary files a/weights/20190722_migration_predictor.hdf5 and b/weights/20200109_child_predictor.hdf5 differ diff --git a/weights/20200109_dies_predictor.hdf5 b/weights/20200109_dies_predictor.hdf5 new file mode 100644 index 0000000..6a9db7d Binary files /dev/null and b/weights/20200109_dies_predictor.hdf5 differ diff --git a/weights/20200109_disappear_predictor.hdf5 b/weights/20200109_disappear_predictor.hdf5 new file mode 100644 index 0000000..9cd8871 Binary files /dev/null and b/weights/20200109_disappear_predictor.hdf5 differ diff --git a/weights/20201002_RecA-GFP_foci.hdf5 b/weights/20201002_RecA-GFP_foci.hdf5 new file mode 100644 index 0000000..166d898 Binary files /dev/null and b/weights/20201002_RecA-GFP_foci.hdf5 differ diff --git a/weights/20201002_dropout_normed_cropped_trap_weights.hdf5 b/weights/20201002_dropout_normed_cropped_trap_weights.hdf5 new file mode 100644 index 0000000..91d68c9 Binary files /dev/null and b/weights/20201002_dropout_normed_cropped_trap_weights.hdf5 differ diff --git a/weights/20190808_RecA-GFP_foci.hdf5 b/weights/20201027_RecA-GFP_foci.hdf5 similarity index 74% rename from weights/20190808_RecA-GFP_foci.hdf5 rename to weights/20201027_RecA-GFP_foci.hdf5 index 624710b..ce582e2 100644 Binary files a/weights/20190808_RecA-GFP_foci.hdf5 and b/weights/20201027_RecA-GFP_foci.hdf5 differ diff --git a/weights/20201103_focus_disappear_predictor.hdf5 b/weights/20201103_focus_disappear_predictor.hdf5 new file mode 100644 index 0000000..897823c Binary files /dev/null and b/weights/20201103_focus_disappear_predictor.hdf5 differ diff --git a/weights/20201103_focus_migration_predictor.hdf5 b/weights/20201103_focus_migration_predictor.hdf5 new file mode 100644 index 0000000..fc8f5cc Binary files /dev/null and b/weights/20201103_focus_migration_predictor.hdf5 differ diff --git a/weights/20201104_focus_appear_predictor.hdf5 b/weights/20201104_focus_appear_predictor.hdf5 new file mode 100644 index 0000000..06c3ce2 Binary files /dev/null and b/weights/20201104_focus_appear_predictor.hdf5 differ diff --git a/yaml_templates/params_Unet.yaml b/yaml_templates/params_Unet.yaml index 9bb93d7..38990be 100644 --- a/yaml_templates/params_Unet.yaml +++ b/yaml_templates/params_Unet.yaml @@ -28,7 +28,7 @@ phase_plane: 'c1' # conversion factor from pixels to microns # use 0.065 for 100X, use 0.108 for 60X (Andor Neo) # use 0.11 for 100X (Photometrics Prime 95B) -pxl2um: 0.105 +pxl2um: 0.108 ### process control ############################################################ # Use these flags to control script specific processes and settings