diff --git a/finn-rtllib/layernorm/accuf.sv b/finn-rtllib/layernorm/accuf.sv index 2c0581551c..a93175f6fe 100644 --- a/finn-rtllib/layernorm/accuf.sv +++ b/finn-rtllib/layernorm/accuf.sv @@ -17,7 +17,7 @@ module accuf #( input logic [31:0] a, input logic avld, - input logic alst, + input logic alst, // complete sum output logic [31:0] s, output logic svld ); diff --git a/finn-rtllib/layernorm/binopf.sv b/finn-rtllib/layernorm/binopf.sv index ee5aabe34a..60853dd18e 100644 --- a/finn-rtllib/layernorm/binopf.sv +++ b/finn-rtllib/layernorm/binopf.sv @@ -56,7 +56,7 @@ module binopf #( logic [1:0] ovf; logic [1:0] unf; always_ff @(posedge clk) begin - automatic logic [1:0] msk = { HAVE_MUL && Vld[2], HAVE_ADD && rvld }; + automatic logic [1:0] msk = { HAVE_MUL && Vld[1+HAVE_MUL], HAVE_ADD && rvld }; assert(!(inv & msk)) else $warning("%m generated invalid output."); assert(!(ovf & msk)) else $warning("%m generated an overflow."); assert(!(unf & msk)) else $warning("%m generated an underflow."); diff --git a/finn-rtllib/layernorm/layernorm.sv b/finn-rtllib/layernorm/layernorm.sv index abd7f8c801..d671ef78bc 100644 --- a/finn-rtllib/layernorm/layernorm.sv +++ b/finn-rtllib/layernorm/layernorm.sv @@ -56,10 +56,6 @@ module layernorm #( $error("%m: SIMD(%0d) must divide N(%0d).", SIMD, N); $finish; end - if(NN <= 12) begin - $error("%m: N/SIMD must be larger than 12 for rsqrt throughput."); - $finish; - end end typedef logic [31:0] fp32; @@ -142,20 +138,12 @@ module layernorm #( // Balancing edge delays in trees with incomplete leaf level typedef bit edge_delays_t[2*SIMD-1]; function edge_delays_t INIT_EDGE_DELAYS(); - localparam int unsigned LEVELS = 1+$clog2(SIMD); + // Use binary encoding of number of short leaves, `sig`, to infer + // most common parent for retiming on each level. + localparam int unsigned FULL_FANIN = 2**$clog2(SIMD); automatic edge_delays_t d = '{ default: 0 }; - // Put delay onto leaves that are not on last level - for(int unsigned i = SIMD-1; i < 2*SIMD-1; i++) begin - if($clog2(i+2) == LEVELS) break; - d[i] = 1; - end - // Move delay shared between children to their parent - for(int unsigned i = SIMD-1; i > 0; i--) begin - if(d[2*i+1]) begin - d[2*i+1] = 0; - d[2*i+2] = 0; - d[i] = 1; - end + for(int unsigned sig = FULL_FANIN - SIMD, i = FULL_FANIN - 1; sig; i >>= 1, sig >>= 1) begin + d[i-sig] = sig[0]; end return d; endfunction : INIT_EDGE_DELAYS @@ -210,7 +198,10 @@ module layernorm #( end 1: /* Var: inverse square root */ begin uwire vrdy; - rsqrtf #(.FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) vari_rsqurt ( + rsqrtf #( + .SUSTAINABLE_INTERVAL(NN), + .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL) + ) vari_rsqurt ( .clk, .rst, .x(total.dat), .xvld(total.vld), .xrdy(vrdy), .r(norm .dat), .rvld(norm .vld) @@ -259,7 +250,7 @@ module layernorm #( else Credit <= Credit + (issue == settle? 0 : settle? 1 : -1); end - logic signed [$clog2(NN-1):0] Cnt = 0; // [-NN,] -NN+1, ..., -1, 0 + logic signed [$clog2(NN):0] Cnt = 0; // [-NN,] -NN+1, ..., -1, 0 assign norm0_rdy = !Cnt[$left(Cnt)]; assign issue = have_cap && (norm0.vld || Cnt[$left(Cnt)]); uwire bload = norm0.vld && norm0_rdy; diff --git a/finn-rtllib/layernorm/rsqrtf.sv b/finn-rtllib/layernorm/rsqrtf.sv index 249108b166..3a47c2d7b9 100644 --- a/finn-rtllib/layernorm/rsqrtf.sv +++ b/finn-rtllib/layernorm/rsqrtf.sv @@ -16,106 +16,101 @@ * // y.f = y.f * ( 1.5f - ( x2 * y.f * y.f ) ); // 2nd iteration * return y.f; * } + * + * Implemented compute schedule on DSP slice: + * - 12-stage schedule with three 4-stage iterations through DSP pipeline + * - 4 jobs can be interleaved when filling the entire DSP pipeline + * + * │ │ │ 1 1 │ + * Stage: │ 0 1 2 3 │ 4 5 6 7 │ 8 9 0 1 │ + * Maturity: | ITER1 | ITER2 | ITER3 | + * ▼ ▼ ▼ + * A1: y ──►A─\───A─A─A───A─\───A─A─A───A─\ // hold or re-feed when interleaving + * * * * + * BD1: x/2 ──►B─/ \ D─/ \ D─/ \ + * \ / \ / \ + * M2: M │ M │ M + * \ │ \ │ \ + * M3: M │ M │ M + * \ / \ / \ + * P4: P P P──► + * / + * C: 1.5 ─────────────────────────/─────────────────── + * ▲ ▲ ▲ + * │ │ │ + * x/2*y 1.5-x/2*y*y (1.5-x/2*y*y)*y + * ***************************************************************************/ -module rsqrtf #( +// Local DSP instantiation wrapper. +module rsqrtf_dspfp32 #( bit FORCE_BEHAVIORAL = 0 )( - // Global Control - input logic clk, - input logic rst, + input logic clk, + input logic rst, - input logic [31:0] x, - input logic xvld, - output logic xrdy, + input logic ena, + input logic bsel, + input logic csel, + input logic [31:0] a, + input logic [31:0] b, + input logic [31:0] c, + input logic [31:0] d, - output logic [31:0] r, - output logic rvld + input logic rvld, + output logic [31:0] r ); - logic signed [4:0] Cnt = -1; // 10, 9, ..., 0, -1 - logic Vld = 0; - always_ff @(posedge clk) begin - if(rst) begin - Cnt <= -1; - Vld <= 0; - end - else begin - Cnt <= Cnt + (xrdy && xvld? 11 : Cnt[4]? 0 : -1); - Vld <= !Cnt[3:0]; - end - end - assign xrdy = Cnt[4]; - assign rvld = Vld; - - uwire bsel = Cnt[3]; // B rather than D to MUL - uwire csel = Cnt[2]; // C rather than 0 to ADD - uwire [31:0] a = ('h5f3759df - x[31:1]); - uwire [31:0] b = { x[31], x[30:23]-1, x[22:0]}; // 0.5*x - uwire [31:0] c = $shortrealtobits(1.5); - uwire [31:0] d = r; - - logic [1:0] inv; - logic [1:0] ovf; - logic [1:0] unf; - always_ff @(posedge clk) begin - automatic logic [1:0] msk = { &Cnt[1:0] && (!xrdy || rvld), Cnt[1:0] == 0 }; - assert(!(inv & msk)) else $warning("%m generated invalid output."); - assert(!(ovf & msk)) else $warning("%m generated an overflow."); - assert(!(unf & msk)) else $warning("%m generated an underflow."); - end + logic invalid; + logic overflow; + logic underflow; + localparam logic [6:0] MODE_MUL = { 2'b00, 3'b010, 2'b01 }; + localparam logic [6:0] MODE_SUB = { 2'b01, 3'b110, 2'b01 }; if(FORCE_BEHAVIORAL) begin : genBehav logic [31:0] A1 = 'x; logic [31:0] B1 = 'x; logic [31:0] D1 = 'x; - logic BSel1 = 'x; - logic CSel3 = 'x; + logic BSel1 = 'x; + logic CSel3 = 'x; logic [31:0] M[2:3] = '{ default: 'x }; logic [31:0] P4 = 'x; + always_ff @(posedge clk) begin - if(rst) begin - end - else begin - if(xrdy) A1 <= a; - B1 <= b; - D1 <= d; - BSel1 <= bsel; - CSel3 <= csel; - M <= { - $shortrealtobits($bitstoshortreal(A1)*$bitstoshortreal(BSel1? B1 : D1)), - M[2] - }; - P4 <= CSel3? $shortrealtobits(1.5 - $bitstoshortreal(M[3])) : M[3]; - end + if(ena) A1 <= a; + B1 <= b; + D1 <= d; + BSel1 <= bsel; + CSel3 <= csel; + M <= { + $shortrealtobits($bitstoshortreal(A1)*$bitstoshortreal(BSel1? B1 : D1)), + M[2] + }; + P4 <= CSel3? $shortrealtobits(1.5 - $bitstoshortreal(M[3])) : M[3]; end - assign r = P4; + + assign r = P4; + always_comb begin - inv = '0; - ovf = '0; - unf = '0; + invalid = 0; + overflow = 0; + underflow = 0; + if(&r[30-:8]) begin - if(|r[0+:23]) inv[1] = 1; - else ovf[1] = 1; - end - if(&M[3][30-:8]) begin - if(|M[3][0+:23]) inv[0] = 1; - else ovf[0] = 1; + if(|r[0+:23]) invalid = 1; + else overflow = 1; end end end : genBehav else begin : genDSP DSPFP32 #( - // Feature Control Attributes: Data Path Selection - .A_FPTYPE("B32"), // B16, B32 - .A_INPUT("DIRECT"), // Selects A input source, "DIRECT" (A port) or "CASCADE" (ACIN port) - .BCASCSEL("B"), // Selects B cascade out data (B, D). - .B_D_FPTYPE("B32"), // B16, B32 - .B_INPUT("DIRECT"), // Selects B input source, "DIRECT" (B port) or "CASCADE" (BCIN port) - .PCOUTSEL("FPA"), // Select PCOUT output cascade of DSPFP32 (FPA, FPM) - .USE_MULT("MULTIPLY"), // Select multiplier usage (DYNAMIC, MULTIPLY, NONE) - - // Programmable Inversion Attributes: Specifies built-in programmable inversion on specific pins + .A_FPTYPE("B32"), + .A_INPUT("DIRECT"), + .BCASCSEL("B"), + .B_D_FPTYPE("B32"), + .B_INPUT("DIRECT"), + .PCOUTSEL("FPA"), + .USE_MULT("MULTIPLY"), .IS_CLK_INVERTED(1'b0), .IS_FPINMODE_INVERTED(1'b0), .IS_FPOPMODE_INVERTED(7'b0000000), @@ -128,88 +123,267 @@ module rsqrtf #( .IS_RSTFPMPIPE_INVERTED(1'b0), .IS_RSTFPM_INVERTED(1'b0), .IS_RSTFPOPMODE_INVERTED(1'b0), - - // Register Control Attributes: Pipeline Register Configuration - .ACASCREG(1), // Number of pipeline stages between A/ACIN and ACOUT (0-2) - .AREG(1), // Pipeline stages for A (0-2) - .FPA_PREG(1), // Pipeline stages for FPA output (0-1) - .FPBREG(1), // Pipeline stages for B inputs (0-1) - .FPCREG(0), // Pipeline stages for C input (0-3) - .FPDREG(1), // Pipeline stages for D inputs (0-1) - .FPMPIPEREG(1), // Selects the number of FPMPIPE registers (0-1) - .FPM_PREG(1), // Pipeline stages for FPM output (0-1) - .FPOPMREG(1), // Selects the length of the FPOPMODE pipeline (0-3) - .INMODEREG(1), // Selects the number of FPINMODE registers (0-1) - .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). - ) - DSPFP32_inst ( - // Cascade outputs: Cascade Ports - .ACOUT_EXP(), - .ACOUT_MAN(), - .ACOUT_SIGN(), - .BCOUT_EXP(), - .BCOUT_MAN(), - .BCOUT_SIGN(), + .ACASCREG(1), + .AREG(1), + .FPA_PREG(1), + .FPBREG(1), + .FPCREG(0), + .FPDREG(1), + .FPMPIPEREG(1), + .FPM_PREG(1), + .FPOPMREG(1), + .INMODEREG(1), + .RESET_MODE("SYNC") + ) DSPFP32_inst ( + .ACOUT_EXP(), .ACOUT_MAN(), .ACOUT_SIGN(), + .BCOUT_EXP(), .BCOUT_MAN(), .BCOUT_SIGN(), .PCOUT(), - - // Data outputs: Data Ports - .FPM_INVALID(inv[0]), - .FPM_OVERFLOW(ovf[0]), - .FPM_UNDERFLOW(unf[0]), - .FPM_OUT(), - .FPA_INVALID(inv[1]), - .FPA_OVERFLOW(ovf[1]), - .FPA_UNDERFLOW(unf[1]), - .FPA_OUT(r), - // Cascade inputs: Cascade Ports - .ACIN_EXP('x), - .ACIN_MAN('x), - .ACIN_SIGN('x), - .BCIN_EXP('x), - .BCIN_MAN('x), - .BCIN_SIGN('x), + .FPM_INVALID(), .FPM_OVERFLOW(), .FPM_UNDERFLOW(), .FPM_OUT(), + .FPA_INVALID(invalid), .FPA_OVERFLOW(overflow), .FPA_UNDERFLOW(underflow), .FPA_OUT(r), + .ACIN_EXP('x), .ACIN_MAN('x), .ACIN_SIGN('x), + .BCIN_EXP('x), .BCIN_MAN('x), .BCIN_SIGN('x), .PCIN('x), - // Control inputs: Control Inputs/Status Bits - .CLK(clk), - .FPINMODE(bsel), - .FPOPMODE({ - // (csel? sub : add)(csel? C : 0, M) - { 1'b0, csel }, { csel, 2'b10 }, 2'b01 - }), - // Data inputs: Data Ports - .A_SIGN(a[31]), - .A_EXP(a[30:23]), - .A_MAN(a[22:0]), - .B_SIGN(b[31]), - .B_EXP(b[30:23]), - .B_MAN(b[22:0]), + .CLK(clk), .FPINMODE(bsel), .FPOPMODE(csel? MODE_SUB : MODE_MUL), + .A_SIGN(a[31]), .A_EXP(a[30:23]), .A_MAN(a[22:0]), + .B_SIGN(b[31]), .B_EXP(b[30:23]), .B_MAN(b[22:0]), .C(c), - .D_SIGN(d[31]), - .D_EXP(d[30:23]), - .D_MAN(d[22:0]), - // Reset/Clock Enable inputs: Reset/Clock Enable Inputs + .D_SIGN(d[31]), .D_EXP(d[30:23]), .D_MAN(d[22:0]), .ASYNC_RST('0), - .CEA1('0), - .CEA2(xrdy), - .CEB('1), - .CEC('0), - .CED('1), - .CEFPA('1), - .CEFPINMODE('1), - .CEFPM('1), - .CEFPMPIPE('1), - .CEFPOPMODE('1), - - .RSTA('0), - .RSTB('0), - .RSTC('0), - .RSTD('0), - .RSTFPA('0), - .RSTFPINMODE('0), - .RSTFPM('0), - .RSTFPMPIPE('0), - .RSTFPOPMODE('0) + .CEA1('0), .CEA2(ena), + .CEB('1), .CEC('0), .CED('1), + .CEFPA('1), .CEFPINMODE('1), .CEFPM('1), .CEFPMPIPE('1), .CEFPOPMODE('1), + .RSTA('0), .RSTB('0), .RSTC('0), .RSTD('0), + .RSTFPA('0), .RSTFPINMODE('0), .RSTFPM('0), .RSTFPMPIPE('0), .RSTFPOPMODE('0) ); end : genDSP + always_ff @(posedge clk) begin + if(!rst && rvld) begin + assert(!invalid) else $warning("%m generated invalid output."); + assert(!overflow) else $warning("%m generated an overflow."); + assert(!underflow) else $warning("%m generated an underflow."); + end + end + +endmodule : rsqrtf_dspfp32 + +module rsqrtf #( + int unsigned SUSTAINABLE_INTERVAL, // Average II sustained over 12 Cycles + // Guarantee readiness at II, do not expose delays of arbitrating between iterations: + // - by intermittent input delays or + // - by revoking readiness. + bit STABLE_READINESS = 1, + bit FORCE_BEHAVIORAL = 0 +)( + // Global Control + input logic clk, + input logic rst, + + input logic [31:0] x, + input logic xvld, + output logic xrdy, + + output logic [31:0] r, + output logic rvld +); + + // Isolate input from arbitration between iterations as needed + uwire [31:0] xx; + uwire xxvld; + uwire xxrdy; + if(STABLE_READINESS && (1 < SUSTAINABLE_INTERVAL) && (SUSTAINABLE_INTERVAL < 9)) begin : genSkid + queue #(.DATA_WIDTH(32), .ELASTICITY(2)) input_queue ( + .clk, .rst, + .idat(x), .ivld(xvld), .irdy(xrdy), + .odat(xx), .ovld(xxvld), .ordy(xxrdy) + ); + end : genSkid + else begin : genStraight + assign xx = x; + assign xxvld = xvld; + assign xrdy = xxrdy; + end : genStraight + + uwire xsel; // Feed new input vs. re-feed for interleaving + uwire [31:0] afb; + uwire [31:0] a = (xsel? 'h5f3759df : afb) - (xsel? xx[31:1] : 0); + uwire [31:0] b = { xx[31], xx[30:23]-1, xx[22:0]}; // 0.5*x + uwire [31:0] c = $shortrealtobits(1.5); + + case(SUSTAINABLE_INTERVAL) + 0: initial begin + $error("SUSTAINABLE_INTERVAL must be positive."); + $finish; + end + 1: begin : genII1 + localparam int unsigned DSP_LATENCY = 4; + localparam int unsigned LAT = 3*DSP_LATENCY; + + logic Vld[LAT] = '{ default: 0 }; + logic [31:0] A[8] = '{ default: 'x }; + uwire [31:0] p[2]; + always_ff @(posedge clk) begin + if(rst) begin + Vld <= '{ default: 0 }; + A <= '{ default: 'x }; + end + else begin + Vld <= { xxvld, Vld[0:LAT-2] }; + A <= { a, A[0:6] }; + end + end + assign xsel = 1; + assign xxrdy = 1; + assign rvld = Vld[LAT-1]; + + rsqrtf_dspfp32 DSP0 ( + .clk, .rst, + .ena('1), .bsel('1), .csel('0), + .a, .b, .c('x), .d('x), + .rvld('0), .r(p[0]) + ); + + rsqrtf_dspfp32 DSP1 ( + .clk, .rst, + .ena('1), .bsel('0), .csel('1), + .a(A[3]), .b('x), .c, .d(p[0]), + .rvld('0), .r(p[1]) + ); + + rsqrtf_dspfp32 DSP2 ( + .clk, .rst, + .ena('1), .bsel('0), .csel('0), + .a(A[7]), .b('x), .c('x), .d(p[1]), + .rvld, .r + ); + end : genII1 + 2: begin : genII2 + + logic Vld[12] = '{ default: 0 }; + always_ff @(posedge clk) begin + if(rst) Vld <= '{ default: 0 }; + else Vld <= { xxrdy && xxvld, Vld[0:10] }; + end + + logic [31:0] A[8] = '{ default: 'x }; + always_ff @(posedge clk) begin + if(rst) A <= '{ default: 'x }; + else A <= { a, A[0:6] }; + end + + assign rvld = Vld[11]; + assign xxrdy = !Vld[7]; + assign xsel = xxrdy; + assign afb = A[7]; + + uwire [31:0] p; // Second DSP Output + rsqrtf_dspfp32 DSP0 ( + .clk, .rst, + .ena('1), .bsel(xsel), .csel('0), + .a, .b, .c('x), .d(p), + .rvld, .r + ); + + rsqrtf_dspfp32 DSP1 ( + .clk, .rst, + .ena('1), .bsel('0), .csel('1), + .a(A[3]), .b('x), .c, .d(r), + .rvld('0), .r(p) + ); + end : genII2 + default: begin : genSharedDSP + uwire aload; + uwire bsel; + uwire csel; + + if(SUSTAINABLE_INTERVAL < 9) begin : genInterleave + typedef enum logic [1:0] { + // bsel/3 csel/1 + IDLE = 2'b11, // 1 x + ITER1 = 2'b00, // 0 0 + ITER2 = 2'b01, // 0 1 + ITER3 = 2'b10, // 1 0 + BSEL = 2'b1x, + CSEL = 2'bx1 + } maturity_t; + + maturity_t Maturity[4] = '{ default: IDLE }; + logic [31:0] A[4] = '{ default: 'x }; + always_ff @(posedge clk) begin + if(rst) begin + Maturity <= '{ default: IDLE }; + A <= '{ default: 'x }; + end + else begin + unique casex(Maturity[3]) + ITER1: Maturity[0] <= ITER2; + ITER2: Maturity[0] <= ITER3; + ITER3, + IDLE: Maturity[0] <= xxvld? ITER1 : IDLE; + endcase + Maturity[1:3] <= Maturity[0:2]; + A <= { a, A[0:2] }; + end + end + assign bsel = Maturity[3] ==? BSEL; + assign csel = Maturity[1] ==? CSEL; + assign xsel = bsel; + assign xxrdy = bsel; + assign rvld = Maturity[3] ==? ITER3; + assign aload = 1; + assign afb = A[3]; + end : genInterleave + else if(SUSTAINABLE_INTERVAL < 12) begin : genOverlapped + logic [3:0] Cnt = 8; + logic [3:0] RVld = '0; + uwire cnt7 = Cnt ==? 4'bx111; + uwire cnt8 = Cnt ==? 4'b1xxx; + always_ff @(posedge clk) begin + if(rst) begin + Cnt <= 8; + RVld <= '0; + end + else begin + Cnt <= Cnt + (!cnt8? 1 : xxvld? 8 : 0); + RVld <= { cnt7, RVld[3:1] }; + end + end + assign bsel = Cnt[3]; + assign csel = Cnt[2]; + assign xsel = 1; + assign xxrdy = bsel; + assign rvld = RVld[0]; + assign aload = bsel; + end : genOverlapped + else begin : genExclusive + logic signed [3:0] Cnt = -1; + logic RVld = 0; + uwire cnt10 = Cnt ==? 4'b101x; + always_ff @(posedge clk) begin + if(rst) begin + Cnt <= -1; + RVld <= 0; + end + else begin + Cnt <= Cnt + (cnt10? 'b101 : xxvld || !bsel); + RVld <= cnt10; + end + end + assign bsel = &Cnt[3:2]; + assign csel = Cnt[2]; + assign xsel = 1; + assign xxrdy = bsel; + assign rvld = RVld; + assign aload = bsel; + end : genExclusive + + rsqrtf_dspfp32 #(.FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) DSPFP32_inst ( + .clk, .rst, + .ena(aload), .bsel, .csel, + .a, .b, .c, .d(r), + .rvld, .r + ); + end : genSharedDSP + endcase + endmodule : rsqrtf diff --git a/finn-rtllib/layernorm/tb/accuf_tb.sv b/finn-rtllib/layernorm/tb/accuf_tb.sv new file mode 100644 index 0000000000..df88489af7 --- /dev/null +++ b/finn-rtllib/layernorm/tb/accuf_tb.sv @@ -0,0 +1,109 @@ +/**************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + ***************************************************************************/ + +module accuf_tb; + + localparam bit FORCE_BEHAVIORAL = 1; + + typedef struct { + shortreal scale; + shortreal bias; + } cfg_t; + localparam int unsigned TESTS = 3; + localparam cfg_t TEST_CFG[TESTS] = '{ + cfg_t'{ scale: 1.0, bias: 0.0 }, + cfg_t'{ scale: 0.2, bias: 0.0 }, + cfg_t'{ scale: 0.0, bias: 4.0 } + }; + + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic rst = 1; + initial begin + repeat(12) @(posedge clk); + rst <= 0; + end + + // Test Instantiation + bit [TESTS-1:0] done = '0; + always_comb begin + if(&done) $finish; + end + for(genvar test = 0; test < TESTS; test++) begin : genTests + localparam cfg_t CFG = TEST_CFG[test]; + + // DUT + uwire [31:0] a; + logic avld; + logic alst; + uwire [31:0] s; + uwire svld; + accuf #(.SCALE(CFG.scale), .BIAS(CFG.bias), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) dut ( + .clk, .rst, + .a, .avld, .alst, + .s, .svld + ); + + // Stimulus + shortreal Q[$]; + shortreal fa, fs; + assign a = $shortrealtobits(fa); + assign fs = $bitstoshortreal(s); + initial begin + automatic shortreal s = CFG.bias; + + fa = 'x; + avld = 0; + alst = 'x; + @(posedge clk iff !rst); + + for(int unsigned i = 0; i < 417; i++) begin + while($urandom()%23 == 0) @(posedge clk); + avld <= 1; + fa <= i; + alst <= (i == 416) || ($urandom()%11 == 0); + @(posedge clk); + avld <= 0; + + s += (CFG.scale == 0.0? fa : CFG.scale) * fa; + if(alst) begin + Q.push_back(s); + s = CFG.bias; + end + end + + repeat(5) @(posedge clk); + assert(Q.size() == 0) else begin + $error("Test #%0d: Missing output.", test); + $stop; + end + $display("Test #%0d completed.", test); + done[test] = 1; + end + + // Checker + always_ff @(posedge clk iff svld) begin + automatic shortreal exp, err; + assert(Q.size) else begin + $error("Test #%0d: Spurious output.", test); + $stop; + end + exp = Q.pop_front(); + err = fs - exp; + err *= err; + assert(err < 1e-8) else begin + $error("Test #%0d: Output mismatch: %f instead of %f", test, fs, exp); + $stop; + end + end + + end : genTests + +endmodule : accuf_tb diff --git a/finn-rtllib/layernorm/tb/binopf_tb.sv b/finn-rtllib/layernorm/tb/binopf_tb.sv new file mode 100644 index 0000000000..a06d01fc13 --- /dev/null +++ b/finn-rtllib/layernorm/tb/binopf_tb.sv @@ -0,0 +1,151 @@ +/**************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + ***************************************************************************/ + +module binopf_tb; + + localparam bit FORCE_BEHAVIORAL = 1; + typedef struct { + string op; + shortreal scale; + bit delay; + } cfg_t; + localparam int unsigned TESTS = 8; + localparam cfg_t CFGS[TESTS] = '{ + '{ "ADD", 1.0, 0 }, + '{ "ADD", 1.0, 1 }, + '{ "ADD", 1.3, 0 }, + '{ "SUB", 1.0, 0 }, + '{ "SUB", 0.4, 0 }, + '{ "SBR", 1.0, 0 }, + '{ "SBR", 3.7, 0 }, + '{ "MUL", 1.0, 0 } + }; + + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic rst = 1; + initial begin + repeat(12) @(posedge clk); + rst <= 0; + end + + // Test Instantiations + bit [TESTS-1:0] done = '0; + always_comb begin + if(&done) $finish; + end + for(genvar test = 0; test < TESTS; test++) begin : genTests + localparam cfg_t CFG = CFGS[test]; + localparam shortreal SCALE = CFG.scale; + + function shortreal compute_ref(input shortreal a, input shortreal b); + unique case(CFG.op) + "ADD": return a + SCALE*b; + "SUB": return a - SCALE*b; + "SBR": return SCALE*b - a; + "MUL": return a * b; + endcase + endfunction : compute_ref + + // DUT + logic avld; + shortreal a; + logic bload; + shortreal b; + uwire rvld; + shortreal r; + if(1) begin : blkDUT + uwire [31:0] aa = $shortrealtobits(a); + uwire [31:0] bb = $shortrealtobits(b); + uwire [31:0] rr; + binopf #( + .OP(CFG.op), .B_SCALE(SCALE), + .A_MATCH_OP_DELAY(CFG.delay), + .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL) + ) dut ( + .clk, .rst, + .avld, .a(aa), .bload, .b(bb), + .rvld, .r(rr) + ); + assign r = $bitstoshortreal(rr); + end : blkDUT + + // Stimulus + shortreal B0; + shortreal Q[$]; + initial begin + + avld = 0; + a = 'x; + bload = 0; + b = 'x; + @(posedge clk iff !rst); + + // Fork off Background Update of `b` Input + fork + forever begin + automatic shortreal val = $urandom()%10000 - 5000.0; + bload <= 1; + b <= val; + B0 <= val; + @(posedge clk); + bload <= 0; + b <= 'x; + while($urandom()%37 != 0) @(posedge clk); + end + join_none + + // Run a Series of `a` Values + repeat(1739) begin + while($urandom()%17 == 0) @(posedge clk); + + avld <= 1; + a <= $urandom()%10000 - 5000.0; + @(posedge clk); + fork begin + automatic shortreal a0 = a; + if(CFG.delay) repeat(2) @(posedge clk); + Q.push_back(compute_ref(a0, B0)); + end join_none + avld <= 0; + a <= 'x; + end + + repeat(7) @(posedge clk); + assert(Q.size() == 0) else begin + $error("Test #%0d: Missing output.", test); + $stop; + end + $display("Test #%0d completed.", test); + done[test] = 1; + end + + // Checker + always_ff @(posedge clk iff rvld) begin + automatic shortreal exp, err; + assert(Q.size) else begin + $error("Test #%0d: Spurious output.", test); + $stop; + end + exp = Q.pop_front(); + err = r - exp; + err *= err; + assert((err < 1e-5) || ($shortrealtobits(r) == $shortrealtobits(exp))) else begin + $error( + "Test #%0d: Output mismatch: %f/%08x instead of %f/%08x", + test, r, $shortrealtobits(r), exp, $shortrealtobits(exp) + ); + $stop; + end + end + + end : genTests + +endmodule : binopf_tb diff --git a/finn-rtllib/layernorm/tb/layernorm_tb.sv b/finn-rtllib/layernorm/tb/layernorm_tb.sv new file mode 100644 index 0000000000..3af4533fa8 --- /dev/null +++ b/finn-rtllib/layernorm/tb/layernorm_tb.sv @@ -0,0 +1,144 @@ +/**************************************************************************** + * Copyright (C) 2025-2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + ***************************************************************************/ + + module layernorm_tb; + + localparam int unsigned ROUNDS = 19; + localparam bit FORCE_BEHAVIORAL = 1; + + typedef struct { + int unsigned n; + int unsigned simd; + } cfg_t; + localparam int unsigned TESTS = 9; + localparam cfg_t TEST_CFG[TESTS] = '{ + '{ 4, 4 }, // NN=1 + '{ 10, 5 }, // NN=2 + '{ 18, 6 }, // NN=3 + '{ 42, 7 }, // NN=6 + '{ 64, 8 }, // NN=8 + '{ 81, 9 }, // NN=9 + '{100, 10 }, // NN=10 + '{ 44, 4 }, // NN=11 + '{ 60, 5 } // NN=12 + }; + + //----------------------------------------------------------------------- + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic rst = 1; + initial begin + repeat(12) @(posedge clk); + rst <= 0; + end + + //----------------------------------------------------------------------- + // Test Instantiations + bit [TESTS-1:0] done = '0; + always_comb begin + if(&done) $finish(); + end + for(genvar test = 0; test < TESTS; test++) begin : genTests + localparam int unsigned N = TEST_CFG[test].n; + localparam int unsigned SIMD = TEST_CFG[test].simd; + typedef shortreal vec_t[N]; + + // DUT + logic [SIMD-1:0][31:0] xdat; + logic xvld; + uwire xrdy; + uwire [SIMD-1:0][31:0] ydat; + uwire yvld; + logic yrdy; + layernorm #(.N(N), .SIMD(SIMD), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) dut ( + .clk, .rst, + .xdat, .xvld, .xrdy, + .ydat, .yvld, .yrdy + ); + + // Stimulus + vec_t X[ROUNDS]; + initial begin + xdat = 'x; + xvld = 0; + @(posedge clk iff !rst); + + for(int unsigned r = 0; r < ROUNDS; r++) begin + static shortreal b; + static shortreal s; + b = $urandom()%129 - 53.0; + s = ($urandom()%29 + 1) / 1.3; + foreach(X[r][i]) X[r][i] = s*($urandom()%1537 - 757.0) + b; + + for(int unsigned i = 0; i < N; i += SIMD) begin + while($urandom()%23 == 0) @(posedge clk); + xvld <= 1; + for(int unsigned j = 0; j < SIMD; j++) begin + xdat[j] <= $shortrealtobits(X[r][i+j]); + end + @(posedge clk iff xrdy); + xdat <= 'x; + xvld <= 0; + end + end + + $display("[%0d] Input feed done.", test); + end + + // Output Checker + vec_t y, exp; + initial begin + yrdy = 0; + @(posedge clk iff !rst); + repeat(187) @(posedge clk); + + for(int unsigned r = 0; r < ROUNDS; r++) begin + static shortreal m, s; + + for(int unsigned i = 0; i < N; i += SIMD) begin + while($urandom()%5 == 0) @(posedge clk); + yrdy <= 1; + @(posedge clk iff yvld); + foreach(ydat[j]) y[i+j] = $bitstoshortreal(ydat[j]); + yrdy <= 0; + end + + m = 0.0; + foreach(X[r][i]) m += X[r][i]; + m /= N; + + s = 0.0; + foreach(exp[i]) begin + exp[i] = X[r][i] - m; + s += exp[i] * exp[i]; + end + s = 1/$sqrt(s/N); + + foreach(exp[i]) exp[i] *= s; + + foreach(y[i]) begin + static shortreal err; + err = (y[i]-exp[i])/exp[i]; + err *= err; + assert(err < 1e-5) else begin + $error("[%0d] Output mismatch: %7.4f instead of %7.4f", test, y[i], exp[i]); + $stop; + end + end + end + repeat(5) @(posedge clk); + + $display("[%0d] Test completed.", test); + done[test] <= 1; + end + + end : genTests + +endmodule : layernorm_tb diff --git a/finn-rtllib/layernorm/tb/queue_tb.sv b/finn-rtllib/layernorm/tb/queue_tb.sv new file mode 100644 index 0000000000..4029fb15f4 --- /dev/null +++ b/finn-rtllib/layernorm/tb/queue_tb.sv @@ -0,0 +1,139 @@ +/**************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + ***************************************************************************/ + +module queue_tb; + localparam int unsigned TXNS = 15317; + + //----------------------------------------------------------------------- + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic rst = 1; + initial begin + repeat(16) @(posedge clk); + rst <= 0; + end + + //----------------------------------------------------------------------- + // Tests + localparam int unsigned ELASTICITY_MIN = 2; + localparam int unsigned ELASTICITY_MAX = 17; + localparam int unsigned DATA_WIDTH = 13; + typedef logic [DATA_WIDTH-1:0] dat_t; + + bit [ELASTICITY_MAX:ELASTICITY_MIN] done = '0; + always_comb begin + if(&done) $finish; + end + + for(genvar test = ELASTICITY_MIN; test <= ELASTICITY_MAX; test++) begin : genTests + localparam int unsigned ELASTICITY = test; + + //- DUT ------------------------- + dat_t idat; + logic ivld; + uwire irdy; + uwire dat_t odat; + uwire ovld; + logic ordy; + queue #(.DATA_WIDTH(DATA_WIDTH), .ELASTICITY(ELASTICITY)) dut ( + .clk, .rst, + .idat, .ivld, .irdy, + .odat, .ovld, .ordy + ); + + //- Stimulus Feed --------------- + dat_t Q[$]; // Refernce Output + int unsigned BackCycles = 0; // Track induced Backpressure + initial begin + idat = 'x; + ivld = 0; + @(posedge clk iff !rst); + + repeat(TXNS) begin + automatic dat_t dat; + + if($urandom()%237 == 0) begin + repeat(2*ELASTICITY + 4) begin + @(posedge clk); + if(!irdy) begin + if(BackCycles > 0) BackCycles--; + else begin + $error("Test #%0d: Encountered unwarranted backpressure.", test); + $stop; + end + end + end + end + while($urandom()%53 == 0) begin + @(posedge clk); + if(!irdy) begin + if(BackCycles > 0) BackCycles--; + else begin + $error("Test #%0d: Encountered unwarranted backpressure.", test); + $stop; + end + end + end + + void'(std::randomize(dat)); + idat <= dat; + ivld <= 1; + Q.push_back(dat); + forever @(posedge clk) begin + if(irdy) break; + if(BackCycles > 0) BackCycles--; + else begin + $error("Test #%0d: Encountered unwarranted backpressure.", test); + $stop; + end + end + idat <= 'x; + ivld <= 0; + end + end + + //- Output Checker -------------- + initial begin + ordy = 0; + @(posedge clk iff !rst); + + repeat(TXNS) begin + automatic dat_t exp; + + if($urandom()%173 == 0) begin + repeat(2 * ELASTICITY + 5) begin + @(posedge clk); + BackCycles++; + end + end + while($urandom()%19 == 0) begin + @(posedge clk); + BackCycles++; + end + ordy <= 1; + @(posedge clk iff ovld); + assert(Q.size > 0) else begin + $error("Test #%0d: Spurious output.", test); + $stop; + end + exp = Q.pop_front(); + assert(odat === exp) else begin + $error("Test #%0d: Output mismatch: %0x instead of %0x.", test, odat, exp); + $stop; + end + ordy <= 0; + end + + $display("Test #%0d completed.", test); + done[test] <= 1; + end + end : genTests + +endmodule : queue_tb diff --git a/finn-rtllib/layernorm/tb/rsqrtf_tb.sv b/finn-rtllib/layernorm/tb/rsqrtf_tb.sv new file mode 100644 index 0000000000..42ed6f0342 --- /dev/null +++ b/finn-rtllib/layernorm/tb/rsqrtf_tb.sv @@ -0,0 +1,129 @@ +/**************************************************************************** + * Copyright (C) 2025-2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + ***************************************************************************/ + +module rsqrtf_tb; + + localparam bit FORCE_BEHAVIORAL = 0; + localparam int unsigned MIN_SUSTAINABLE_INTERVAL = 1; + localparam int unsigned MAX_SUSTAINABLE_INTERVAL = 15; + localparam int unsigned TEST_COUNT = MAX_SUSTAINABLE_INTERVAL - MIN_SUSTAINABLE_INTERVAL + 1; + localparam int unsigned ROUNDS = 137; + + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic rst = 1; + initial begin + repeat(12) @(posedge clk); + rst <= 0; + end + + bit [MAX_SUSTAINABLE_INTERVAL:MIN_SUSTAINABLE_INTERVAL] done = '0; + always_comb begin + if(&done) $finish; + end + + // Reference Compute + function shortreal q_rsqrt(input shortreal x); + automatic shortreal y = $bitstoshortreal('h5f3759df - ($shortrealtobits(x) >> 1)); + return y * (1.5 - (0.5 * x * y * y)); + endfunction : q_rsqrt + + for(genvar t = MIN_SUSTAINABLE_INTERVAL; t <= MAX_SUSTAINABLE_INTERVAL; t++) begin : genTests + + // DUT + shortreal fx; + uwire [31:0] x = $shortrealtobits(fx); + logic xvld; + uwire [31:0] r; + uwire rvld; + uwire xrdy; + rsqrtf #( + .SUSTAINABLE_INTERVAL(t), + .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL) + ) dut ( + .clk, .rst, + .x, .xvld, .xrdy, + .r, .rvld + ); + shortreal fr; + assign fr = $bitstoshortreal(r); + + // Stimulus + shortreal Q[$]; + initial begin + automatic int unsigned Round2Cycles = 0; + fx = 'x; + xvld = 0; + @(posedge clk iff !rst); + + // Round 1: intermittent feed with occasional stalls + for(int unsigned i = 1; i <= ROUNDS; i++) begin + while($urandom()%23 == 0) @(posedge clk); + fx <= i; + xvld <= 1; + @(posedge clk iff xrdy); + Q.push_back(fx); + fx <= 'x; + xvld <= 0; + end + repeat(12) @(posedge clk); + + // Round 2: feed as fast as the DUT accepts input + xvld <= 1; + for(int unsigned i = 0; i < ROUNDS;) begin + fx <= ROUNDS + i; + @(posedge clk); + Round2Cycles++; + if(xrdy) begin + Q.push_back(fx); + i++; + end + end + xvld <= 0; + fx <= 'x; + + $display("Test #%0d: Round-2 cycles/input = %0d/%0d = %0.2f", t, Round2Cycles, ROUNDS, real'(Round2Cycles)/ROUNDS); + + repeat(32) @(posedge clk); + assert(Q.size() == 0) else begin + $error("Test #%0d: Missing %0d outputs.", t, Q.size()); + $stop; + end + + done[t] = 1; + end + + // Checker + int unsigned Checks = 0; + always_ff @(posedge clk iff rvld) begin + automatic shortreal x, exp, err; + assert(Q.size()) else begin + $error("Test #%0d: Spurious output.", t); + $stop; + end + x = Q.pop_front(); + exp = q_rsqrt(x); + + err = fr - exp; + err *= err; + assert(err < 1e-8) else begin + $error("Test #%0d: Output mismatch for %f: %f instead of %f", t, x, fr, exp); + $stop; + end + Checks <= Checks + 1; + end + + final begin + assert(Checks == 2*ROUNDS) $display("Test #%0d: Successfully performed %0d checks.", t, Checks); + else $error("Test #%0d: Unexpected number of checks: %0d instead of %0d.", t, Checks, 2*ROUNDS); + end + end : genTests + +endmodule : rsqrtf_tb diff --git a/src/finn/custom_op/fpgadataflow/rtl/layernorm_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/layernorm_rtl.py index 3f048d2f09..12540cc1a2 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/layernorm_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/layernorm_rtl.py @@ -43,7 +43,6 @@ def generate_hdl(self, model, fpgapart, clk): n % simd == 0 ), """Requirement N (last dim) divisable by SIMD is violated. Please set SIMD to a different value""" - assert n // simd > 12, "N/SIMD must be larger than 12 for rsqrt throughput." code_gen_dict = { "$N$": int(n), "$SIMD$": int(simd), @@ -132,8 +131,6 @@ def get_exp_cycles(self): n % simd == 0 ), """Requirement N (last dim) divisable by SIMD is violated. Please set SIMD to a different value""" - assert n // simd > 12, "N/SIMD must be larger than 12 for rsqrt throughput." - val_queue_len_0 = n // simd + math.ceil(math.log2(simd)) * 2 + 7 val_queue_len_1 = n // simd + math.ceil(math.log2(simd)) * 2 + 24 exp_cycles = val_queue_len_0 + val_queue_len_1 + np.prod(idim) // simd + 5 diff --git a/tests/fpgadataflow/test_fpgadataflow_layernorm.py b/tests/fpgadataflow/test_fpgadataflow_layernorm.py index a8af0e755b..4b12f0fd0b 100644 --- a/tests/fpgadataflow/test_fpgadataflow_layernorm.py +++ b/tests/fpgadataflow/test_fpgadataflow_layernorm.py @@ -97,6 +97,7 @@ def create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon): ["node_by_node", pytest.param("stitched_ip", marks=pytest.mark.xfail(reason="sim bug"))], ) def test_fpgadataflow_rtl_layernorm(idt, ishape, simd, sim_style): + """Test RTL LayerNorm with N/SIMD > 12 (original regime).""" model = create_layernorm_model( idt, ishape, has_scale=True, has_bias=True, epsilon=9.999999960041972e-13 ) @@ -153,6 +154,67 @@ def test_fpgadataflow_rtl_layernorm(idt, ishape, simd, sim_style): assert exp_cycles != 0 +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("idt", [DataType["FLOAT32"]]) +@pytest.mark.parametrize( + "ishape,simd", + [ + ([1, 4], 4), # NN=1 -> rsqrt genII1 (3 DSPs) + ([1, 10], 5), # NN=2 -> rsqrt genII2 (2 DSPs) + ([1, 18], 6), # NN=3 -> rsqrt genInterleave + ([1, 42], 7), # NN=6 -> rsqrt genInterleave + ([1, 64], 8), # NN=8 -> rsqrt genInterleave + ([1, 81], 9), # NN=9 -> rsqrt genOverlapped + ([1, 100], 10), # NN=10 -> rsqrt genOverlapped + ([1, 44], 4), # NN=11 -> rsqrt genOverlapped + ], +) +def test_fpgadataflow_rtl_layernorm_low_simd_ratio(idt, ishape, simd): + """Test RTL LayerNorm with N/SIMD <= 12, exercising the new rsqrt strategies.""" + model = create_layernorm_model( + idt, ishape, has_scale=True, has_bias=True, epsilon=9.999999960041972e-13 + ) + + # reference calculation + input = gen_finn_dt_tensor(idt, ishape) + input_t = {model.graph.input[0].name: input} + + y_ref = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + model = model.transform(ExtractNormScaleBias()) + + model = model.transform(to_hw.InferLayerNorm()) + model = model.transform(to_hw.InferElementwiseBinaryOperation()) + input_t = {model.graph.input[0].name: input} + + y_hw = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + assert np.allclose(y_ref, y_hw, rtol=1e-3, atol=2**-4) + + model = model.transform(SpecializeLayers(test_fpga_part)) + model = model.transform(GiveUniqueNodeNames()) + + assert model.graph.node[0].op_type == "LayerNorm_rtl", "LayerNorm wasn't converted to RTL Layer" + + getCustomOp(model.graph.node[0]).set_nodeattr("SIMD", simd) + + # Execute node-by-node RTL simulation + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareIP(test_fpga_part, target_clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + input_t = {model.graph.input[0].name: input} + + y_rtl = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + + assert np.allclose(y_ref, y_rtl, rtol=1e-3, atol=2**-4) + + @pytest.mark.fpgadataflow @pytest.mark.vivado @pytest.mark.slow