diff --git a/Datasets/Sample_Dataset/test_cli_input2_noseg/Bladder1.png b/Datasets/Sample_Dataset/test_cli_input2_noseg/Bladder1.png new file mode 100644 index 0000000..c66f43d Binary files /dev/null and b/Datasets/Sample_Dataset/test_cli_input2_noseg/Bladder1.png differ diff --git a/Datasets/Sample_Dataset/test_cli_input2_noseg/Lung1.png b/Datasets/Sample_Dataset/test_cli_input2_noseg/Lung1.png new file mode 100644 index 0000000..3e11929 Binary files /dev/null and b/Datasets/Sample_Dataset/test_cli_input2_noseg/Lung1.png differ diff --git a/Model Types.md b/Model Types.md new file mode 100644 index 0000000..86c0ed9 --- /dev/null +++ b/Model Types.md @@ -0,0 +1,78 @@ +# Models + +## Usage +In general, option `--model` controls the type of model used in training: +``` +deepliif train ... --model +``` +During testing or inference, the model type information will automatically be fetched from the recorded options produced during training. + +Note that there could be additional configurable options for different model types. See below for details. + + +## Available Models +Currently we provide 5 models ready to be used with `deepliif` package. + +Input training data typically is a row of same-shaped square patches stitched together. Using DeepLIIF as an example, the input in the published paper is (IHC, Hematoxylin, DAPI, Lap2, Marker, Seg). This can be understood as (base mod, mod 1, mod 2, mod 3, mod 4, seg). + +|Model Name|Full Model Name|Tasks|Description|Training Data|Additional Config| +|----------|---------------|-----|-----------|-------------|-----------------| +|DeepLIIF|DeepLIIF|**paired** modality translation, segmentation (optional)|Based on the original DeepLIIF model as published in [Nature MI'22](https://rdcu.be/cKSBz) and the SDG model published as part of [MICCAI'24](https://arxiv.org/abs/2405.08169), this model class has been adapted to accept a configurable number of modalities and inputs, as well as the optional segmentation task.|Original: (base mod, mod 1, mod 2, mod 3, mod 4, seg)
New training: (base mod 1, ..., mod 1, ..., (seg))|`--seg-gen`: whether involving segmentation task
`--modalities-no`: the number of modalities to translate to (can be 0)
| +|DeepLIIFExt|DeepLIIF Extension|**paired** modality translation, segmentation (optional)|An extended version of the original DeepLIIF model that allows for **paired** mod-seg patches during training.|(base mod, mod 1, ..., seg 1, ...)|`--seg-gen`: whether involving segmentation task
`--modalities-no`: the number of modalities to translate to (>=1)| +|DeepLIIFKD|DeepLIIF Knowledge Distillation|knowledge distillation|This creates a student model (usually much smaller or simpler) and learns the output from the target large teacher model. This currently has only been tested with `DeepLIIF` as the teacher model, and may not be compatible with other model types.|(base mod, mod 1, mod 2, mod 3, mod 4, seg)|`--model-dir-teacher`: the directory of a trained large model to be used as the teacher| +|CycleGAN|CycleGAN|**unpaired** modality translation|This model is adapted from CycleGAN, where the main usage is to learn the translation from input domain to the target domain without paired ground truth data.|(base mod, mod 1)| +|SDG **(deprecatred)**|Synthetic Data Generation|modality translation|This model has been published. It translates provided modalities to the target modalities. In the paper, we used it to increase the resolution of stitched WSI from video frames. **This is now merged with DeepLIIF.**|||``| + + +## Model Details +### 1. DeepLIIF +DeepLIIF uses IHC as training input as well as 4 additional modalities (Hematoxylin, DAPI, Lap2, protein marker) to learn the translation task (4 generators, 1 for each modality), followed by a segmentation ground truth predicted collectively by 5 generators, each relying on one modality (IHC plus 4 translated modalities). The actual training data consists of wide images, where each has IHC + 4 modalities + segmentation stitched together in a row (see the Datasets folder for examples). + +During inference, only the IHC input is needed. + +The original setting employs **ResNet-9block** as the backbone for translation generators and **UNet-512 (9 down layers)** for segmentation generators. + +#### DeepLIIF with a configurable number of modalities and SDG capabilities +We recently updated DeepLIIF model class to allow an arbitrary number of modalities (which can even be 0). Setting `--modalities-no 0` means no modalities to translate to, and there will only be 1 generator that uses the base input mod (e.g., IHC) for the segmentation task (i.e., the whole DeepLIIF model in this case will only have 1 generator, rather than 4+5=9 generators in the original setting). + +Another major update is the merge with SDG as DeepLIIF has a considerable overlap with SDG. This means the DeepLIIF model now accepts `--seg-gen false` (no segmentation task) and the input number can be more than 1 (automatically deteremined using the number of patches in the training image minus the supplied modalities number in the command). + +### 2. DeepLIIFExt +Mainly based on the original DeepLIIF, this extension model allows to: +- learn only the modality translation task (`--seg-gen false`) +- use modality-wise segmentation ground truth (e.g., if use IHC as the base input, Lap2 and Marker as the additional modalities to learn translation to, then the training image should include 2 segmentation ground truth, 1 for each additional modalities) +- train any number of modalities (e.g., `--modalities-no 2` for 2 modalities to learn translation task for) +- modify loss function (DeepLIIF uses BCE for translation and LSGAN for segmentation) + +Some other noticeable differences (from the orignal DeepLIIF) include: +1. The training data to DeepLIIFExt requires **1 segmentation ground truth per modality**, rather than 1 final segmentation as in DeepLIIF. +2. The input to segmentation generator is not only the original base image (IHC in DeepLIIF) or a translated modality, but a concatenated vector of **original IHC, the first translated modality, and the current translated modality**. For example, if the sequence of patches in the training image is (IHC, Lap2, Marker, Lap2-Seg, Marker-Seg), then the input to the first segmentation generator will be (IHC, translated Lap2, translated Lap2), and the input to the second segmentation generator will be (IHC, translated Lap2, translated Marker). +3. The input to segmentation discriminator, as part of conditional GAN's practice, includes more context. In DeepLIIF, the context is the original IHC or a real modality image combined with the tensor to be evaluated in this context which is the aggregated final segmentation (fake case) or the real segmentation output (true case). In DeepLIIFExt, the context includes **the first real modality, and the current real modality**. Using the same example above, the input to the first discriminator will be (IHC, real Lap2, real Lap2, generated Lap2-Seg / real Lap2-Seg), and the input to the second discriminator will be (IHC, real Lap2, real Marker, generated Marker-Seg / real Marker-Seg). +4. DeepLIIF model includes a segmentation generator for the base input modality (e.g., IHC), while DeepLIIFExt does not. + + +|Input to|Original DeepLIIF|DeepLIIFExt| +|--------|--------|-----------| +|Translation generators|(IHC)|(base mod)| +|Segmentation generators|(IHC)
(generated Hema)
...|(base mod, generated mod 1, generated mod 1)
(base mod, generated mod 1, generated mod 2)
...| + +### 3. DeepLIIFKD +Our current KD approach simply flattens the output RGB tensors `(3, 512, 512)` into vectors `(1, 3*512*512)` and then applies KL divergence loss to the student’s output and the teacher’s output. The KL divergence loss is summed up for all 10 outputs (4 modality translations, 5 intermediate segmentations, and 1 final segmentation), and then added to the final loss term for back propagation. + +If we view the whole deepliif model set as one big model that produces the aggregated segmentation image as final output, then the 4 modality translation outputs and 5 intermediate segmentation outputs can be understood as intermediate features, which helps the student model to mimic how the teacher model arrives at the final segmentation output. In this sense, we do not incorporate "real" intermediate feature losses as other approaches did by comparing middle-layer output tensors of both models, but still effectively achieve the same purpose. + +The input to DeepLIIFKD is the same as that to DeepLIIF. + +*Note that DeepLIIFKD has only been tested with the original DeepLIIF model setup and that with a configurable number of modalities other than 4. DeepLIIF model with multiple input modalities or without the segmentation task may not work as expected.* + + +### 4. CycleGAN +The DeepLIIF model family requires paired images to learn the mapping. This, however, is not always achievable during data collection or might require considerably more efforts. Hence we apply the idea of CycleGAN for unpaired modality translation. + +The core idea is for the model to learn `f(A) = B` and `g(f(A)) = A`, where `A` and `B` denote data from the input and target domain, and `f(x)` and `g(x)` denote two mapping functions approximated by neural networks. For example, the input domain can be IHC and the target domain can be Ki-67. Generator `f(x)` learns how to map IHC to Ki-67 and generator `g(x)` learns how to map the translated Ki-67 back to the original IHC. In this case, we only need generator `f(x)` for inference. + +Our implementation of CycleGAN supports learning the mapping to multiple domains at once. Essentially, this becomes a multi-task learning: for each target domain `B1`, `B2`, ..., we create a separate generator set `f(x)` and `g(x)`. The losses from each generator set then gets combined for back propagation. + + +### 5. SDG +The Synthetic Data Generation model was developed for modality translation based on multiple inputs. This model class is now **deprecated** and will not be maintained in future releases. Users of this model class need to migrate to `DeepLIIF`. \ No newline at end of file diff --git a/README.md b/README.md index 136dbc2..d583945 100644 --- a/README.md +++ b/README.md @@ -77,17 +77,23 @@ The package is composed of two parts: You can list all available commands: ``` -(venv) $ deepliif --help +$ deepliif --help Usage: deepliif [OPTIONS] COMMAND [ARGS]... + Commonly used DeepLIIF batch operations + Options: --help Show this message and exit. Commands: prepare-testing-data Preparing data for testing + prepare-training-data Preparing data for training serialize Serialize DeepLIIF models using Torchscript test Test trained models + test-wsi train General-purpose training script for multi-task... + trainlaunch A wrapper method that executes deepliif/train.py... + visualize ``` **Note:** You might need to install a version of PyTorch that is compatible with your CUDA version. @@ -99,20 +105,33 @@ You can confirm if your installation will run on the GPU by checking if the foll import torch torch.cuda.is_available() ``` - -## Training Dataset -For training, all image sets must be 512x512 and combined together in 3072x512 images (six images of size 512x512 stitched -together horizontally). -The data need to be arranged in the following order: +## Dataset for training, validation, and testing +An example data directory looks as the following: ``` -XXX_Dataset + ├── train - └── val + ├── val + ├── val_cli + └── val_cli_gt ``` -We have provided a simple function in the CLI for preparing data for training. -* **To prepare data for training**, you need to have the image dataset for each image (including IHC, Hematoxylin Channel, mpIF DAPI, mpIF Lap2, mpIF marker, and segmentation mask) in the input directory. -Each of the six images for a single image set must have the same naming format, with only the name of the label for the type of image differing between them. The label names must be, respectively: IHC, Hematoxylin, DAPI, Lap2, Marker, Seg. +If you use different subfolder names, you will need to add `--phase {foldername}` in the training or testing commands for the functions to navigate to the correct subfolder. + +Content in each subfolder: +- train: training images used by command `python cli.py train`, see section Training Dataset below +- val: validation images used by command `python cli.py train --with-val`, see section Validation Dataset below +- val_cli: input modalities of the validation images used by command `python cli.py test`, see section Testing below +- val_cli_gt: ground truth of the output modalities from the validation images, used for evaluation purposes + +### Training Dataset +For training in general, each image in the training set is in the form of a set of horizontally stitched patches, in the order of **base input modalities, translation modalities, and segmentation modalities** (whenever applicable). + +Specifically for the DeepLIIF original model, all image sets must be 512x512 and combined together in 3072x512 images (six images of size 512x512 stitched together horizontally). + +We have provided a simple function in the CLI for preparing DeepLIIF data for training. + +* **To use this method to prepare data for training**, you need to have the image dataset for each image (including IHC, Hematoxylin Channel, mpIF DAPI, mpIF Lap2, mpIF marker, and segmentation mask) in the input directory. +Each of the six images for a single image set must have the same naming format, with only the name of the label for the type of image differing between them. To reproduce the original DeepLIIF model, the label names must be, respectively: IHC, Hematoxylin, DAPI, Lap2, Marker, Seg. The command takes the address of the directory containing image set data and the address of the output dataset directory. It first creates the train and validation directories inside the given output dataset directory. It then reads all of the images in the input directory and saves the combined image in the train or validation directory, based on the given `validation_ratio`. @@ -122,6 +141,19 @@ deepliif prepare-training-data --input-dir /path/to/input/images --validation-ratio 0.2 ``` +### Validation Dataset +The validation dataset consists of images of the same format as the training dataset and is totally optional (i.e., DeepLIIF model training command does not require a validation dataset to run). This currently is only implemented for **DeepLIIF or DeepLIIFKD models with segmentation task** (in which case the very last tile in the training / validation image is the segmentation tile). + +To use the validation dataset during training, it is necessary to first acquire the key quantitative statistics for the model to compare against as the training progresses. In tasks that target generating a single number or an array of numbers, validation metrics can be done by simply calculating the differences between the ground truth numbers and predicted numbers. In our image generation tasks, however, the key metrics we want to monitor are segmentation results: number of positive cells, number of negative cells, etc. These are much more informative and better reflect the quality of the model output than differences between pixel values. The ground truth quantitative numbers of segmentation results can be obtained using the `postprocess` function in `deepliif.models`. + +We provide a wrapper function `get_cell_count_metrics` that generates a JSON file for model validation: +``` +from deepliif.stat import get_cell_count_metrics +dir_img = '...' # e.g., directory to the validation images +get_cell_count_metrics(dir_img, dir_save=dir_img, model='DeepLIIF', tile_size=512) +``` + + ## Training To train a model: ``` @@ -137,48 +169,35 @@ python train.py --dataroot /path/to/input/images * To view training losses and results, open the URL http://localhost:8097. For cloud servers replace localhost with your IP. * Epoch-wise intermediate training results are in `DeepLIIF/checkpoints/Model_Name/web/index.html`. * Trained models will be by default be saved in `DeepLIIF/checkpoints/Model_Name`. -* Training datasets can be downloaded [here](https://zenodo.org/record/4751737#.YKRTS0NKhH4). - -**DP**: To train a model you can use DP. DP is single-process. It means that **all the GPUs you want to use must be on the same machine** so that they can be included in the same process - you cannot distribute the training across multiple GPU machines, unless you write your own code to handle inter-node (node = machine) communication. -To split and manage the workload for multiple GPUs within the same process, DP uses multi-threading. -You can find more information on DP [here](https://github.com/nadeemlab/DeepLIIF/blob/main/Multi-GPU%20Training.md). +* Training datasets for the original DeepLIIF model can be downloaded from [Zenodo](https://zenodo.org/record/4751737#.YKRTS0NKhH4). -To train a model with DP (Example with 2 GPUs (on 1 machine)): -``` -deepliif train --dataroot --batch-size 6 --gpu-ids 0 --gpu-ids 1 -``` -Note that `batch-size` is defined per process. Since DP is a single-process method, the `batch-size` you set is the **effective** batch size. +### Multi-GPU Training +You can find more information on multi-gpu training with DeepLIIF code [here](https://github.com/nadeemlab/DeepLIIF/blob/main/Multi-GPU%20Training.md). -**DDP**: To train a model you can use DDP. DDP usually spawns multiple processes. -**DeepLIIF's code follows the PyTorch recommendation to spawn 1 process per GPU** ([doc](https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md#application-process-topologies)). If you want to assign multiple GPUs to each process, you will need to make modifications to DeepLIIF's code (see [doc](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#combine-ddp-with-model-parallelism)). -Despite all the benefits of DDP, one drawback is the extra GPU memory needed for dedicated CUDA buffer for communication. See a short discussion [here](https://discuss.pytorch.org/t/do-dataparallel-and-distributeddataparallel-affect-the-batch-size-and-gpu-memory-consumption/97194/2). In the context of DeepLIIF, this means that there might be situations where you could use a *bigger batch size with DP* as compared to DDP, which may actually train faster than using DDP with a smaller batch size. -You can find more information on DDP [here](https://github.com/nadeemlab/DeepLIIF/blob/main/Multi-GPU%20Training.md). +In short, +- Command `deepliif train` triggers **Data Parallel** (DP). DP is single-process, so **all the GPUs you want to use must be on the same machine** in order for them to be included in the same process. In other words, you cannot distribute the training across multiple GPU machines, unless you write your own code to handle inter-node / inter-machine communication. +- Command `deepliif trainlaunch` triggers **Distributed Data Parallel** (DDP). DDP usually spawns multiple processes and consequently **can be used across machines**. -To launch training using DDP on a local machine, use `deepliif trainlaunch`. Example with 2 GPUs (on 1 machine): +Example commands with 2 GPUs: ``` +deepliif train --dataroot --batch-size 6 --gpu-ids 0 --gpu-ids 1 deepliif trainlaunch --dataroot --batch-size 3 --gpu-ids 0 --gpu-ids 1 --use-torchrun "--nproc_per_node 2" ``` -Note that -1. `batch-size` is defined per process. Since DDP is a single-process method, the `batch-size` you set is the batch size for each process, and the **effective** batch size will be `batch-size` multiplied by the number of processes you started. In the above example, it will be 3 * 2 = 6. -2. You still need to provide **all GPU ids to use** to the training command. Internally, in each process DeepLIIF picks the device using `gpu_ids[local_rank]`. If you provide `--gpu-ids 2 --gpu-ids 3`, the process with local rank 0 will use gpu id 2 and that with local rank 1 will use gpu id 3. -3. `-t 3 --log_dir ` is not required, but is a useful setting in `torchrun` that saves the log from each process to your target log directory. For example: -``` -deepliif trainlaunch --dataroot --batch-size 3 --gpu-ids 0 --gpu-ids 1 --use-torchrun "-t 3 --log_dir --nproc_per_node 2" -``` -4. If your PyTorch is older than 1.10, DeepLIIF calls `torch.distributed.launch` in the backend. Otherwise, DeepLIIF calls `torchrun`. + +### Model Types +In addition to the original DeepLIIF model, the package now supports more model types. Details can be found [here](https://github.com/nadeemlab/DeepLIIF/blob/main/Model%20Types.md). ## Serialize Model -The installed `deepliif` uses Dask to perform inference on the input IHC images. -Before running the `test` command, the model files must be serialized using Torchscript. -To serialize the model files: +The installed `deepliif` package can optionally use serialized model objects to perform inference on the input images. In order to do this, before running the `test` command, the model files need to be serialized using Torchscript: ``` deepliif serialize --model-dir /path/to/input/model/files --output-dir /path/to/output/model/files + --device gpu ``` -* By default, the model files are expected to be located in `DeepLIIF/model-server/DeepLIIF_Latest_Model`. -* By default, the serialized files will be saved to the same directory as the input model files. +* By default, for original DeepLIIF, the model files are expected to be located in `DeepLIIF/model-server/DeepLIIF_Latest_Model`. +* If not specified, the serialized files will be saved to the same directory as the input model files. -## Testing +## Testing / Inference To test the model: ``` deepliif test --input-dir /path/to/input/images @@ -194,19 +213,22 @@ python test.py --dataroot /path/to/input/images --name Model_Name ``` * The latest version of the pretrained models can be downloaded [here](https://zenodo.org/record/4751737#.YKRTS0NKhH4). -* Before running test on images, the model files must be serialized as described above. -* The serialized model files are expected to be located in `DeepLIIF/model-server/DeepLIIF_Latest_Model`. +* The format of input images to `test.py` is the same as training/validation data, while that to `deepliif test` command is only the input modalities (e.g., only IHC for original DeepLIIF). +* Use `deepliif test ... --eager-mode` for the raw model files, or serialize the model files as described above to run the serialized ones. +* For original DeepLIIF, The serialized model files are expected to be located in `DeepLIIF/model-server/DeepLIIF_Latest_Model`. * The test results will be saved to the specified output directory, which defaults to the input directory. * The tile size must be specified and is used to split the image into tiles for processing. The tile size is based on the resolution (scan magnification) of the input image, and the recommended values are a tile size of 512 for 40x images, 256 for 20x, and 128 for 10x. Note that the smaller the tile size, the longer inference will take. -* Testing datasets can be downloaded [here](https://zenodo.org/record/4751737#.YKRTS0NKhH4). +* Testing datasets can be downloaded from [Zenodo](https://zenodo.org/record/4751737#.YKRTS0NKhH4). **Test Command Options:** In addition to the required parameters given above, the following optional parameters are available for `deepliif test`: * `--eager-mode` Run the original model files (instead of serialized model files). * `--seg-intermediate` Save the intermediate segmentation maps created for each modality. * `--seg-only` Save only the segmentation files, and do not infer images that are not needed. +* `--mod-only` Save only the translated modality image; overwrites --seg-only and --seg-intermediate. * `--color-dapi` Color the inferred DAPI image. * `--color-marker` Color the inferred marker image. +* `--BtoA` For models trained with unaligned dataset, this flag instructs the code to load generatorB instead of generatorA. **Whole Slide Image (WSI) Inference:** For translation and segmentation of whole slide images, @@ -246,8 +268,8 @@ on how to deploy the model with Torchserve and for an example of how to run the ## Docker We provide a Dockerfile that can be used to run the DeepLIIF models inside a container. -First, you need to install the [Docker Engine](https://docs.docker.com/engine/install/ubuntu/). -After installing the Docker, you need to follow these steps: +First, you need to install [Docker Engine](https://docs.docker.com/engine/install/ubuntu/). +After installing Docker, you need to follow these steps: * Download the pretrained model [here](https://zenodo.org/record/4751737#.YKRTS0NKhH4) and place them in DeepLIIF/model-server/DeepLIIF_Latest_Model. * To create a docker image from the docker file: ``` diff --git a/cli.py b/cli.py index 8ea7a83..1fc98a9 100644 --- a/cli.py +++ b/cli.py @@ -213,7 +213,10 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd """ assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN','DeepLIIFKD'], f'model class {model} is not implemented' if model in ['DeepLIIF','DeepLIIFKD']: - seg_no = 1 + if seg_gen == True: + seg_no = 1 + else: + seg_no = 0 elif model == 'DeepLIIFExt': if seg_gen: seg_no = modalities_no @@ -223,6 +226,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd seg_no = 0 seg_gen = False + # validation currently is only supported for segmentation results + if seg_gen == False: + with_val = False if model == 'CycleGAN': dataset_mode = "unaligned" @@ -842,7 +848,6 @@ def serialize(model_dir, output_dir, device, epoch, verbose): @click.option('--BtoA', is_flag=True, help='for models trained with unaligned dataset, this flag instructs to load generatorB instead of generatorA') def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids, eager_mode, epoch, seg_intermediate, seg_only, mod_only, color_dapi, color_marker, btoa): - """Test trained models """ output_dir = output_dir or input_dir diff --git a/conftest.py b/conftest.py index 5ac81b0..a74ca13 100644 --- a/conftest.py +++ b/conftest.py @@ -8,12 +8,12 @@ import datetime MODEL_INFO = {'latest':{'model':'DeepLIIF', # cli.py train looks for subfolder "train" under dataroot - 'dir_input_train':['Datasets/Sample_Dataset'], - 'dir_input_testpy':['Datasets/Sample_Dataset'], - 'dir_input_inference':['Datasets/Sample_Dataset/test_cli'], - 'dir_model':['../checkpoints/DeepLIIF_Latest_Model'], - 'modalities_no': [4], - 'seg_gen':[True], + 'dir_input_train':['Datasets/Sample_Dataset','Datasets/Sample_Dataset'], + 'dir_input_testpy':['Datasets/Sample_Dataset','Datasets/Sample_Dataset'], + 'dir_input_inference':['Datasets/Sample_Dataset/test_cli','Datasets/Sample_Dataset/test_cli'], + 'dir_model':['../checkpoints/DeepLIIF_Latest_Model','../checkpoints/deepliif_sdg_test_0225'], + 'modalities_no': [4,4], + 'seg_gen':[True,False], 'tile_size':512}, 'ext':{'model':'DeepLIIFExt', 'dir_input_train':['Datasets/Sample_Dataset_ext_withseg','Datasets/Sample_Dataset_ext_noseg'], diff --git a/deepliif/data/aligned_dataset.py b/deepliif/data/aligned_dataset.py index 33787c6..c177870 100644 --- a/deepliif/data/aligned_dataset.py +++ b/deepliif/data/aligned_dataset.py @@ -51,7 +51,8 @@ def __getitem__(self, index): # split AB image into A and B w, h = AB.size if self.model in ['DeepLIIF','DeepLIIFKD']: - num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image + num_img = self.modalities_no + self.seg_no + self.input_no + # num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image elif self.model == 'DeepLIIFExt': num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel elif self.model == 'SDG': @@ -69,12 +70,19 @@ def __getitem__(self, index): A = A_transform(A) B_Array = [] if self.model in ['DeepLIIF','DeepLIIFKD']: - for i in range(1, num_img): + for i in range(self.input_no, num_img): B = AB.crop((w2 * i, 0, w2 * (i + 1), h)) B = B_transform(B) B_Array.append(B) - - return {'A': A, 'B': B_Array, 'A_paths': AB_path, 'B_paths': AB_path} + if self.input_no > 1: + A_Array = [] + for i in range(self.input_no): + A = AB.crop((w2 * i, 0, w2 * (i+1), h)) + A = A_transform(A) + A_Array.append(A) + return {'A': A_Array, 'B': B_Array, 'A_paths': AB_path, 'B_paths': AB_path} + else: + return {'A': A, 'B': B_Array, 'A_paths': AB_path, 'B_paths': AB_path} elif self.model == 'DeepLIIFExt': for i in range(1, self.modalities_no + 1): B = AB.crop((w2 * i, 0, w2 * (i + 1), h)) diff --git a/deepliif/models/DeepLIIF_model.py b/deepliif/models/DeepLIIF_model.py index da3e6ac..f2f00ed 100644 --- a/deepliif/models/DeepLIIF_model.py +++ b/deepliif/models/DeepLIIF_model.py @@ -18,6 +18,7 @@ def __init__(self, opt): if not hasattr(opt,'net_gs'): opt.net_gs = 'unet_512' + self.seg_gen = opt.seg_gen self.seg_weights = opt.seg_weights self.loss_G_weights = opt.loss_G_weights self.loss_D_weights = opt.loss_D_weights @@ -35,11 +36,13 @@ def __init__(self, opt): for i in range(self.opt.modalities_no): self.loss_names.extend([f'G_GAN_{i+1}', f'G_L1_{i+1}', f'D_real_{i+1}', f'D_fake_{i+1}']) self.visual_names.extend([f'fake_B_{i+1}', f'real_B_{i+1}']) - self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}']) + if self.seg_gen: + self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}']) - for i in range(self.opt.modalities_no+1): - self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}']) # 0 is used for the base input mod - self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'real_B_{self.mod_id_seg}']) + if self.seg_gen: + for i in range(self.opt.modalities_no+1): + self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}']) # 0 is used for the base input mod + self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'real_B_{self.mod_id_seg}']) # specify the images you want to save/display. The training/test scripts will call # specify the models you want to save to the disk. The training/test scripts will call and @@ -51,28 +54,30 @@ def __init__(self, opt): self.model_names.extend([f'G{i}', f'D{i}']) self.model_names_g.append(f'G{i}') - for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod - if self.input_id == '0': - self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}']) - self.model_names_gs.append(f'G{self.mod_id_seg}{i}') - else: - self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}']) - self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}') + if self.seg_gen: + for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod + if self.input_id == '0': + self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}']) + self.model_names_gs.append(f'G{self.mod_id_seg}{i}') + else: + self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}']) + self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}') else: # during test time, only load G self.model_names = [] for i in range(1, self.opt.modalities_no + 1): self.model_names.extend([f'G{i}']) self.model_names_g.append(f'G{i}') - #input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name)) - if self.input_id == '0': - for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod - self.model_names.extend([f'G{self.mod_id_seg}{i}']) - self.model_names_gs.append(f'G{self.mod_id_seg}{i}') - else: - for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod - self.model_names.extend([f'G{self.mod_id_seg}{i+1}']) - self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}') + if self.seg_gen: + #input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name)) + if self.input_id == '0': + for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod + self.model_names.extend([f'G{self.mod_id_seg}{i}']) + self.model_names_gs.append(f'G{self.mod_id_seg}{i}') + else: + for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod + self.model_names.extend([f'G{self.mod_id_seg}{i+1}']) + self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}') # define networks (both generator and discriminator) if isinstance(opt.netG, str): @@ -82,33 +87,41 @@ def __init__(self, opt): for i,model_name in enumerate(self.model_names_g): - setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[i], opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)) + setattr(self,f'net{model_name}',networks.define_G(self.opt.input_nc * self.opt.input_no, opt.output_nc, opt.ngf, opt.netG[i], opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding, opt.upsample)) - # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output - for i,model_name in enumerate(self.model_names_gs): - setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)) + if self.seg_gen: + # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output + # default padding type in define_G: reflect (default padding type in cli/opt is zero - this opt in cli was developed to control padding type in translation generators only) + # default upsample strategy in define_G: convtranspose + for i,model_name in enumerate(self.model_names_gs): + setattr(self,f'net{model_name}',networks.define_G(self.opt.input_nc * self.opt.input_no, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)) if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.model_names_d = [f'D{i+1}' for i in range(self.opt.modalities_no)] - if self.input_id == '0': - self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)] - else: - self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)] for model_name in self.model_names_d: - setattr(self,f'net{model_name}',networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD, + setattr(self,f'net{model_name}',networks.define_D(self.opt.input_nc * self.opt.input_no + opt.output_nc , opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)) - for model_name in self.model_names_ds: - setattr(self,f'net{model_name}',networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, - opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)) + self.model_names_ds = [] + if self.seg_gen: + if self.input_id == '0': + self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)] + else: + self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)] + for model_name in self.model_names_ds: + setattr(self,f'net{model_name}',networks.define_D(self.opt.input_nc * self.opt.input_no + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)) if self.is_train: # define loss functions - self.criterionGAN_BCE = networks.GANLoss('vanilla').to(self.device) - self.criterionGAN_lsgan = networks.GANLoss('lsgan').to(self.device) + # self.criterionGAN_BCE = networks.GANLoss('vanilla').to(self.device) + # self.criterionGAN_lsgan = networks.GANLoss('lsgan').to(self.device) + self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device) + self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device) self.criterionSmoothL1 = torch.nn.SmoothL1Loss() + self.criterionVGG = networks.VGGLoss().to(self.device) # initialize optimizers; schedulers will be automatically created by function . #params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters()) @@ -136,7 +149,6 @@ def __init__(self, opt): self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) - self.criterionVGG = networks.VGGLoss().to(self.device) def set_input(self, input): """ @@ -145,12 +157,18 @@ def set_input(self, input): :param input (dict): include the input image and the output modalities """ - self.real_A = input['A'].to(self.device) + self.real_A = input['A'] + if isinstance(self.real_A, list): # from previous SDG setup where multiple input modalities are allowed + As = [A.to(self.device) for A in self.real_A] + self.real_A = torch.cat(As, dim=1) # shape: 1, (3 x input_no), 512, 512 + else: + self.real_A = self.real_A.to(self.device) self.real_B_array = input['B'] for i in range(self.opt.modalities_no): setattr(self,f'real_B_{i+1}',self.real_B_array[i].to(self.device)) - setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg + if self.opt.seg_gen: + setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg self.image_paths = input['A_paths'] @@ -175,13 +193,14 @@ def forward(self): # torch.mul(self.fake_B_5_4, self.seg_weights[3]), # torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0) - for i,model_name in enumerate(self.model_names_gs): - if i == 0: - setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A)) - else: - setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}'))) - - setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0)) + if self.seg_gen: + for i,model_name in enumerate(self.model_names_gs): + if i == 0: + setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A)) + else: + setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}'))) + + setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0)) def backward_D(self): """Calculate GAN loss for the discriminators""" @@ -203,7 +222,7 @@ def backward_D(self): for i,model_name in enumerate(self.model_names_d): fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1) pred_fake = getattr(self,f'net{model_name}')(fake_AB.detach()) - setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_BCE(pred_fake, False)) + setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_mod(pred_fake, False)) #setattr(self,f'fake_AB_{i+1}',torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)) #setattr(self,f'pred_fake_{i+1}',getattr(self,f'netD{i+1}')(getattr)) @@ -227,19 +246,20 @@ def backward_D(self): # torch.mul(pred_fake_5_4, self.seg_weights[3]), # torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0) - l_pred_fake_seg = [] - for i,model_name in enumerate(self.model_names_ds): - if i == 0: - fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1) - else: - fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1) - - pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach()) - l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i])) - pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0) - - #self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False) - setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, False)) + if self.seg_gen: + l_pred_fake_seg = [] + for i,model_name in enumerate(self.model_names_ds): + if i == 0: + fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1) + else: + fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1) + + pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach()) + l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i])) + pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0) + + #self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False) + setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_seg(pred_fake_seg, False)) # real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1) @@ -260,7 +280,7 @@ def backward_D(self): for i,model_name in enumerate(self.model_names_d): real_AB = torch.cat((self.real_A, getattr(self,f'real_B_{i+1}')), 1) pred_real = getattr(self,f'net{model_name}')(real_AB) - setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_BCE(pred_real, True)) + setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_mod(pred_real, True)) # real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1) # real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1) @@ -281,19 +301,20 @@ def backward_D(self): # torch.mul(pred_real_5_4, self.seg_weights[3]), # torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0) - l_pred_real_seg = [] - for i,model_name in enumerate(self.model_names_ds): - if i == 0: - real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1) - else: - real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1) - - pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i) - l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i])) - pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0) - - #self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True) - setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_real_seg, True)) + if self.seg_gen: + l_pred_real_seg = [] + for i,model_name in enumerate(self.model_names_ds): + if i == 0: + real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1) + else: + real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1) + + pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i) + l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i])) + pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0) + + #self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True) + setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_seg(pred_real_seg, True)) # combine losses and calculate gradients # self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \ @@ -305,7 +326,8 @@ def backward_D(self): self.loss_D = torch.tensor(0., device=self.device) for i in range(self.opt.modalities_no): self.loss_D += (getattr(self,f'loss_D_fake_{i+1}') + getattr(self,f'loss_D_real_{i+1}')) * 0.5 * self.loss_D_weights[i] - self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no] + if self.seg_gen: + self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no] self.loss_D.backward() @@ -330,7 +352,7 @@ def backward_G(self): for i,model_name in enumerate(self.model_names_d): fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1) pred_fake = getattr(self,f'net{model_name}')(fake_AB) - setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_BCE(pred_fake, True)) + setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_mod(pred_fake, True)) # fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) @@ -350,19 +372,20 @@ def backward_G(self): # torch.mul(pred_fake_5_4, self.seg_weights[3]), # torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0) - l_pred_fake_seg = [] - for i,model_name in enumerate(self.model_names_ds): - if i == 0: - fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1) - else: - fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1) - - pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i) - l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i])) - pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0) - - # self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True) - setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, True)) + if self.seg_gen: + l_pred_fake_seg = [] + for i,model_name in enumerate(self.model_names_ds): + if i == 0: + fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1) + else: + fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1) + + pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i) + l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i])) + pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0) + + # self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True) + setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_seg(pred_fake_seg, True)) # Second, G(A) = B # self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1 @@ -373,7 +396,8 @@ def backward_G(self): for i in range(self.opt.modalities_no): setattr(self,f'loss_G_L1_{i+1}',self.criterionSmoothL1(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_L1) - setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1) + if self.seg_gen: + setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1) # self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat # self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat @@ -381,7 +405,8 @@ def backward_G(self): # self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat for i in range(self.opt.modalities_no): setattr(self,f'loss_G_VGG_{i+1}',self.criterionVGG(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_feat) - setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat) + if self.seg_gen: + setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat) # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \ # (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \ @@ -392,7 +417,8 @@ def backward_G(self): self.loss_G = torch.tensor(0., device=self.device) for i in range(self.opt.modalities_no): self.loss_G += (getattr(self,f'loss_G_GAN_{i+1}') + getattr(self,f'loss_G_L1_{i+1}') + getattr(self,f'loss_G_VGG_{i+1}')) * self.loss_G_weights[i] - self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i] + if self.seg_gen: + self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i] # combine loss and calculate gradients # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \ @@ -480,4 +506,3 @@ def calculate_losses(self): self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G - diff --git a/deepliif/models/__init__.py b/deepliif/models/__init__.py index aeea049..4e76d1b 100644 --- a/deepliif/models/__init__.py +++ b/deepliif/models/__init__.py @@ -180,8 +180,11 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'): if opt.modalities_no == 0: net_groups = [(f'G{opt.mod_id_seg}{opt.input_id}',)] else: - net_groups = [(f'G{i+1}', f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}') for i in range(opt.modalities_no)] - net_groups += [(f'G{opt.mod_id_seg}{opt.input_id}',)] # this is the generator for the input base mod + if opt.seg_gen: + net_groups = [(f'G{i+1}', f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}') for i in range(opt.modalities_no)] + net_groups += [(f'G{opt.mod_id_seg}{opt.input_id}',)] # this is the generator for the input base mod + else: + net_groups = [(f'G{i+1}',) for i in range(opt.modalities_no)] elif opt.model in ['DeepLIIFExt','SDG']: if opt.seg_gen: net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)] @@ -288,30 +291,33 @@ def forward(input, model): return model(input.to(next(model.parameters()).device)) if opt.model in ['DeepLIIF','DeepLIIFKD']: - if seg_weights is None: - # weights = { - # 'G51': 0.5, # IHC - # 'G52': 0.0, # Hema - # 'G53': 0.0, # DAPI - # 'G54': 0.0, # Lap2 - # 'G55': 0.5, # Marker - # } - weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': 1/(opt.modalities_no+1) for i in range(opt.modalities_no+1)} - else: - # weights = { - # 'G51': seg_weights[0], # IHC - # 'G52': seg_weights[1], # Hema - # 'G53': seg_weights[2], # DAPI - # 'G54': seg_weights[3], # Lap2 - # 'G55': seg_weights[4], # Marker - # } - weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': seg_weight for i,seg_weight in enumerate(seg_weights)} - + # for seg_gen False, we also use this seg_map dictionary - only the keys though seg_map = {f'G{i+1}': f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}' for i in range(opt.modalities_no)} - if seg_only: - seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0} - + + if opt.seg_gen: + if seg_weights is None: + # weights = { + # 'G51': 0.5, # IHC + # 'G52': 0.0, # Hema + # 'G53': 0.0, # DAPI + # 'G54': 0.0, # Lap2 + # 'G55': 0.5, # Marker + # } + weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': 1/(opt.modalities_no+1) for i in range(opt.modalities_no+1)} + else: + # weights = { + # 'G51': seg_weights[0], # IHC + # 'G52': seg_weights[1], # Hema + # 'G53': seg_weights[2], # DAPI + # 'G54': seg_weights[3], # Lap2 + # 'G55': seg_weights[4], # Marker + # } + weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': seg_weight for i,seg_weight in enumerate(seg_weights)} + + if seg_only: + seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0} + lazy_gens = {k: forward(ts, nets[k]) for k in seg_map} if 'Marker' in opt.modalities_names: mod_id_marker = opt.modalities_names.index("Marker") @@ -320,7 +326,7 @@ def forward(input, model): gens = compute(lazy_gens)[0] - if not mod_only: + if opt.seg_gen and not mod_only: lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()} # run seg generator for the base input if weights[f'G{opt.mod_id_seg}{opt.input_id}'] != 0: @@ -332,7 +338,7 @@ def forward(input, model): seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0) if output_tensor: - if mod_only: + if mod_only or not opt.seg_gen: res = gens elif seg_only and opt.modalities_no > 0: res = {f'G{opt.modalities_no}': gens[f'G{opt.modalities_no}']} if f'G{opt.modalities_no}' in gens else {} @@ -342,7 +348,7 @@ def forward(input, model): res[f'G{opt.mod_id_seg}'] = seg else: - if mod_only: + if mod_only or not opt.seg_gen: res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()} elif seg_only and opt.modalities_no > 0: res = {f'G{opt.modalities_no}': tensor_to_pil(gens[f'G{opt.modalities_no}'].to(torch.device('cpu')))} if f'G{opt.modalities_no}' in gens else {} @@ -399,7 +405,7 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt= f'G{opt.modalities_no}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[-1]), f'G{opt.mod_id_seg}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)), } - elif mod_only: + elif mod_only or not opt.seg_gen: res = {f'G{i+1}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[i]) for i in range(opt.modalities_no)} else : @@ -464,11 +470,17 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, for k,v in opt_args.items(): setattr(opt,k,v) #print_options(opt) + + if hasattr(opt,'seg_gen') and opt.seg_gen == False: + if seg_only == True or return_seg_intermediate == True: + seg_only = False + return_seg_intermediate = False + print('option seg_gen is False, disabled seg_only and return_seg_intermediate') run_fn = run_torchserve if use_torchserve else run_dask - if opt.model == 'SDG': - # SDG could have multiple input images/modalities, hence the input could be a rectangle. + if opt.input_no > 1 or opt.model == 'SDG': + # models such as SDG could have multiple input images/modalities, hence the input could be a rectangle. # We split the input to get each modality image then create tiles for each set of input images. w, h = int(img.width / opt.input_no), img.height orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)] @@ -485,10 +497,10 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, if opt.model in ['DeepLIIF','DeepLIIFKD']: # check if both the elements and the order are exactly the same l_modname = [f'mod{i+1}' for i in range(opt.modalities_no)] - if l_modname != opt.modalities_names[1:]: + if l_modname != opt.modalities_names[opt.input_no:]: # if not, append modalities_names to mod names - l_modname = [f'mod{i+1}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names[1:])] - d_modname2id = {mod_name:f'G{i+1}' for i,mod_name in enumerate(l_modname)} + l_modname = [f'mod{i+1}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names[opt.input_no:])] + d_modname2id = {mod_name:f'G{i+1}' for i,mod_name in enumerate(l_modname)} if opt.seg_gen: l_modname_seg = [f'mod{i}' for i in range(opt.modalities_no+1)] @@ -500,11 +512,9 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, else: d_modname2id_seg = {mod_name:f'G{opt.mod_id_seg}{i+1}' for i,mod_name in enumerate(l_modname_seg)} - if not mod_only: + if not mod_only and opt.seg_gen: d_modname2id['Seg'] = f'G{opt.mod_id_seg}' - #print('d_modname2id:',d_modname2id) - if seg_only: images = {'Seg': results[d_modname2id['Seg']]} marker_key = find_marker_key(d_modname2id) @@ -526,7 +536,7 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, # images['Seg'] = results[f'G{opt.modalities_no+1}'] images = {mod_name: results[mod_id] for mod_name,mod_id in d_modname2id.items()} - if return_seg_intermediate and not seg_only: + if opt.seg_gen and return_seg_intermediate and not seg_only: # images.update({'IHC_s':results['G51'], # 'Hema_s':results['G52'], # 'DAPI_s':results['G53'], diff --git a/deepliif/options/__init__.py b/deepliif/options/__init__.py index 9f42a3d..4bc2bcb 100644 --- a/deepliif/options/__init__.py +++ b/deepliif/options/__init__.py @@ -91,6 +91,8 @@ def __init__(self, d_params=None, path_file=None, mode='train'): if self.model in ['DeepLIIF','DeepLIIFKD']: self.mod_id_seg, self.input_id = init_input_and_mod_id(self, os.path.dirname(path_file)) + if hasattr(self,'seg_gen') and self.seg_gen == False: + self.mod_id_seg = None self.input_id = int(self.input_id) print('mod id seg:', self.mod_id_seg, '; input id:', self.input_id) @@ -99,7 +101,7 @@ def __init__(self, d_params=None, path_file=None, mode='train'): if not hasattr(self,'modalities_names'): self.modalities_names = ['IHC','Hema','DAPI','Lap2','Marker'] self.seg_weights = [0.5,0,0,0,0.5] - elif not hasattr(self,'modalities_names') or len(self.modalities_names)==0: + if not hasattr(self,'modalities_names') or len(self.modalities_names)==0: # if self.model == 'DeepLIIFKD': # # try find the modalities names from the teacher model # d_params_teacher = read_model_params(os.path.join(self.model_dir_teacher,'train_opt.txt')) @@ -107,7 +109,7 @@ def __init__(self, d_params=None, path_file=None, mode='train'): # self.modalities_names = d_params_teacher['modalities_names'] # # check again # if not hasattr(self,'modalities_names') or len(self.modalities_names)==0: - self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)] + self.modalities_names = [f'input{i+1}' for i in range(self.input_no)] + [f'mod{i+1}' for i in range(self.modalities_no)] else: self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)] @@ -133,7 +135,7 @@ def __init__(self, d_params=None, path_file=None, mode='train'): # to account for old settings: same as in cli.py train if not hasattr(self,'seg_no'): - if self.model == 'DeepLIIF': + if self.model == 'DeepLIIF': # for newer models where deepliif accept seg_gen=False (2026feb), seg_no should exist so this auto-determination will not be triggered self.seg_no = 1 self.seg_gen = True elif self.model == 'DeepLIIFExt': @@ -172,6 +174,10 @@ def __init__(self, d_params=None, path_file=None, mode='train'): # weights of the modalities used to calculate the final loss self.loss_G_weights = [1 / self.modalities_no] * self.modalities_no if not hasattr(self,'loss_G_weights') else self.loss_G_weights self.loss_D_weights = [1 / self.modalities_no] * self.modalities_no if not hasattr(self,'loss_D_weights') else self.loss_D_weights + + # upsample + if not hasattr(self,'upsample'): + self.upsample = 'convtranspose' diff --git a/deepliif/scripts/train.py b/deepliif/scripts/train.py index 3b61e2d..e28c6bb 100644 --- a/deepliif/scripts/train.py +++ b/deepliif/scripts/train.py @@ -202,7 +202,10 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd """ assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN','DeepLIIFKD'], f'model class {model} is not implemented' if model in ['DeepLIIF','DeepLIIFKD']: - seg_no = 1 + if seg_gen == True: + seg_no = 1 + else: + seg_no = 0 elif model == 'DeepLIIFExt': if seg_gen: seg_no = modalities_no @@ -212,6 +215,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd seg_no = 0 seg_gen = False + # validation currently is only supported for segmentation results + if seg_gen == False: + with_val = False if model == 'CycleGAN': dataset_mode = "unaligned" @@ -294,6 +300,10 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd modalities_names = [name.strip() for name in modalities_names.split(',') if len(name) > 0] assert len(modalities_names) == 0 or len(modalities_names) == input_no + modalities_no, f'--modalities-names has {len(modalities_names)} entries ({modalities_names}), expecting 0 or {input_no + modalities_no} entries' + if len(modalities_names) == 0 and model == 'DeepLIIFKD': + # inherit this property from teacher model + opt_teacher = get_opt(model_dir_teacher, mode='test') + modalities_names = opt_teacher.modalities_names d_params['input_no'] = input_no d_params['modalities_names'] = modalities_names diff --git a/deepliif/stat/__init__.py b/deepliif/stat/__init__.py new file mode 100644 index 0000000..6a8f703 --- /dev/null +++ b/deepliif/stat/__init__.py @@ -0,0 +1,96 @@ +import os + +from PIL import Image +import json +from ..models import postprocess +import re + +def get_cell_count_metrics(dir_seg, + dir_input = None, + dir_save = None, + model = "DeepLIIF", + tile_size = 512, + single_tile = False, + use_marker = False, + suffix_seg = '5', + suffix_marker = '4', + save_individual = False): + """ + Obtain cell count metrics through postprocess functions. + Currenlty implemented only for ground truth tiles. + Eligible for data used in model type: + DeepLIIF (with segmentation task) + DeepLIIFKD (with segmentation task) + + dir_seg: directory to find segmentation (and marker if specified) images + dir_input: directory to find input tile images, only used when single_tile + is True and the filenames are expected to be .png; if + not specified, default to dir_seg + dir_save: directory to save the results out; if not specified, default to + dir_seg + model: model type for postprocess function to understand how to handle the + data (DeepLIIF, DeepLIIFExt) + tile_size: tile size used for postprocess calculation + single_tile: True if the images are single-tile images, and image name + should follow _.png; use False if the + images consisting of a row of multiple tiles like those used in + training or validation (in this case the segmentation tile is + the last one or several tiles) + use_marker: whether to use the marker image (if True, assumes the marker + image is the second last tile (single_tile=False) or has a + suffix of (single_tile=True)) + suffix_seg: filename suffix for segmentation images if single_tile is True + suffix_marker: filename suffix for marker images if single_tile is True + save_individual: save cell count statistics for each individual image + """ + dir_save = dir_save if dir_seg is None else dir_save + + if single_tile: + fns = [x for x in os.listdir(dir_seg) if x.endswith(f'_{suffix_seg}.png') or x.endswith(f'_{suffix_marker}.png')] + fns = list(set(['_'.join(x.split('_')[:-1]) for x in fns])) # fns do not have extention and mod suffix + else: + fns = [x for x in os.listdir(dir_seg) if x.endswith('.png')] # fns have extension + + d_metrics = {} + count = 0 + for fn in fns: + if single_tile: + dir_input = dir_input if dir_seg is None else dir_input + img_gt = Image.open(os.path.join(dir_seg, fn + f'_{suffix_seg}.png')) + img_marker = Image.open(os.path.join(dir_seg, fn + f'_{suffix_marker}.png')) + img_input = Image.open(os.path.join(dir_input,fn+'.png')) + k = fn + else: + img = Image.open(os.path.join(dir_seg,fn)) + w, h = img.size + + # assume in the row of tiles, the first is the input and the last is the ground truth + img_input = img.crop((0,0,h,h)) + img_gt = img.crop((w-h,0,w,h)) + img_marker = img.crop((w-h*2,0,w-h,h)) # the second last is marker, if marker is included + k = os.path.splitext(fn)[0] #re.sub('\..*?$','',fn) # remove extension + + images = {'Seg':img_gt} + if use_marker: + images['Marker'] = img_marker + + post_images, scoring = postprocess(img_input, images, tile_size, model) + d_metrics[k] = scoring + + if save_individual: + with open(os.path.join( + dir_save, + k+'.json' + ), 'w') as f: + json.dump(scoring, f, indent=2) + + count += 1 + + if count % 100 == 0 or count == len(fns): + print(count,'/',len(fns)) + + with open(os.path.join( + dir_save, + 'metrics.json' + ), 'w') as f: + json.dump(d_metrics, f, indent=2) diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py index 80d8118..bfa6b6d 100644 --- a/tests/test_cli_train.py +++ b/tests/test_cli_train.py @@ -175,27 +175,28 @@ def test_cli_train_single_gpu_netgs(tmp_path, model_info, foldername_suffix): def test_cli_train_single_gpu_withval(tmp_path, model_info, foldername_suffix): - if available_gpus > 0 and model_info["model"] not in ['CycleGAN']: + if available_gpus > 0 and True in model_info["seg_gen"]: torch.cuda.nvtx.range_push("test_cli_train_single_gpu_withval") dirs_input = model_info['dir_input_train'] for i in range(len(dirs_input)): - torch.cuda.nvtx.range_push(f"test_cli_train_single_gpu {dirs_input[i]}") - dir_save = tmp_path - - fns_input = [f for f in os.listdir(dirs_input[i] + '/train' + foldername_suffix) if os.path.isfile(os.path.join(dirs_input[i] + '/train' + foldername_suffix, f)) and f.endswith('png')] - num_input = len(fns_input) - assert num_input > 0 - - test_param = '--gpu-ids 0 --with-val' - cmd = CMD_BASIC.format(model=model_info["model"], dataroot=dirs_input[i], - modalities_no=model_info["modalities_no"][i], - seg_gen=model_info["seg_gen"][i], dir_save=dir_save) - cmd += f' {test_param}' - if model_info["model"] in ['DeepLIIFKD']: - cmd += CMD_KD.format(model_dir_teacher=model_info['model_dir_teacher'][i]) - res = subprocess.run(cmd,shell=True) - assert res.returncode == 0 - torch.cuda.nvtx.range_pop() + if model_info["seg_gen"][i]: + torch.cuda.nvtx.range_push(f"test_cli_train_single_gpu {dirs_input[i]}") + dir_save = tmp_path + + fns_input = [f for f in os.listdir(dirs_input[i] + '/train' + foldername_suffix) if os.path.isfile(os.path.join(dirs_input[i] + '/train' + foldername_suffix, f)) and f.endswith('png')] + num_input = len(fns_input) + assert num_input > 0 + + test_param = '--gpu-ids 0 --with-val' + cmd = CMD_BASIC.format(model=model_info["model"], dataroot=dirs_input[i], + modalities_no=model_info["modalities_no"][i], + seg_gen=model_info["seg_gen"][i], dir_save=dir_save) + cmd += f' {test_param}' + if model_info["model"] in ['DeepLIIFKD']: + cmd += CMD_KD.format(model_dir_teacher=model_info['model_dir_teacher'][i]) + res = subprocess.run(cmd,shell=True) + assert res.returncode == 0 + torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() else: pytest.skip(f'Detected {available_gpus} (< 1) available GPUs. Skip.')