diff --git a/rdock-utils/pyproject.toml b/rdock-utils/pyproject.toml index a3d96bde..fe3b7f11 100644 --- a/rdock-utils/pyproject.toml +++ b/rdock-utils/pyproject.toml @@ -6,10 +6,12 @@ description = "Utilities for working with RDock and operating on SD files" requires-python = ">=3.10.0" [tool.setuptools.dynamic] -dependencies = {file = ["requirements.txt"]} -optional-dependencies = { dev = {file = ["requirements-dev.txt"]} } +dependencies = { file = ["requirements.txt"] } +optional-dependencies = { dev = { file = ["requirements-dev.txt"] } } [project.scripts] +rbhtfinder = "rdock_utils.rbhtfinder:main" +rbhtfinder_old = "rdock_utils.rbhtfinder_original_copy:main" sdfield = "rdock_utils.sdfield:main" sdrmsd_old = "rdock_utils.sdrmsd_original:main" sdrmsd = "rdock_utils.sdrmsd.main:main" @@ -26,11 +28,16 @@ Repository = "https://github.com/CBDD/rDock.git" [tool.ruff] line-length = 119 target-version = "py312" -exclude = [".git", "__pycache__", "rdock_utils/sdrmsd_original.py", "rdock_utils/sdtether_original.py"] +exclude = [ + ".git", + "__pycache__", + "rdock_utils/sdrmsd_original.py", + "rdock_utils/sdtether_original.py", +] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I"] -ignore = ["E231","E501","E203"] +ignore = ["E231", "E501", "E203"] [tool.ruff.format] quote-style = "double" @@ -67,4 +74,12 @@ no_implicit_reexport = false strict_equality = true -exclude = ["build/*", "rdock_utils/sdrmsd_original.py", "tests/", "rdock_utils/sdtether_original.py"] +exclude = [ + "build/*", + "rdock_utils/sdrmsd_original.py", + "tests/", + "rdock_utils/sdtether_original.py", + "rdock_utils/rbhtfinder_original_copy.py", +] + +plugins = "numpy.typing.mypy_plugin" diff --git a/rdock-utils/rdock_utils/common/__init__.py b/rdock-utils/rdock_utils/common/__init__.py index 8f9411a0..ca2b713b 100644 --- a/rdock-utils/rdock_utils/common/__init__.py +++ b/rdock-utils/rdock_utils/common/__init__.py @@ -2,11 +2,21 @@ from .SDFParser import FastSDMol, molecules_with_progress_log, read_molecules, read_molecules_from_all_inputs from .superpose3d import MolAlignmentData, Superpose3D, update_coordinates from .types import ( + Array1DFloat, + Array1DInt, + Array1DStr, + Array2DFloat, + Array3DFloat, AtomsMapping, AutomorphismRMSD, + ColumnNamesArray, CoordsArray, + FilterCombination, FloatArray, + InputData, Matrix3x3, + MinScoreIndices, + SDReportArray, SingularValueDecomposition, Superpose3DResult, Vector3D, @@ -25,11 +35,21 @@ "MolAlignmentData", "Superpose3D", # -- types -- + "Array1DFloat", + "Array1DInt", + "Array1DStr", + "Array2DFloat", + "Array3DFloat", + "AtomsMapping", "AutomorphismRMSD", + "ColumnNamesArray", "CoordsArray", + "FilterCombination", "FloatArray", - "AtomsMapping", + "InputData", "Matrix3x3", + "MinScoreIndices", + "SDReportArray", "SingularValueDecomposition", "Superpose3DResult", "Vector3D", diff --git a/rdock-utils/rdock_utils/common/types.py b/rdock-utils/rdock_utils/common/types.py index f972fd6e..392df4b7 100644 --- a/rdock-utils/rdock_utils/common/types.py +++ b/rdock-utils/rdock_utils/common/types.py @@ -1,7 +1,10 @@ from typing import Any import numpy +import numpy.typing +# TODO: Review common types for all rdock_utils scripts +# SDRMSD types FloatArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] CoordsArray = numpy.ndarray[Any, numpy.dtype[numpy.float64]] AutomorphismRMSD = tuple[float, CoordsArray | None] @@ -11,8 +14,20 @@ Superpose3DResult = tuple[CoordsArray, float, Matrix3x3] AtomsMapping = tuple[tuple[int, int], ...] -## Shape support for type hinting is not yet avaialable in numpy -## let's keep this as a guide for numpy 2.0 release +# RBHTFinder types +SDReportArray = numpy.ndarray[list[int | str | float], numpy.dtype[numpy.object_]] +Array1DFloat = numpy.typing.NDArray[numpy.float_] +Array2DFloat = numpy.typing.NDArray[numpy.float_] +Array3DFloat = numpy.typing.NDArray[numpy.float_] +Array1DStr = numpy.typing.NDArray[numpy.str_] +Array1DInt = numpy.typing.NDArray[numpy.int_] +ColumnNamesArray = Array1DStr | list[str] +InputData = tuple[SDReportArray, ColumnNamesArray] +MinScoreIndices = dict[int, Array1DInt] +FilterCombination = tuple[float, float] + +## Shape support for type hinting is not yet avaialable in np +## let's keep this as a guide for np 2.0 release # FloatArray = numpy.ndarray[Literal["N"], numpy.dtype[float]] # BoolArray = numpy.ndarray[Literal["N"], numpy.dtype[bool]] # CoordsArray = numpy.ndarray[Literal["N", 3], numpy.dtype[float]] diff --git a/rdock-utils/rdock_utils/rbhtfinder/__init__.py b/rdock-utils/rdock_utils/rbhtfinder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rdock-utils/rdock_utils/rbhtfinder/main.py b/rdock-utils/rdock_utils/rbhtfinder/main.py new file mode 100644 index 00000000..68264f20 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/main.py @@ -0,0 +1,12 @@ +from .parser import get_config +from .rbhtfinder import RBHTFinder + + +def main(argv: list[str] | None = None) -> None: + config = get_config(argv) + rbhtfinder = RBHTFinder(config) + rbhtfinder.run() + + +if __name__ == "__main__": + main() diff --git a/rdock-utils/rdock_utils/rbhtfinder/parser.py b/rdock-utils/rdock_utils/rbhtfinder/parser.py new file mode 100644 index 00000000..b0a9c5a9 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/parser.py @@ -0,0 +1,145 @@ +import argparse +from dataclasses import dataclass + +from .schemas import Filter + + +@dataclass +class RBHTFinderConfig: + input: str + output: str + threshold: str | None + name: int + filters: list[Filter] + validation: int + header: bool + max_time: float + min_percentage: float + cpu_count: int + + def __post_init__(self) -> None: + self.filters = self.get_parsed_filters() + + def get_parsed_filters(self) -> list[Filter]: + filter_args: list[str] = self.filters # type: ignore + parsed_filters = [self._parse_filter(arg) for arg in filter_args] + # sort filters by step at which they are applied + parsed_filters.sort(key=lambda filter: filter.steps) + return parsed_filters + + @staticmethod + def _parse_filter(argument: str) -> Filter: + parsed_filter = Filter() + + for item in argument.split(","): + key, value = item.split("=") + setattr(parsed_filter, key, float(value) if key in ("interval", "min", "max") else int(value)) + # User inputs with 1-based numbering whereas python uses 0-based + parsed_filter.column -= 1 + parsed_filter.interval = parsed_filter.interval or 1.0 + return parsed_filter + + +def get_parser() -> argparse.ArgumentParser: + description = """ + Estimate the results and computation time of an rDock high-throughput protocol. + + Steps: + 1. Perform exhaustive docking of a small representative part of the entire library. + 2. Store the result of sdreport -t from that exhaustive docking run in a file + , which will be the input of this script. + 3. Run rbhtfinder, specifying -i and an arbitrary number of filters + using the -f option, for example, "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". + This example would simulate the effect of applying thresholds on column 6 after 5 poses + have been generated, for values between 0.5 and 1.0 (i.e., 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). + More than one threshold can be specified, e.g., + "-f column=4,steps=5,min=-12,max=-10,interval=1 column=4,steps=15,min=-16,max=-15,interval=1" + will test the following combinations of thresholds on column 4: + 5 -10 15 -15 + 5 -11 15 -15 + 5 -12 15 -15 + 5 -10 15 -16 + 5 -11 15 -16 + 5 -12 15 -16 + The number of combinations will increase very rapidly, the more filters are used and the + larger the range of values specified for each. It may be sensible to run rbhtfinder several + times to explore the effects of various filters independently. + + Output: + The output of the program consists of the following columns: + FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME + SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 + SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 + The four columns are repeated for each filter specified with the -f option: + name of the column on which the filter is applied (FILTER1), + number of steps at which the threshold is applied (NSTEPS1), + value of the threshold (THR1) + and the percentage of poses which pass this filter (PERC1). + Additional filters (FILTER2, FILTER3 etc.) are listed in the order that they are applied + (i.e., by NSTEPS). + + The final columns provide some overall statistics for the combination of thresholds + specified in a row. TOP500_SCORE.INTER gives the percentage of the top-scoring 500 poses, + measured by SCORE.INTER, from the whole of which are retained after the + thresholds are applied. This can be contrasted with the final PERC column. The higher the + ratio (the 'enrichment factor'), the better the combination of thresholds. If thresholds are + applied on multiple columns, this column will be duplicated for each, e.g. TOP500_SCORE.INTER + and TOP500_SCORE.RESTR will give the percentage of the top-scoring poses retained for both of + these scoring methods. The exact number of poses used for this validation can be changed from + the default 500 using the --validation flag. + ENRICH_SCORE.INTER gives the enrichment factor as a quick rule-of-thumb to assess the best + choice of thresholds. The final column TIME provides an estimate of the time taken to perform + docking, as a proportion of the time taken for exhaustive docking. This value should be below + 0.1. + + After a combination of thresholds has been selected, they need to be encoded into a threshold + file which rDock can use as an input. rbhtfinder attempts to help with this task by + automatically selecting a combination and writing a threshold file. The combination chosen is + that which provides the highest enrichment factor, after all options with a TIME value over + 0.1 are excluded. This choice should not be blindly followed, so the threshold file should be + considered a template that the user modifies as needed. + + Requirements: + rbhtfinder requires NumPy. Installation of Pandas is recommended, but optional; if Pandas is + not available, loading the input file for calculations will be considerably slower. + """ + input_help = "Input from sdreport (tabular separated format)." + output_help = "Output file for report on threshold combinations." + threshold_help = "Threshold file used by rDock as input." + name_help = "Index of column containing the molecule name (0 indexed). Default is 1." + filter_help = "Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given." + validation_help = "Top-scoring N molecules from input to use for validating threshold combinations. Default 500." + header_help = "Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5." + max_time_help = "Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take." + min_perc_help = "Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1." + cpu_count_help = "Specify the number of CPU cores to use for multiprocessing. Defaults to '1' if not provided." + + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-i", "--input", help=input_help, type=str, required=True) + parser.add_argument("-o", "--output", help=output_help, type=str, required=True) + parser.add_argument("-t", "--threshold", help=threshold_help, type=str) + parser.add_argument("-n", "--name", type=int, default=1, help=name_help) + parser.add_argument("-f", "--filters", nargs="+", type=str, help=filter_help, required=True) # Review 'required' + parser.add_argument("-v", "--validation", type=int, default=500, help=validation_help) + parser.add_argument("-c", "--cpu-count", type=int, default=1, help=cpu_count_help) + parser.add_argument("--header", action="store_true", help=header_help) + parser.add_argument("--max-time", type=float, default=0.1, help=max_time_help) + parser.add_argument("--min-perc", type=float, default=1.0, help=min_perc_help) + return parser + + +def get_config(argv: list[str] | None = None) -> RBHTFinderConfig: + parser = get_parser() + args = parser.parse_args(argv) + return RBHTFinderConfig( + input=args.input, + output=args.output, + threshold=args.threshold, + name=args.name, + filters=args.filters, + validation=args.validation, + header=args.header, + max_time=args.max_time, + min_percentage=args.min_perc, + cpu_count=args.cpu_count, + ) diff --git a/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py new file mode 100644 index 00000000..a3b37951 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/rbhtfinder.py @@ -0,0 +1,361 @@ +import itertools +import logging +import multiprocessing +from collections import Counter, defaultdict +from functools import partial + +import numpy + +from rdock_utils.common import ( + Array1DFloat, + Array1DInt, + Array2DFloat, + Array3DFloat, + ColumnNamesArray, + FilterCombination, + InputData, + MinScoreIndices, + SDReportArray, +) + +from .parser import RBHTFinderConfig +from .schemas import Filter, FilterCombinationResult, MinMaxValues + +try: + import pandas +except ImportError: + PANDAS_IS_AVAILABLE = False +else: + PANDAS_IS_AVAILABLE = True + +logger = logging.getLogger("RBHTFinder") + + +class RBHTFinder: + def __init__(self, config: RBHTFinderConfig) -> None: + self.config = config + + def run(self) -> None: + filters_combinations = self.generate_filters_combinations(self.config.filters) + print(f"{len(filters_combinations)} combinations of filters calculated.") + distinct_combinations = self.remove_redundant_combinations(filters_combinations, self.config.filters) + + if len(distinct_combinations) == 0: + raise RuntimeError("No filter combinations could be calculated - check the thresholds specified.") + + print( + f"{len(distinct_combinations)} combinations of filters remain after removal of redundant combinations. " + "Starting calculations..." + ) + sdreport_array, column_names = self.read_data() + print(f"First few rows of the input array:\n{sdreport_array[:5]}") + print("Data read in from input file.") + # Convert to 3D array (molecules x poses x columns) + molecule_array = self.prepare_array(sdreport_array, self.config.name) + results = self.process_filter_combinations(molecule_array, distinct_combinations) + self.write_output(results, column_names) + self.handle_threshold(results, distinct_combinations, column_names, molecule_array.shape[1]) + + def generate_filters_combinations(self, filters: list[Filter]) -> list[FilterCombination]: + filter_ranges = ((filter.min, filter.max + filter.interval, filter.interval) for filter in filters) + combinations = (numpy.arange(*range) for range in filter_ranges) + filters_combinations = list(itertools.product(*combinations)) + return filters_combinations + + def remove_redundant_combinations( + self, all_combinations: list[FilterCombination], filters: list[Filter] + ) -> Array2DFloat: + all_combinations_array = numpy.array(all_combinations) + columns = [filter.column for filter in filters] + indices_per_col = {col: [i for i, c in enumerate(columns) if c == col] for col in set(columns)} + # Create a mask to keep only valid combinations + mask = numpy.ones(len(all_combinations_array), dtype=bool) + + for indices in indices_per_col.values(): + col_data = all_combinations_array[:, indices] + sorted_data = numpy.sort(col_data, axis=1)[:, ::-1] # Sort descending + is_valid = numpy.all(col_data == sorted_data, axis=1) # Check if sorted matches original + is_unique = numpy.apply_along_axis(lambda x: len(set(x)) == len(x), 1, col_data) + mask &= is_valid & is_unique + + valid_combinations: Array2DFloat = all_combinations_array[mask] + return valid_combinations + + def read_data(self) -> InputData: + if PANDAS_IS_AVAILABLE: + data_array, column_names = self.read_data_using_pandas() + else: + logging.warning("Pandas is not available to read the data") + data_array, column_names = self.read_data_using_numpy() + return data_array, column_names + + def read_data_using_pandas(self) -> InputData: + sdreport_dataframe = pandas.read_csv(self.config.input, sep="\t", header=0 if self.config.header else None) + + if self.config.header: + column_names = sdreport_dataframe.columns.values + else: + # Use index names; add 1 to deal with zero-based numbering + column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] + + sdreport_array = sdreport_dataframe.values + return sdreport_array, column_names + + def read_data_using_numpy(self) -> InputData: + np_array = numpy.loadtxt(self.config.input, dtype=str) + + if self.config.header: + column_names = np_array[0] + sdreport_array = np_array[1:] + else: + column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] + sdreport_array = np_array + + return sdreport_array, column_names + + def prepare_array(self, data_array: SDReportArray, name_column: int) -> Array3DFloat: + """ + Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses + """ + split_indices = ( + numpy.where( + data_array[:, name_column] != numpy.hstack((data_array[1:, name_column], data_array[0, name_column])) + )[0] + + 1 + ) + split_array = numpy.split(data_array, split_indices) + modal_shape = Counter([array.shape for array in split_array]).most_common(1)[0] + number_of_poses = modal_shape[0][0] # Find modal number of poses per molecule in the array + valid_split_arrays = [ + numpy.array_split(array, array.shape[0] / number_of_poses) # type: ignore + for array in split_array + if not array.shape[0] % number_of_poses and array.shape[0] + ] + flattened_split_array = numpy.concatenate(valid_split_arrays) + + if len(flattened_split_array) * number_of_poses < data_array.shape[0] * 0.99: + message = ( + "WARNING: The number of poses provided per molecule is inconsistent. " + f"Only {len(flattened_split_array)} of {int(data_array.shape[0] / number_of_poses)} molecules have {number_of_poses} poses." + ) + logger.warning(message) + + molecule_3d_array = numpy.array(flattened_split_array) + # Overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_3d_array[:, :, name_column] = 0 + final_array = molecule_3d_array.astype(float) + return final_array + + def process_filter_combinations( + self, molecule_array: Array3DFloat, distinct_combinations: Array2DFloat + ) -> list[FilterCombinationResult]: + # Find the top scoring compounds for validation of the filter combinations + columns = set(filter.column for filter in self.config.filters) + min_score_indices = { + column: numpy.argpartition(numpy.min(molecule_array[:, :, column], axis=1), self.config.validation)[ + : self.config.validation + ] + for column in columns + } + with multiprocessing.Pool(self.config.cpu_count) as pool: + process_combination = partial( + self.calculate_results_for_filter_combination, + molecule_array=molecule_array, + min_score_indices=min_score_indices, + ) + results = pool.map(process_combination, distinct_combinations) + return results + + def calculate_results_for_filter_combination( + self, filter_combination: Array2DFloat, molecule_array: Array3DFloat, min_score_indices: MinScoreIndices + ) -> FilterCombinationResult: + """ + For a particular combination of filters, calculate the percentage of molecules that will be filtered, + the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + """ + num_molecules = molecule_array.shape[0] + num_steps = molecule_array.shape[1] + # Passing_molecule_indices is a list of indices of molecules which have passed the applied filters. + # As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + passing_molecule_indices = numpy.arange(num_molecules) + filter_percentages = [] + number_of_simulated_poses = 0 # Number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + + for i, threshold in enumerate(filter_combination): + number_of_simulated_poses += self.calculate_simulated_poses_increment(i, passing_molecule_indices) + column = self.config.filters[i].column + step = self.config.filters[i].steps + passing_indices = self.apply_threshold(molecule_array, column, step, threshold) + # All mols which pass the threshold and which were already in passing_molecule_indices, i.e. passed all previous filters + passing_molecule_indices = numpy.intersect1d(passing_molecule_indices, passing_indices, assume_unique=True) + filter_percentages.append(len(passing_molecule_indices) / num_molecules) + + number_of_simulated_poses += len(passing_molecule_indices) * (num_steps - self.config.filters[-1].steps) + perc_val = { + k: len(numpy.intersect1d(v, passing_molecule_indices, assume_unique=True)) / self.config.validation + for k, v in min_score_indices.items() + } + time = float(number_of_simulated_poses / numpy.prod(molecule_array.shape[:2])) + result = FilterCombinationResult( + combination=filter_combination, + perc_val=perc_val, + percentages=filter_percentages, + time=time, + ) + return result + + def calculate_simulated_poses_increment(self, index: int, passing_molecule_indices: Array1DInt) -> int: + if index: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + increment = len(passing_molecule_indices) * ( + self.config.filters[index].steps - self.config.filters[index - 1].steps + ) + else: + increment = len(passing_molecule_indices) * self.config.filters[index].steps + return increment + + def apply_threshold(self, scored_poses: Array3DFloat, column: int, steps: int, threshold: float) -> Array1DInt: + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # Minimum score after `steps` per molecule + mins = numpy.min(scored_poses[:, :steps, column], axis=1) + # Return those molecules where the minimum score is less than the threshold + passing_molecules = numpy.where(mins < threshold)[0] + return passing_molecules + + def write_output( + self, results: list[FilterCombinationResult], column_names: ColumnNamesArray, sep: str = "\t", end: str = "\n" + ) -> None: + """ + Print results as a table. The number of columns varies depending how many columns the user picked. + """ + with open(self.config.output, "w") as f: + header = self.get_output_header(results[0], column_names) + f.write(sep.join(header) + end) + content_lines = [sep.join(self.get_output_content(result, column_names)) + end for result in results] + f.writelines(content_lines) + + def get_output_header(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: + header = [] + for i in range(len(result.combination)): + header.extend([f"FILTER{i + 1}", f"NSTEPS{i + 1}", f"THR{i + 1}", f"PERC{i + 1}"]) + + for col_index in result.perc_val.keys(): + column_name = column_names[col_index] + header.extend([f"TOP{self.config.validation}_{column_name}", f"ENRICH_{column_name}"]) + + header.append("TIME") + + return header + + def get_output_content(self, result: FilterCombinationResult, column_names: ColumnNamesArray) -> list[str]: + content = [] + + for i, threshold in enumerate(result.combination): + column_name = column_names[self.config.filters[i].column] + steps = self.config.filters[i].steps + filter_percentage = result.percentages[i] * 100 + content.extend([f"{column_name}", f"{steps}", f"{threshold:.2f}", f"{filter_percentage:.2f}"]) + + for value in result.perc_val.values(): + perc_val_percent = value * 100 + enrichment = value / result.percentages[-1] if result.percentages[-1] else float("nan") + content.extend([f"{perc_val_percent:.2f}", f"{enrichment:.2f}"]) + + content.append(f"{result.time:.4f}") + + return content + + def handle_threshold( + self, + filter_combinations: list[FilterCombinationResult], + distinct_combinations: Array2DFloat, + column_names: ColumnNamesArray, + num_poses: int, + ) -> None: + threshold_file = self.config.threshold or "" + if not threshold_file: + return + + best_combination_index = self.get_best_filter_combination_index(filter_combinations) + if best_combination_index: + best_combination = distinct_combinations[best_combination_index] + self.write_threshold(best_combination, column_names, num_poses) + else: + message = "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + logger.warning(message) + + def get_best_filter_combination_index(self, results: list[FilterCombinationResult]) -> int: + """ + Very debatable how to do this... + Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" + (= percentage of validation compounds / percentage of all compounds); we select the + threshold with the highest enrichment factor + """ + min_max_values = MinMaxValues() + # Transpose the `perc_val` data to get columns + perc_vals = {col: [result.perc_val[col] for result in results] for col in results[0].perc_val} + for col, vals in perc_vals.items(): + min_max_values.update(col, min(vals), max(vals)) + time_vals = [result.time for result in results] + min_max_values.update("time", min(time_vals), max(time_vals)) + combination_scores = [self.calculate_combination_score(result, min_max_values) for result in results] + index = numpy.argmax(combination_scores) + return int(index) + + def calculate_combination_score(self, result: FilterCombinationResult, min_max_values: MinMaxValues) -> float: + if result.time < self.config.max_time and result.percentages[-1] >= self.config.min_percentage / 100: + col_scores = [ + (result.perc_val[col] - min_max_values.get(col).min) + / (min_max_values.get(col).max - min_max_values.get(col).min) + for col in min_max_values.values + if isinstance(col, int) + ] + time_score = (min_max_values.get("time").max - result.time) / ( + min_max_values.get("time").max - min_max_values.get("time").min + ) + score = sum(col_scores) + time_score + else: + score = 0 + + return score + + def write_threshold( + self, + best_filter_combination: Array1DFloat, + column_names: ColumnNamesArray, + max_number_of_runs: int, + sep: str = "\n", + end: str = "\n", + ) -> None: + path = self.config.threshold or "default_threshold.txt" + with open(path, "w") as f: + content = self.get_threshold_content(best_filter_combination, column_names, max_number_of_runs) + f.write(sep.join(content) + end) + + def get_threshold_content( + self, best_filter_combination: Array1DFloat, column_names: ColumnNamesArray, max_number_of_runs: int + ) -> list[str]: + content = [] + # Number of filters to apply + content.append(f"{len(self.config.filters) + 1}") + # Get each filter to a separate line + filter_lines = [ + f"if - {best_filter_combination[i]:.2f} {column_names[filter.column]} 1.0 " + f"if - SCORE.NRUNS {filter.steps} 0.0 -1.0," + for i, filter in enumerate(self.config.filters) + ] + content.extend(filter_lines) + # Filter to terminate docking when NRUNS reaches the number of runs used in the input file + content.append(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0") + # Find strictest filters for all columns and apply them again + filters_by_column = defaultdict(list) + for i, filter in enumerate(self.config.filters): + filters_by_column[filter.column].append(best_filter_combination[i]) + # Number of filters (same as number of columns filtered on) + content.append(f"{len(filters_by_column)}") + # Filter + filter_min_values = [f"- {column_names[col]} {min(values)}," for col, values in filters_by_column.items()] + content.extend(filter_min_values) + return content diff --git a/rdock-utils/rdock_utils/rbhtfinder/schemas.py b/rdock-utils/rdock_utils/rbhtfinder/schemas.py new file mode 100644 index 00000000..43ce073d --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder/schemas.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field + +from rdock_utils.common import Array1DFloat + + +@dataclass +class Filter: + column: int = 0 + steps: int = 0 + min: float = 0.0 + max: float = 0.0 + interval: float = 0.0 + + +@dataclass +class FilterCombinationResult: + combination: Array1DFloat + perc_val: dict[int, float] + percentages: list[float] + time: float + + +@dataclass +class MinMax: + min: float + max: float + + +@dataclass +class MinMaxValues: + values: dict[int | str, MinMax] = field(default_factory=dict) + + def update(self, column: int | str, min_val: float, max_val: float) -> None: + self.values[column] = MinMax(min=min_val, max=max_val) + + def get(self, column: int | str) -> MinMax: + return self.values[column] diff --git a/rdock-utils/rdock_utils/rbhtfinder b/rdock-utils/rdock_utils/rbhtfinder_original similarity index 89% rename from rdock-utils/rdock_utils/rbhtfinder rename to rdock-utils/rdock_utils/rbhtfinder_original index 9e2123c5..c5162bb9 100755 --- a/rdock-utils/rdock_utils/rbhtfinder +++ b/rdock-utils/rdock_utils/rbhtfinder_original @@ -11,7 +11,6 @@ import argparse import itertools import multiprocessing import os -import sys from collections import Counter from functools import partial from pathlib import Path @@ -37,16 +36,12 @@ def prepare_array(sdreport_array, name_column): sdreport_array, np.where( sdreport_array[:, name_column] - != np.hstack( - (sdreport_array[1:, name_column], sdreport_array[0, name_column]) - ) + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) )[0] + 1, ) modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] - number_of_poses = modal_shape[0][ - 0 - ] # find modal number of poses per molecule in the array + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array split_array_clean = sum( [ @@ -85,24 +80,16 @@ def calculate_results_for_filter_combination( for n, threshold in enumerate(filter_combination): if n: # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses - number_of_simulated_poses += len(mols_passed_threshold) * ( - filters[n]["steps"] - filters[n - 1]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) else: - number_of_simulated_poses += ( - len(mols_passed_threshold) * filters[n]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters n - for n in apply_threshold( - molecule_array, filters[n]["column"], filters[n]["steps"], threshold - ) + for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) if n in mols_passed_threshold ] filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) - number_of_simulated_poses += len(mols_passed_threshold) * ( - molecule_array.shape[1] - filters[-1]["steps"] - ) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) perc_val = { k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols for k, v in min_score_indices.items() @@ -115,9 +102,7 @@ def calculate_results_for_filter_combination( } -def write_output( - results, filters, number_of_validation_mols, output_file, column_names -): +def write_output(results, filters, number_of_validation_mols, output_file, column_names): """ Print results as a table. The number of columns varies depending how many columns the user picked. """ @@ -139,9 +124,7 @@ def write_output( for n in result["perc_val"]: f.write(f"{result['perc_val'][n]*100:.2f}\t") if result["filter_percentages"][-1]: - f.write( - f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t" - ) + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") else: f.write("NaN\t") f.write(f"{result['time']:.4f}\n") @@ -176,17 +159,14 @@ def select_best_filter_combination(results, max_time, min_perc): / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) ] ) - if result["time"] < max_time - and result["filter_percentages"][-1] >= min_perc / 100 + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 else 0 for result in results ] return np.argmax(combination_scores) -def write_threshold_file( - filters, best_filter_combination, threshold_file, column_names, max_number_of_runs -): +def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): with open(threshold_file, "w") as f: # write number of filters to apply f.write(f"{len(filters) + 1}\n") @@ -200,11 +180,7 @@ def write_threshold_file( # write final filters - find strictest filters for all columns and apply them again filters_by_column = { - col: [ - best_filter_combination[n] - for n, filtr in enumerate(filters) - if filtr["column"] == col - ] + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] for col in set([filtr["column"] for filtr in filters]) } # write number of filters (same as number of columns filtered on) @@ -355,15 +331,9 @@ throughput protocol. The following steps should be followed: args.name -= 1 # because np arrays need 0-based indices # create filters dictionary from args.filter passed in + filters = [dict([n.split("=") for n in filtr[0].split(",")]) for filtr in args.filter] filters = [ - dict([n.split("=") for n in filtr[0].split(",")]) for filtr in args.filter - ] - filters = [ - { - k: float(v) if k in ["interval", "min", "max"] else int(v) - for k, v in filtr.items() - } - for filtr in filters + {k: float(v) if k in ["interval", "min", "max"] else int(v) for k, v in filtr.items()} for filtr in filters ] for filtr in filters: @@ -394,10 +364,7 @@ throughput protocol. The following steps should be followed: # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps filter_combinations = np.array(filter_combinations) cols = [filtr["column"] for filtr in filters] - indices_per_col = { - col: [n for n, filter_col in enumerate(cols) if col == filter_col] - for col in set(cols) - } + indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} filter_combination_indices_to_keep = range(len(filter_combinations)) for col, indices in indices_per_col.items(): filter_combination_indices_to_keep = [ @@ -414,9 +381,7 @@ throughput protocol. The following steps should be followed: f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." ) else: - print( - "No filter combinations could be calculated - check the thresholds specified." - ) + print("No filter combinations could be calculated - check the thresholds specified.") exit(1) if pd: @@ -446,9 +411,7 @@ throughput protocol. The following steps should be followed: min_score_indices = {} for column in set(filtr["column"] for filtr in filters): min_scores = np.min(molecule_array[:, :, column], axis=1) - min_score_indices[column] = np.argpartition(min_scores, args.validation)[ - : args.validation - ] + min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] results = [] @@ -466,9 +429,7 @@ throughput protocol. The following steps should be followed: write_output(results, filters, args.validation, args.output, column_names) - best_filter_combination = select_best_filter_combination( - results, args.max_time, args.min_perc - ) + best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) if args.threshold: if best_filter_combination: write_threshold_file( diff --git a/rdock-utils/rdock_utils/rbhtfinder_original_copy.py b/rdock-utils/rdock_utils/rbhtfinder_original_copy.py new file mode 100755 index 00000000..07df5f52 --- /dev/null +++ b/rdock-utils/rdock_utils/rbhtfinder_original_copy.py @@ -0,0 +1,443 @@ +import numpy as np + +try: + import pandas as pd +except ImportError: + pd = None +import argparse +import itertools +import multiprocessing +import os +from collections import Counter +from functools import partial + +Filter = dict[str, float] + + +def apply_threshold(scored_poses, column, steps, threshold): + """ + Filter out molecules from `scored_poses`, where the minimum score reached (for a specified `column`) after `steps` is more negative than `threshold`. + """ + # minimum score after `steps` per molecule + mins = np.min(scored_poses[:, :steps, column], axis=1) + # return those molecules where the minimum score is less than the threshold + passing_molecules = np.where(mins < threshold)[0] + return passing_molecules + + +def prepare_array(sdreport_array: np.ndarray, name_column: int) -> np.ndarray: + """ + Convert `sdreport_array` (read directly from the tsv) to 3D array (molecules x poses x columns) and filter out molecules with too few/many poses + """ + # print(sdreport_array.shape[1]) + # if name_column >= sdreport_array.shape[1]: + # raise IndexError( + # f"name_column index {name_column} is out of bounds for array with shape {sdreport_array.shape}" + # ) + + # find points in the array where the name_column changes (i.e. we are dealing with a new molecule) and split the array + split_indices = ( + np.where( + sdreport_array[:, name_column] + != np.hstack((sdreport_array[1:, name_column], sdreport_array[0, name_column])) + )[0] + + 1 + ) + split_array = np.split(sdreport_array, split_indices) + + modal_shape = Counter([n.shape for n in split_array]).most_common(1)[0] + number_of_poses = modal_shape[0][0] # find modal number of poses per molecule in the array + + split_array_clean = sum( + [ + np.array_split(n, n.shape[0] / number_of_poses) + for n in split_array + if not n.shape[0] % number_of_poses and n.shape[0] + ], + [], + ) + + if len(split_array_clean) * number_of_poses < sdreport_array.shape[0] * 0.99: + print( + f"WARNING: the number of poses provided per molecule is inconsistent. Only {len(split_array_clean)} of {int(sdreport_array.shape[0] / number_of_poses)} moleules have {number_of_poses} poses." + ) + + molecule_array = np.array(split_array_clean) + # overwrite the name column (should be the only one with dtype=str) so we can force everything to float + molecule_array[:, :, name_column] = 0 + return np.array(molecule_array, dtype=float) + + +def calculate_results_for_filter_combination( + filter_combination, + molecule_array, + filters, + min_score_indices, + number_of_validation_mols, +): + """ + For a particular combination of filters, calculate the percentage of molecules that will be filtered, the percentage of top-scoring molecules that will be filtered, and the time taken relative to exhaustive docking + """ + # mols_passed_threshold is a list of indices of molecules which have passed the applied filters. As more filters are applied, it gets smaller. Before any iteration, we initialise with all molecules passing + mols_passed_threshold = list(range(molecule_array.shape[0])) + filter_percentages = [] + number_of_simulated_poses = 0 # number of poses which we calculate would be generated, we use this to calculate the TIME column in the final output + for n, threshold in enumerate(filter_combination): + if n: + # e.g. if there are 5000 mols left after 15 steps and the last filter was at 5 steps, append 5000 * (15 - 5) to number_of_simulated_poses + number_of_simulated_poses += len(mols_passed_threshold) * (filters[n]["steps"] - filters[n - 1]["steps"]) + else: + number_of_simulated_poses += len(mols_passed_threshold) * filters[n]["steps"] + mols_passed_threshold = [ # all mols which pass the threshold and which were already in mols_passed_threshold, i.e. passed all previous filters + n + for n in apply_threshold(molecule_array, filters[n]["column"], filters[n]["steps"], threshold) + if n in mols_passed_threshold + ] + filter_percentages.append(len(mols_passed_threshold) / molecule_array.shape[0]) + number_of_simulated_poses += len(mols_passed_threshold) * (molecule_array.shape[1] - filters[-1]["steps"]) + perc_val = { + k: len([n for n in v if n in mols_passed_threshold]) / number_of_validation_mols + for k, v in min_score_indices.items() + } + return { + "filter_combination": filter_combination, + "perc_val": perc_val, + "filter_percentages": filter_percentages, + "time": number_of_simulated_poses / np.product(molecule_array.shape[:2]), + } + + +def write_output(results, filters, number_of_validation_mols, output_file, column_names): + """ + Print results as a table. The number of columns varies depending how many columns the user picked. + """ + with open(output_file, "w") as f: + # write header + for n in range(len(results[0]["filter_combination"])): + f.write(f"FILTER{n+1}\tNSTEPS{n+1}\tTHR{n+1}\tPERC{n+1}\t") + for n in results[0]["perc_val"]: + f.write(f"TOP{number_of_validation_mols}_{column_names[n]}\t") + f.write(f"ENRICH_{column_names[n]}\t") + f.write("TIME\n") + + # write results + for result in results: + for n, threshold in enumerate(result["filter_combination"]): + f.write( + f"{column_names[filters[n]['column']]}\t{filters[n]['steps']}\t{threshold:.2f}\t{result['filter_percentages'][n]*100:.2f}\t" + ) + for n in result["perc_val"]: + f.write(f"{result['perc_val'][n]*100:.2f}\t") + if result["filter_percentages"][-1]: + f.write(f"{result['perc_val'][n]/result['filter_percentages'][-1]:.2f}\t") + else: + f.write("NaN\t") + f.write(f"{result['time']:.4f}\n") + return + + +def select_best_filter_combination(results, max_time, min_perc): + """ + Very debatable how to do this... + Here we exclude all combinations with TIME < max_time and calculate an "enrichment factor" + (= percentage of validation compounds / percentage of all compounds); we select the + threshold with the highest enrichment factor + """ + min_max_values = {} + for col in results[0]["perc_val"].keys(): + vals = [result["perc_val"][col] for result in results] + min_max_values[col] = {"min": min(vals), "max": max(vals)} + time_vals = [result["time"] for result in results] + min_max_values["time"] = {"min": min(time_vals), "max": max(time_vals)} + + combination_scores = [ + sum( + [ + ( + (result["perc_val"][col] - min_max_values[col]["min"]) + / (min_max_values[col]["max"] - min_max_values[col]["min"]) + ) + for col in results[0]["perc_val"].keys() + ] + + [ + (min_max_values["time"]["max"] - result["time"]) + / (min_max_values["time"]["max"] - min_max_values["time"]["min"]) + ] + ) + if result["time"] < max_time and result["filter_percentages"][-1] >= min_perc / 100 + else 0 + for result in results + ] + return np.argmax(combination_scores) + + +def write_threshold_file(filters, best_filter_combination, threshold_file, column_names, max_number_of_runs): + with open(threshold_file, "w") as f: + # write number of filters to apply + f.write(f"{len(filters) + 1}\n") + # write each filter to a separate line + for n, filtr in enumerate(filters): + f.write( + f'if - {best_filter_combination[n]:.2f} {column_names[filtr["column"]]} 1.0 if - SCORE.NRUNS {filtr["steps"]} 0.0 -1.0,\n' + ) + # write filter to terminate docking when NRUNS reaches the number of runs used in the input file + f.write(f"if - SCORE.NRUNS {max_number_of_runs - 1} 0.0 -1.0\n") + + # write final filters - find strictest filters for all columns and apply them again + filters_by_column = { + col: [best_filter_combination[n] for n, filtr in enumerate(filters) if filtr["column"] == col] + for col in set([filtr["column"] for filtr in filters]) + } + # write number of filters (same as number of columns filtered on) + f.write(f"{len(filters_by_column)}\n") + # write filter + for col, values in filters_by_column.items(): + f.write(f"- {column_names[col]} {min(values)},\n") + + +def parse_filter(filter_str: str) -> Filter: + parsed_filter = {} + for item in filter_str.split(","): + key, value = item.split("=") + parsed_filter[key] = float(value) if key in ["interval", "min", "max"] else int(value) + parsed_filter["column"] -= 1 + return parsed_filter + + +def main(argv: list[str] | None = None): + """ + Parse arguments; read in data; calculate filter combinations and apply them; print results + """ + parser = argparse.ArgumentParser( + description="""Estimate the results and computation time of an rDock high +throughput protocol. The following steps should be followed: +1) exhaustive docking of a small representative part of the entire + library. +2) Store the result of sdreport -t over that exhaustive docking run + in a file which will be the input of this script. +3) Run rbhtfinder, specifying -i and an arbitrary + number of filters specified using the -f option, for example + "-f column=6,steps=5,min=0.5,max=1.0,interval=0.1". This example + would simulate the effect of applying thresholds on column 6 after + 5 poses have been generated, for values between 0.5 and 1.0 (i.e. + 0.5, 0.6, 0.7, 0.8, 0.9, 1.0). More than one threshold can be + specified, e.g., "-f column=4,steps=5,min=-12,max=-10,interval=1 + -f column=4,steps=15,min=-16,max=-15,interval=1" will test the + following combinations of thresholds on column 4: + 5 -10 15 -15 + 5 -11 15 -15 + 5 -12 15 -15 + 5 -10 15 -16 + 5 -11 15 -16 + 5 -12 15 -16 + The number of combinations will increase very rapidly, the more + filters are used and the larger the range of values specified for + each. It may be sensible to run rbhtfinder several times to explore + the effects of various filters independently. + + The output of the program consists of the following columns. + FILTER1 NSTEPS1 THR1 PERC1 TOP500_SCORE.INTER ENRICH_SCORE.INTER TIME + SCORE.INTER 5 -13.00 6.04 72.80 12.05 0.0500 + SCORE.INTER 5 -12.00 9.96 82.80 8.31 0.0500 + The four columns are repeated for each filter specified with the -f + option: name of the column on which the filter is applied + (FILTER1), number of steps at which the threshold is applied + (NSTEPS1), value of the threshold (THR1) and the percentage of + poses which pass this filter (PERC1). Additional filters (FILTER2, + FILTER3 etc.) are listed in the order that they are applied (i.e. + by NSTEPS). + + The final columns provide some overall statistics for the + combination of thresholds specified in a row. TOP500_SCORE.INTER + gives the percentage of the top-scoring 500 poses, measured by + SCORE.INTER, from the whole of which are retained + after the thresholds are applied. This can be contrasted with the + final PERC column. The higher the ratio (the 'enrichment factor'), + the better the combination of thresholds. If thresholds are applied + on multiple columns, this column will be duplicated for each, e.g. + TOP500_SCORE.INTER and TOP500_SCORE.RESTR will give the percentage + of the top-scoring poses retained for both of these scoring + methods. The exact number of poses used for this validation can be + changed from the default 500 using the --validation flag. + ENRICH_SCORE.INTER gives the enrichment factor as a quick + rule-of-thumb to assess the best choice of thresholds. The final + column TIME provides an estimate of the time taken to perform + docking, as a proportion of the time taken for exhaustive docking. + This value should be below 0.1. + + After a combination of thresholds has been selected, they need to + be encoded into a threshold file which rDock can use as an input. + rbhtfinder attempts to help with this task by automatically + selecting a combination and writing a threshold file. The + combination chosen is that which provides the highest enrichment + factor, after all options with a TIME value over 0.1 are excluded. + This choice should not be blindly followed, so the threshold file + should be considered a template that the user modifies as needed. + + rbhtfinder requires NumPy. Installation of pandas is recommended, + but optional; if pandas is not available, loading the input file + for calculations will be considerably slower. + + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-i", + "--input", + help="Input from sdreport (tabular separated format).", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output", + help="Output file for report on threshold combinations.", + type=str, + required=True, + ) + parser.add_argument( + "-t", + "--threshold", + help="Threshold file used by rDock as input.", + type=str, + ) + parser.add_argument( + "-n", + "--name", + type=int, + default=1, # Index of molecule name in input file is 1 by default + help="Index of column containing the molecule name (0 indexed). Default is 1.", + ) + parser.add_argument( + "-f", + "--filter", + nargs="+", + type=str, + help="Filter to apply, e.g. column=4,steps=5,min=-10,max=-15,interval=1 will test applying a filter to column 4 after generation of 5 poses, with threshold values between -10 and -15 tested. The variables column, steps, min and max must all be specified; interval defaults to 1 if not given.", + ) # Removed action 'append' to avoid unnecessary nested structure + parser.add_argument( + "-v", + "--validation", + type=int, + default=500, + help="Top-scoring N molecules from input to use for validating threshold combinations. Default is 500.", + ) + parser.add_argument( + "--header", + action="store_true", + help="Specify if the input file from sdreport contains a header line with column names. If not, output files will describe columns using indices, e.g. COL4, COL5.", + ) + parser.add_argument( + "--max-time", + type=float, + default=0.1, + help="Maximum value for time to use when autogenerating a high-throughput protocol - default is 0.1, i.e. 10%% of the time exhaustive docking would take.", + ) + parser.add_argument( + "--min-perc", + type=float, + default=1.0, + help="Minimum value for the estimated final percentage of compounds to use when autogenerating a high-throughput protocol - default is 1.", + ) + + args = parser.parse_args(argv) + + # create filters dictionary from args.filter passed in + filters = [parse_filter(filter) for filter in args.filter] + + # sort filters by step at which they are applied + filters.sort(key=lambda n: n["steps"]) + + # generates all possible combinations from filters provided + fils = [(filtr["min"], filtr["max"] + filtr.get("interval", 1.0), filtr.get("interval", 1.0)) for filtr in filters] + filter_combinations = list(itertools.product(*(np.arange(*n) for n in fils))) + print(f"{len(filter_combinations)} combinations of filters calculated.") + + # remove redundant combinations, i.e. where filters for later steps are less or equally strict to earlier steps + filter_combinations = np.array(filter_combinations) + cols = [filtr["column"] for filtr in filters] + indices_per_col = {col: [n for n, filter_col in enumerate(cols) if col == filter_col] for col in set(cols)} + filter_combination_indices_to_keep = range(len(filter_combinations)) + for col, indices in indices_per_col.items(): + filter_combination_indices_to_keep = [ + n + for n, comb in enumerate(filter_combinations[:, indices]) + if list(comb) == sorted(comb, reverse=True) + and len(set(comb)) == comb.shape[0] + and n in filter_combination_indices_to_keep + ] + filter_combinations = filter_combinations[filter_combination_indices_to_keep] + + if len(filter_combinations): + print( + f"{len(filter_combinations)} combinations of filters remain after removal of redundant combinations. Starting calculations..." + ) + else: + print("No filter combinations could be calculated - check the thresholds specified.") + exit(1) + + if pd: + # pandas is weird... i.e., skip line 0 if there's a header, else read all lines + header = 0 if args.header else None + sdreport_dataframe = pd.read_csv(args.input, sep="\t", header=header) + if args.header: + column_names = sdreport_dataframe.columns.values + else: + # use index names; add 1 to deal with zero-based numbering + column_names = [f"COL{n+1}" for n in range(len(sdreport_dataframe.columns))] + sdreport_array = sdreport_dataframe.values + print(f"First few rows of the input array:\n{sdreport_array[:5]}") + else: # pd not available + np_array = np.loadtxt(args.input, dtype=str) + if args.header: + column_names = np_array[0] + sdreport_array = np_array[1:] + else: + column_names = [f"COL{n+1}" for n in range(np_array.shape[1])] + sdreport_array = np_array + print("Data read in from input file.") + + # convert to 3D array (molecules x poses x columns) + molecule_array = prepare_array(sdreport_array, args.name) + + # find the top scoring compounds for validation of the filter combinations + min_score_indices = {} + for column in set(filtr["column"] for filtr in filters): + min_scores = np.min(molecule_array[:, :, column], axis=1) + min_score_indices[column] = np.argpartition(min_scores, args.validation)[: args.validation] + + results = [] + + pool = multiprocessing.Pool(os.cpu_count()) + results = pool.map( + partial( + calculate_results_for_filter_combination, + molecule_array=molecule_array, + filters=filters, + min_score_indices=min_score_indices, + number_of_validation_mols=args.validation, + ), + filter_combinations, + ) + + write_output(results, filters, args.validation, args.output, column_names) + + best_filter_combination = select_best_filter_combination(results, args.max_time, args.min_perc) + if args.threshold: + if best_filter_combination: + write_threshold_file( + filters, + filter_combinations[best_filter_combination], + args.threshold, + column_names, + molecule_array.shape[1], + ) + else: + print( + "Filter combinations defined are too strict or would take too long to run; no threshold file was written." + ) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/rdock-utils/requirements-dev.txt b/rdock-utils/requirements-dev.txt index 16263576..531ad766 100644 --- a/rdock-utils/requirements-dev.txt +++ b/rdock-utils/requirements-dev.txt @@ -1,3 +1,3 @@ mypy==1.8.0 pytest==7.4.4 -ruff==0.1.14 +ruff==0.5.4 diff --git a/rdock-utils/requirements.txt b/rdock-utils/requirements.txt index a113d2c3..1b76bf47 100644 --- a/rdock-utils/requirements.txt +++ b/rdock-utils/requirements.txt @@ -1,2 +1,3 @@ numpy==1.26.2 -openbabel==3.1.1.1 \ No newline at end of file +openbabel==3.1.1.1 +pandas==2.2.2 \ No newline at end of file diff --git a/rdock-utils/tests/fixtures/rbhtfinder/input.txt b/rdock-utils/tests/fixtures/rbhtfinder/input.txt new file mode 100644 index 00000000..3b261f8f --- /dev/null +++ b/rdock-utils/tests/fixtures/rbhtfinder/input.txt @@ -0,0 +1,101 @@ +REC _TITLE1 TOTAL INTER INTRA RESTR VDW +001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 +002 mol00 2.595 -0.601 -1.152 4.347 -11.001 +003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 +004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 +005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 +006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 +007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 +008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 +009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 +010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 +011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 +012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 +013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 +014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 +015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 +016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 +017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 +018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 +019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 +020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 +021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 +022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 +023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 +024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 +025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 +026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 +027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 +028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 +029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 +030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 +031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 +032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 +033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 +034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 +035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 +036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 +037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 +038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 +039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 +040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 +041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 +042 mol04 6.644 5.519 -0.566 1.691 -0.734 +043 mol04 -3.363 -7.773 0.964 3.446 -13.299 +044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 +045 mol04 -2.875 -5.317 0.643 1.799 -13.852 +046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 +047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 +048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 +049 mol04 -8.574 -9.947 1.073 0.301 -18.351 +050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 +051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 +052 mol05 -5.265 -9.689 0.036 4.387 -16.474 +053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 +054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 +055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 +056 mol05 -12.771 -15.095 1.703 0.621 -17.161 +057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 +058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 +059 mol05 -3.387 -7.638 1.574 2.678 -16.308 +060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 +061 mol06 -15.764 -17.717 0.853 1.101 -21.131 +062 mol06 -2.956 -7.275 0.313 4.006 -14.833 +063 mol06 -6.103 -12.909 2.281 4.526 -17.262 +064 mol06 1.370 -1.589 -0.619 3.579 -9.989 +065 mol06 0.980 -14.709 0.605 15.084 -20.358 +066 mol06 3.784 -6.808 8.337 2.255 -14.995 +067 mol06 -5.845 -12.679 2.130 4.704 -17.065 +068 mol06 -5.255 -12.309 4.456 2.598 -17.557 +069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 +070 mol06 -8.737 -13.409 3.272 1.400 -17.974 +071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 +072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 +073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 +074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 +075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 +076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 +077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 +078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 +079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 +080 mol07 -2.593 -7.660 3.906 1.162 -10.076 +081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 +082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 +083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 +084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 +085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 +086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 +087 mol08 -10.119 5.962 -25.259 9.178 -7.399 +088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 +089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 +090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 +091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 +092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 +093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 +094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 +095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 +096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 +097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 +098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 +099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 +100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_output.txt b/rdock-utils/tests/fixtures/rbhtfinder/output.txt similarity index 100% rename from rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_output.txt rename to rdock-utils/tests/fixtures/rbhtfinder/output.txt diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_input.txt b/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_input.txt deleted file mode 100644 index 0d9277c6..00000000 --- a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_input.txt +++ /dev/null @@ -1,101 +0,0 @@ -REC _TITLE1 TOTAL INTER INTRA RESTR VDW -001 mol00 -16.905 -11.204 -6.416 0.715 -18.926 -002 mol00 2.595 -0.601 -1.152 4.347 -11.001 -003 mol00 -13.022 -12.572 -8.953 8.502 -20.443 -004 mol00 -16.128 -12.742 -8.977 5.591 -17.353 -005 mol00 -10.576 -4.606 -6.451 0.481 -16.707 -006 mol00 -18.429 -11.402 -8.179 1.152 -18.191 -007 mol00 -18.316 -12.749 -6.842 1.275 -21.002 -008 mol00 -13.123 -6.272 -9.001 2.150 -16.672 -009 mol00 -6.763 -7.234 -4.006 4.478 -15.995 -010 mol00 -16.302 -11.451 -5.042 0.192 -21.602 -011 mol01 -14.764 -12.244 -3.069 0.550 -16.362 -012 mol01 -8.102 -9.014 -2.509 3.421 -13.535 -013 mol01 -17.136 -13.983 -4.509 1.356 -15.128 -014 mol01 -10.791 -7.401 -4.334 0.944 -12.455 -015 mol01 -15.107 -11.770 -3.681 0.343 -12.760 -016 mol01 -15.348 -12.600 -3.085 0.337 -12.213 -017 mol01 -13.234 -9.356 -4.039 0.161 -13.449 -018 mol01 -12.883 -10.593 -2.692 0.401 -14.155 -019 mol01 -14.937 -12.053 -3.622 0.738 -16.503 -020 mol01 -15.504 -12.806 -3.140 0.442 -12.497 -021 mol02 -12.446 -11.333 -4.405 3.291 -15.701 -022 mol02 -13.334 -11.044 -2.708 0.418 -13.332 -023 mol02 -12.298 -8.953 -4.006 0.662 -13.422 -024 mol02 -10.855 -8.415 -3.033 0.593 -12.782 -025 mol02 -12.506 -9.802 -3.198 0.494 -14.579 -026 mol02 -13.582 -11.559 -2.422 0.399 -15.628 -027 mol02 -14.966 -11.346 -4.361 0.741 -16.671 -028 mol02 -15.302 -12.238 -3.389 0.324 -13.782 -029 mol02 -9.849 -9.111 -4.596 3.858 -14.011 -030 mol02 -13.621 -11.178 -2.870 0.427 -15.527 -031 mol03 -10.492 -8.634 -2.412 0.554 -12.702 -032 mol03 -16.369 -12.611 -3.925 0.166 -15.707 -033 mol03 -16.074 -12.018 -4.147 0.091 -14.921 -034 mol03 -6.623 -8.868 -2.337 4.582 -13.383 -035 mol03 -4.061 -4.354 -4.135 4.428 -11.803 -036 mol03 -16.844 -13.744 -3.429 0.329 -14.531 -037 mol03 -16.759 -14.229 -2.994 0.464 -15.433 -038 mol03 -15.680 -11.976 -3.889 0.185 -15.065 -039 mol03 -11.919 -9.693 -2.623 0.398 -14.239 -040 mol03 -8.137 -7.516 -3.235 2.614 -11.614 -041 mol04 -7.776 -6.296 -2.270 0.790 -16.535 -042 mol04 6.644 5.519 -0.566 1.691 -0.734 -043 mol04 -3.363 -7.773 0.964 3.446 -13.299 -044 mol04 -4.351 -4.121 -1.905 1.675 -11.049 -045 mol04 -2.875 -5.317 0.643 1.799 -13.852 -046 mol04 -7.823 -9.622 -0.031 1.830 -14.752 -047 mol04 -2.534 -1.876 -2.013 1.354 -10.910 -048 mol04 -13.193 -11.516 -2.048 0.371 -17.047 -049 mol04 -8.574 -9.947 1.073 0.301 -18.351 -050 mol04 -9.966 -9.181 -1.811 1.027 -14.498 -051 mol05 -5.717 -12.369 -0.344 6.997 -20.154 -052 mol05 -5.265 -9.689 0.036 4.387 -16.474 -053 mol05 -11.101 -9.229 -2.354 0.483 -17.823 -054 mol05 -3.375 -5.926 -1.281 3.832 -14.547 -055 mol05 -9.546 -12.438 -1.927 4.819 -17.671 -056 mol05 -12.771 -15.095 1.703 0.621 -17.161 -057 mol05 -19.198 -19.152 -0.788 0.743 -17.933 -058 mol05 -12.564 -13.726 -0.425 1.587 -19.786 -059 mol05 -3.387 -7.638 1.574 2.678 -16.308 -060 mol05 -14.882 -17.451 -0.477 3.045 -19.050 -061 mol06 -15.764 -17.717 0.853 1.101 -21.131 -062 mol06 -2.956 -7.275 0.313 4.006 -14.833 -063 mol06 -6.103 -12.909 2.281 4.526 -17.262 -064 mol06 1.370 -1.589 -0.619 3.579 -9.989 -065 mol06 0.980 -14.709 0.605 15.084 -20.358 -066 mol06 3.784 -6.808 8.337 2.255 -14.995 -067 mol06 -5.845 -12.679 2.130 4.704 -17.065 -068 mol06 -5.255 -12.309 4.456 2.598 -17.557 -069 mol06 -5.051 -8.500 -1.065 4.515 -12.298 -070 mol06 -8.737 -13.409 3.272 1.400 -17.974 -071 mol07 -5.945 -6.564 -0.932 1.551 -15.670 -072 mol07 -11.177 -12.429 -1.525 2.777 -15.118 -073 mol07 -3.446 -1.734 -2.958 1.246 -7.623 -074 mol07 -4.229 -5.796 -0.264 1.831 -14.220 -075 mol07 -14.958 -15.847 -0.333 1.222 -18.956 -076 mol07 -8.390 -8.507 -0.927 1.045 -14.022 -077 mol07 -5.093 -5.862 -1.992 2.761 -15.437 -078 mol07 -9.813 -12.418 -0.122 2.726 -17.489 -079 mol07 -10.936 -10.623 -1.940 1.626 -16.272 -080 mol07 -2.593 -7.660 3.906 1.162 -10.076 -081 mol08 -30.625 -10.460 -24.533 4.369 -21.331 -082 mol08 -34.896 -10.897 -28.333 4.334 -24.000 -083 mol08 -37.535 -5.959 -32.574 0.998 -17.627 -084 mol08 -24.337 -1.398 -32.330 9.391 -13.655 -085 mol08 -33.982 -6.759 -29.808 2.584 -20.003 -086 mol08 -22.908 -5.812 -32.172 15.076 -17.519 -087 mol08 -10.119 5.962 -25.259 9.178 -7.399 -088 mol08 -36.286 -7.066 -31.019 1.799 -19.466 -089 mol08 -32.439 -4.421 -28.944 0.926 -16.742 -090 mol08 -33.056 -3.138 -31.632 1.714 -16.795 -091 mol09 -37.922 -11.009 -28.015 1.102 -14.514 -092 mol09 -33.961 -11.278 -28.396 5.713 -18.027 -093 mol09 -30.177 -6.085 -27.327 3.235 -11.667 -094 mol09 -36.755 -10.942 -27.524 1.710 -16.747 -095 mol09 -27.609 -3.028 -27.462 2.881 -5.874 -096 mol09 -29.025 -10.924 -25.192 7.091 -17.062 -097 mol09 -28.521 -6.851 -28.559 6.889 -12.872 -098 mol09 -37.849 -18.828 -26.348 7.327 -18.185 -099 mol09 -33.968 -11.233 -28.349 5.614 -17.982 -100 mol09 -37.434 -10.703 -28.080 1.348 -16.012 diff --git a/rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_threshold.txt b/rdock-utils/tests/fixtures/rbhtfinder/threshold.txt similarity index 100% rename from rdock-utils/tests/fixtures/rbhtfinder/rbhtfinder_threshold.txt rename to rdock-utils/tests/fixtures/rbhtfinder/threshold.txt diff --git a/rdock-utils/tests/rbhtfinder/__init__.py b/rdock-utils/tests/rbhtfinder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rdock-utils/tests/rbhtfinder/conftest.py b/rdock-utils/tests/rbhtfinder/conftest.py new file mode 100644 index 00000000..9d688bb3 --- /dev/null +++ b/rdock-utils/tests/rbhtfinder/conftest.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import pytest + +from ..conftest import FIXTURES_FOLDER + +RBHTFINDER_FIXTURES_FOLDER = FIXTURES_FOLDER / "rbhtfinder" + +INPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "input.txt") +EXPECTED_THRESHOLD_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "threshold.txt") +EXPECTED_OUTPUT_FILE = str(RBHTFINDER_FIXTURES_FOLDER / "output.txt") + + +@pytest.fixture +def output_temp(tmp_path: Path) -> Path: + output_path = tmp_path / "output.txt" + return output_path + + +@pytest.fixture +def threshold_temp(tmp_path: Path) -> Path: + threshold_path = tmp_path / "threshold.txt" + return threshold_path + + +@pytest.fixture +def argv(output_temp: Path, threshold_temp: Path) -> list[str]: + return [ + "-i", + INPUT_FILE, + "-o", + str(output_temp), + "-t", + str(threshold_temp), + "-f", + "column=4,steps=3,min=-10.0,max=0.0,interval=5.0", + "column=6,steps=5,min=1.0,max=5.0,interval=5.0", + "--max-time", + "1", + "--min-perc", + "1.0", + "-v", + "5", + "--header", + ] + + +def get_file_content(file: str | Path) -> str: + with open(file, "r") as f: + return f.read() diff --git a/rdock-utils/tests/rbhtfinder/test_integration.py b/rdock-utils/tests/rbhtfinder/test_integration.py new file mode 100644 index 00000000..9e4d637a --- /dev/null +++ b/rdock-utils/tests/rbhtfinder/test_integration.py @@ -0,0 +1,34 @@ +from pathlib import Path +from typing import Callable + +import pytest + +from rdock_utils.rbhtfinder.main import main as rbhtfinder_main +from rdock_utils.rbhtfinder_original_copy import main as rbhtfinder_old_main + +from .conftest import EXPECTED_OUTPUT_FILE, EXPECTED_THRESHOLD_FILE, get_file_content + +parametrize_main = pytest.mark.parametrize( + "main", + [ + pytest.param(rbhtfinder_old_main, id="Original version Python 3"), + pytest.param(rbhtfinder_main, id="Improved version Python 3.12"), + ], +) + + +@parametrize_main +def test_do_nothing(main: Callable[[list[str]], None]): + with pytest.raises(SystemExit): + main() + + +@parametrize_main +def test_integration(main: Callable[[list[str]], None], output_temp: Path, threshold_temp: Path, argv: list[str]): + main(argv) + output = get_file_content(output_temp) + threshold = get_file_content(threshold_temp) + expected_output = get_file_content(EXPECTED_OUTPUT_FILE) + expected_threshold = get_file_content(EXPECTED_THRESHOLD_FILE) + assert output == expected_output + assert threshold == expected_threshold