diff --git a/src/finn/custom_op/fpgadataflow/thresholding.py b/src/finn/custom_op/fpgadataflow/thresholding.py index 12cb76be4e..8cebf613b1 100644 --- a/src/finn/custom_op/fpgadataflow/thresholding.py +++ b/src/finn/custom_op/fpgadataflow/thresholding.py @@ -243,16 +243,29 @@ def execute_node(self, context, graph): inp_values = context[node.input[0]] th_val = context[node.input[1]] out_bias = self.get_nodeattr("ActVal") - # MT expects inputs to be in the shape (N,C,H,W) or (N, C) - # if 4D then input values in context are (N,H,W,C) and need to - # be transposed. - # if 2D then inputs can be passed directly to MT function - is_4d = len(inp_values.shape) == 4 - if is_4d: - inp_values = np.transpose(inp_values, (0, 3, 1, 2)) + + # Consider the data layout for transposing the input into the format + # accepted by the multithreshold function above, i.e, the channel + # dimension is along the axis with index 1. + data_layout = None + # If there is no layout annotation, guess based on rank of the tensor + # TODO: Currently there is no mechanism here to get the layout + # annotation, we allways guess, but this matches the previous behavior. + if len(inp_values.shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NHWC"} + # Lookup the layout required by this input shape + data_layout = rank_to_layout[len(inp_values.shape)] + # Lookup the index of the channel dimension in the data layout + # Note: Assumes there is at most one "C" which denotes the channel + # dimension + cdim = data_layout.index("C") if "C" in data_layout else 1 + # Rearrange the input to the expected (N, C, ...) layout + inp_values = inp_values.swapaxes(cdim, 1) y = multithreshold(inp_values, th_val, out_bias=out_bias) - if is_4d: - y = y.transpose(0, 2, 3, 1) + # Rearrange the output back to the original layout + y = y.swapaxes(cdim, 1) + act = DataType[self.get_nodeattr("outputDataType")] if act == DataType["BIPOLAR"]: # binary to bipolar