diff --git a/cmd/rtrdump/rtrdump.go b/cmd/rtrdump/rtrdump.go index 73ffef9..e5c59b9 100644 --- a/cmd/rtrdump/rtrdump.go +++ b/cmd/rtrdump/rtrdump.go @@ -107,6 +107,7 @@ func (c *Client) HandlePDU(cs *rtr.ClientSession, pdu rtr.PDU) { Pubkey: pdu.SubjectPublicKeyInfo, Ski: skiHex, } + c.Data.Metadata.CountBgpSecKeys++ c.Data.BgpSecKeys = append(c.Data.BgpSecKeys, rj) if *LogDataPDU { @@ -122,6 +123,7 @@ func (c *Client) HandlePDU(cs *rtr.ClientSession, pdu rtr.PDU) { Providers: pdu.ProviderASNumbers, } + c.Data.Metadata.CountASPAs++ c.Data.ASPA = append(c.Data.ASPA, aj) if *LogDataPDU { diff --git a/cmd/stayrtr/stayrtr.go b/cmd/stayrtr/stayrtr.go index 9e8a507..dfa7455 100644 --- a/cmd/stayrtr/stayrtr.go +++ b/cmd/stayrtr/stayrtr.go @@ -329,8 +329,8 @@ func processData(vrplistjson []prefixfile.VRPJson, } } - // Ensure that these are sorted, otherwise they - // don't hash right. + // Ensure that Providers are sorted + // Required by RFC and also otherwise they don't hash right. sort.Slice(v.Providers, func(i, j int) bool { return v.Providers[i] < v.Providers[j] }) diff --git a/lib/client_test.go b/lib/client_test.go index affe119..7a84fdc 100644 --- a/lib/client_test.go +++ b/lib/client_test.go @@ -129,10 +129,8 @@ func TestRouterKeyEncodeDecode(t *testing.T) { func TestASPAEncodeDecode(t *testing.T) { p := &PDUASPA{ - Version: 1, + Version: 2, Flags: 1, - AFIFlags: 1, - ProviderASCount: 2, CustomerASNumber: 64497, ProviderASNumbers: []uint32{64498, 64499}, } diff --git a/lib/server.go b/lib/server.go index d8740c3..574044d 100644 --- a/lib/server.go +++ b/lib/server.go @@ -1066,24 +1066,13 @@ func (c *Client) SendData(sd SendableData) { return } - pdu4 := &PDUASPA{ + pdu := &PDUASPA{ Version: c.version, Flags: t.Flags, - AFIFlags: AFI_IPv4, - ProviderASCount: uint16(len(t.Providers)), CustomerASNumber: t.CustomerASN, ProviderASNumbers: t.Providers, } - pdu6 := &PDUASPA{ - Version: c.version, - Flags: t.Flags, - AFIFlags: AFI_IPv6, - ProviderASCount: uint16(len(t.Providers)), - CustomerASNumber: t.CustomerASN, - ProviderASNumbers: t.Providers, - } - c.SendPDU(pdu4) - c.SendPDU(pdu6) + c.SendPDU(pdu) } } diff --git a/lib/structs.go b/lib/structs.go index 6d8bf8e..cd7e58a 100644 --- a/lib/structs.go +++ b/lib/structs.go @@ -26,11 +26,11 @@ const ( // We ignore the theoretically unbounded length of SKIs for router keys. // RPs should validate that this has the correct length. // - // maximum size of ASPA PDU payload: - // * 2^16 providers * 32bit = 262144 bytes - // * length is inclusive of header: 8 bytes - // * flags/afi flags/provider as/customer AS: 16 bytes - messageMaxSize = 262168 + // Maximum size of ASPA PDU payload: + // * header + length field: 8 bytes + // * Customer ASID: 4 bytes + // * 20,002 providers * 32bit = 80,008 bytes + messageMaxSize = 80020 PROTOCOL_VERSION_0 = 0 PROTOCOL_VERSION_1 = 1 @@ -526,14 +526,12 @@ func (pdu *PDUErrorReport) Write(wr io.Writer) { type PDUASPA struct { Version uint8 Flags uint8 - AFIFlags uint8 - ProviderASCount uint16 CustomerASNumber uint32 ProviderASNumbers []uint32 } func (pdu *PDUASPA) String() string { - return fmt.Sprintf("PDU ASPA v%d TODO", pdu.Version) // TODO + return fmt.Sprintf("PDU ASPA v%d TODO", pdu.Version) // XXX: TODO } func (pdu *PDUASPA) Bytes() []byte { @@ -557,12 +555,11 @@ func (pdu *PDUASPA) GetType() uint8 { func (pdu *PDUASPA) Write(wr io.Writer) { binary.Write(wr, binary.BigEndian, uint8(pdu.Version)) binary.Write(wr, binary.BigEndian, uint8(PDU_ID_ASPA)) - binary.Write(wr, binary.BigEndian, uint16(0)) - binary.Write(wr, binary.BigEndian, uint32(16+(len(pdu.ProviderASNumbers)*4))) binary.Write(wr, binary.BigEndian, uint8(pdu.Flags)) - binary.Write(wr, binary.BigEndian, uint8(pdu.AFIFlags)) - binary.Write(wr, binary.BigEndian, uint16(pdu.ProviderASCount)) + binary.Write(wr, binary.BigEndian, uint8(0)) + binary.Write(wr, binary.BigEndian, uint32(12 + (len(pdu.ProviderASNumbers)*4))) binary.Write(wr, binary.BigEndian, uint32(pdu.CustomerASNumber)) + for _, pasn := range pdu.ProviderASNumbers { binary.Write(wr, binary.BigEndian, uint32(pasn)) } @@ -577,34 +574,41 @@ func Decode(rdr io.Reader) (PDU, error) { if rdr == nil { return nil, errors.New("reader for decoding is nil") } + var pver uint8 var pduType uint8 - var sessionId uint16 + var sessionId_or_flags uint16 var length uint32 + err := binary.Read(rdr, binary.BigEndian, &pver) if err != nil { return nil, err } + err = binary.Read(rdr, binary.BigEndian, &pduType) if err != nil { return nil, err } - err = binary.Read(rdr, binary.BigEndian, &sessionId) + + err = binary.Read(rdr, binary.BigEndian, &sessionId_or_flags) if err != nil { return nil, err } + err = binary.Read(rdr, binary.BigEndian, &length) if err != nil { return nil, err } if length < 8 { - return nil, fmt.Errorf("wrong length: %d < 8", length) + return nil, fmt.Errorf("wrong PDU length: %d < 8", length) } if length > messageMaxSize { - return nil, fmt.Errorf("wrong length: %d > %d", length, messageMaxSize) + return nil, fmt.Errorf("PDU too large: %d > %d", length, messageMaxSize) } - toread := make([]byte, length-8) + + toread := make([]byte, length - 8) + err = binary.Read(rdr, binary.BigEndian, toread) if err != nil { return nil, err @@ -615,20 +619,24 @@ func Decode(rdr io.Reader) (PDU, error) { if len(toread) != 4 { return nil, fmt.Errorf("wrong length for Serial Notify PDU: %d != 4", len(toread)) } + serial := binary.BigEndian.Uint32(toread) + return &PDUSerialNotify{ Version: pver, - SessionId: sessionId, + SessionId: sessionId_or_flags, SerialNumber: serial, }, nil case PDU_ID_SERIAL_QUERY: if len(toread) != 4 { return nil, fmt.Errorf("wrong length for Serial Query PDU: %d != 4", len(toread)) } + serial := binary.BigEndian.Uint32(toread) + return &PDUSerialQuery{ Version: pver, - SessionId: sessionId, + SessionId: sessionId_or_flags, SerialNumber: serial, }, nil case PDU_ID_RESET_QUERY: @@ -642,21 +650,25 @@ func Decode(rdr io.Reader) (PDU, error) { if len(toread) != 0 { return nil, fmt.Errorf("wrong length for Cache Response PDU: %d != 0", len(toread)) } + return &PDUCacheResponse{ Version: pver, - SessionId: sessionId, + SessionId: sessionId_or_flags, }, nil case PDU_ID_IPV4_PREFIX: - if len(toread) != 12 { - return nil, fmt.Errorf("wrong length for IPv4 Prefix PDU: %d != 12", len(toread)) + if length != 20 { + return nil, fmt.Errorf("wrong length for IPv4 Prefix PDU: %d != 20", length) } + prefixLen := int(toread[1]) ip := toread[4:8] addr, ok := netip.AddrFromSlice(ip) if !ok { return nil, fmt.Errorf("ip slice length is not 4 or 16: %+v", addr) } + asn := binary.BigEndian.Uint32(toread[8:]) + return &PDUIPv4Prefix{ Version: pver, Flags: uint8(toread[0]), @@ -665,16 +677,19 @@ func Decode(rdr io.Reader) (PDU, error) { Prefix: netip.PrefixFrom(addr, prefixLen), }, nil case PDU_ID_IPV6_PREFIX: - if len(toread) != 24 { - return nil, fmt.Errorf("wrong length for IPv6 Prefix PDU: %d != 24", len(toread)) + if length != 32 { + return nil, fmt.Errorf("wrong length for IPv6 Prefix PDU: %d != 32", length) } + prefixLen := int(toread[1]) ip := toread[4:20] addr, ok := netip.AddrFromSlice(ip) if !ok { return nil, fmt.Errorf("ip slice length is not 4 or 16: %+v", addr) } + asn := binary.BigEndian.Uint32(toread[20:]) + return &PDUIPv6Prefix{ Version: pver, Flags: uint8(toread[0]), @@ -683,14 +698,15 @@ func Decode(rdr io.Reader) (PDU, error) { Prefix: netip.PrefixFrom(addr, prefixLen), }, nil case PDU_ID_END_OF_DATA: - if len(toread) != 4 && len(toread) != 16 { - return nil, fmt.Errorf("wrong length for End of Data PDU: %d != 4 or != 16", len(toread)) + if length != 12 && length != 24 { + return nil, fmt.Errorf("wrong length for End of Data PDU: %d != 12 or != 24", length) } var serial uint32 var refreshInterval uint32 var retryInterval uint32 var expireInterval uint32 + if len(toread) == 4 { serial = binary.BigEndian.Uint32(toread) } else if len(toread) == 16 { @@ -702,83 +718,96 @@ func Decode(rdr io.Reader) (PDU, error) { return &PDUEndOfData{ Version: pver, - SessionId: sessionId, + SessionId: sessionId_or_flags, SerialNumber: serial, RefreshInterval: refreshInterval, RetryInterval: retryInterval, ExpireInterval: expireInterval, }, nil case PDU_ID_CACHE_RESET: - if len(toread) != 0 { - return nil, fmt.Errorf("wrong length for Cache Reset PDU: %d != 0", len(toread)) + if length != 8 { + return nil, fmt.Errorf("wrong length for Cache Reset PDU: %d != 8", length) } + return &PDUCacheReset{ Version: pver, }, nil case PDU_ID_ROUTER_KEY: - if len(toread) < 28 { - return nil, fmt.Errorf("wrong length for Router Key PDU: %d < 28", len(toread)) + if length < 28 { + return nil, fmt.Errorf("wrong length for Router Key PDU: %d < 28", length) } + asn := binary.BigEndian.Uint32(toread[20:24]) spki := toread[24:] ski := make([]byte, 20) copy(ski[:], toread[0:20]) + return &PDURouterKey{ Version: pver, SubjectKeyIdentifier: ski, - // Router Key uses a rarely used spot that is also used by the SessionID, So we we will just bitshift - Flags: uint8(sessionId >> 8), + // Flags is in a spot that is also used by the SessionID, So we we will just bitshift + Flags: uint8(sessionId_or_flags >> 8), ASN: asn, SubjectPublicKeyInfo: spki, }, nil case PDU_ID_ERROR_REPORT: - if len(toread) < 8 { - return nil, fmt.Errorf("wrong length for Error Report PDU: %d < 8", len(toread)) + if length < 24 { + return nil, fmt.Errorf("wrong length for Error Report PDU: %d < 24", length) } + lenPdu := binary.BigEndian.Uint32(toread[0:4]) - if len(toread) < int(lenPdu)+8 { - return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+4) + if len(toread) < int(lenPdu) + 8 { + return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu + 4) } + errPdu := toread[4 : lenPdu+4] lenErrText := binary.BigEndian.Uint32(toread[lenPdu+4 : lenPdu+8]) + // int casting for each value is needed here to prevent an uint32 overflow that could result in // upper bound being lower than lower bound causing a crash if len(toread) < int(lenPdu)+8+int(lenErrText) { - return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu+8+lenErrText) + return nil, fmt.Errorf("wrong length for Error Report PDU: %d < %d", len(toread), lenPdu + 8 + lenErrText) } errMsg := string(toread[lenPdu+8 : lenPdu+8+lenErrText]) + return &PDUErrorReport{ Version: pver, - ErrorCode: sessionId, + ErrorCode: sessionId_or_flags, PDUCopy: errPdu, ErrorMsg: errMsg, }, nil case PDU_ID_ASPA: - if len(toread) < 8 { - return nil, fmt.Errorf("wrong length for ASPA PDU: %d < 16", len(toread)) + if length < 12 { + return nil, fmt.Errorf("wrong length for ASPA PDU: %d < 12", length) } - aspaFlag := uint8(toread[0]) - aspaAFIFlag := uint8(toread[1]) - PASCount := binary.BigEndian.Uint16(toread[2:4]) - CASN := binary.BigEndian.Uint32(toread[4:8]) + CASN := binary.BigEndian.Uint32(toread[0:4]) PASNs := make([]uint32, 0) - rbuf := bytes.NewReader(toread[8:]) - for i := 0; i < int(PASCount); i++ { - var asn uint32 + rbuf := bytes.NewReader(toread[4:]) + var prev_asn uint32 + var asn uint32 + for i := 0; i < int((length - 12) / 4); i++ { + if i == 0 { + prev_asn = asn + } err := binary.Read(rbuf, binary.BigEndian, &asn) if err != nil { return nil, err } PASNs = append(PASNs, asn) + if i > 0 { + if !(asn > prev_asn) { + return nil, fmt.Errorf("Sorting issue in ASPA Providers: %d > %d", asn, prev_asn) + } + prev_asn = asn + } } return &PDUASPA{ Version: pver, - Flags: aspaFlag, - AFIFlags: aspaAFIFlag, - ProviderASCount: PASCount, + // Flags is in a spot that is also used by the SessionID, So we we will just bitshift + Flags: uint8(sessionId_or_flags >> 8), CustomerASNumber: CASN, ProviderASNumbers: PASNs, }, nil