From 3ea904103953948f8998c9d70ccc64189daa9958 Mon Sep 17 00:00:00 2001 From: Akihiro Suda Date: Tue, 7 Apr 2026 20:53:28 +0900 Subject: [PATCH] forwarder: propagate source IP Signed-off-by: Akihiro Suda --- go.mod | 2 + pkg/services/forwarder/ports.go | 14 +- pkg/services/forwarder/ports_test.go | 179 ++++++++++++++++++ pkg/services/forwarder/udp.go | 2 +- pkg/services/forwarder/udp_proxy.go | 6 +- vendor/github.com/inetaf/tcpproxy/tcpproxy.go | 8 +- .../pkg/tcpip/link/loopback/loopback.go | 148 +++++++++++++++ .../link/loopback/loopback_state_autogen.go | 44 +++++ vendor/modules.txt | 4 +- 9 files changed, 398 insertions(+), 9 deletions(-) create mode 100644 pkg/services/forwarder/ports_test.go create mode 100644 vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback.go create mode 100644 vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback_state_autogen.go diff --git a/go.mod b/go.mod index beb342878..1a24d926b 100644 --- a/go.mod +++ b/go.mod @@ -50,3 +50,5 @@ require ( golang.org/x/tools v0.41.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) + +replace github.com/inetaf/tcpproxy => ../tcpproxy diff --git a/pkg/services/forwarder/ports.go b/pkg/services/forwarder/ports.go index f5d9aa678..e8b7b1541 100644 --- a/pkg/services/forwarder/ports.go +++ b/pkg/services/forwarder/ports.go @@ -212,8 +212,12 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote if err != nil { return err } - p, err := NewUDPProxy(listener, func() (net.Conn, error) { - return gonet.DialUDP(f.stack, nil, &address, ipv4.ProtocolNumber) + p, err := NewUDPProxy(listener, func(from net.Addr) (net.Conn, error) { + var local *tcpip.FullAddress + if a, ok := from.(*net.UDPAddr); ok && a.IP.To4() != nil { + local = &tcpip.FullAddress{NIC: 1, Addr: tcpip.AddrFrom4Slice(a.IP.To4()), Port: uint16(a.Port)} + } + return gonet.DialUDP(f.stack, local, &address, ipv4.ProtocolNumber) }) if err != nil { return err @@ -235,7 +239,11 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote p.AddRoute(local, &tcpproxy.DialProxy{ Addr: remote, DialContext: func(ctx context.Context, _, _ string) (conn net.Conn, e error) { - return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber) + var local tcpip.FullAddress + if a, ok := ctx.Value(tcpproxy.SourceAddrContextKey).(*net.TCPAddr); ok && a.IP.To4() != nil { + local = tcpip.FullAddress{NIC: 1, Addr: tcpip.AddrFrom4Slice(a.IP.To4()), Port: uint16(a.Port)} + } + return gonet.DialTCPWithBind(ctx, f.stack, local, address, ipv4.ProtocolNumber) }, }) if err := p.Start(); err != nil { diff --git a/pkg/services/forwarder/ports_test.go b/pkg/services/forwarder/ports_test.go new file mode 100644 index 000000000..7ee34d5d3 --- /dev/null +++ b/pkg/services/forwarder/ports_test.go @@ -0,0 +1,179 @@ +package forwarder + +import ( + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/containers/gvisor-tap-vsock/pkg/types" + "github.com/onsi/ginkgo" + "github.com/onsi/gomega" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +func TestSuite(t *testing.T) { + gomega.RegisterFailHandler(ginkgo.Fail) + ginkgo.RunSpecs(t, "forwarder suite") +} + +func hostIP() net.IP { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { + return ipnet.IP + } + } + return nil +} + +var ( + gatewayIP = tcpip.AddrFrom4([4]byte{10, 0, 2, 1}) + childIP = tcpip.AddrFrom4([4]byte{10, 0, 2, 100}) +) + +// newTestStack creates a gvisor stack with spoofing and promiscuous mode +// enabled, matching the configuration used by virtualnetwork.New. +func newTestStack() *stack.Stack { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + }) + gomega.Expect(s.CreateNIC(1, loopback.New())).To(gomega.BeNil()) + for _, addr := range []tcpip.Address{gatewayIP, childIP} { + gomega.Expect(s.AddProtocolAddress(1, tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + }, stack.AddressProperties{})).To(gomega.BeNil()) + } + s.SetSpoofing(1, true) + s.SetPromiscuousMode(1, true) + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: 1}}) + return s +} + +// freeHostAddr returns a free "hostIP:port" address for the given network. +func freeHostAddr(network string, ip net.IP) string { + switch network { + case "tcp": + ln, err := net.Listen("tcp", ip.String()+":0") + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + addr := ln.Addr().String() + ln.Close() + return addr + case "udp": + conn, err := net.ListenPacket("udp", ip.String()+":0") + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + addr := conn.LocalAddr().String() + conn.Close() + return addr + default: + panic("unsupported network: " + network) + } +} + +var _ = ginkgo.Describe("port forwarding", func() { + ginkgo.It("should preserve the client source IP for TCP", func() { + ip := hostIP() + if ip == nil { + ginkgo.Skip("no non-loopback IPv4 address found") + } + + s := newTestStack() + + childLn, err := gonet.ListenTCP(s, tcpip.FullAddress{NIC: 1, Addr: childIP, Port: 8080}, ipv4.ProtocolNumber) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer childLn.Close() + + sourceAddrCh := make(chan string, 1) + go func() { + conn, err := childLn.Accept() + if err != nil { + return + } + defer conn.Close() + sourceAddrCh <- conn.RemoteAddr().String() + io.Copy(io.Discard, conn) + }() + + listenAddr := freeHostAddr("tcp", ip) + fw := NewPortsForwarder(s) + gomega.Expect(fw.Expose(types.TCP, listenAddr, "10.0.2.100:8080")).Should(gomega.Succeed()) + defer fw.Unexpose(types.TCP, listenAddr) + + conn, err := net.Dial("tcp", listenAddr) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + clientIP := conn.LocalAddr().(*net.TCPAddr).IP.String() + conn.Close() + + var addr string + gomega.Eventually(sourceAddrCh).Should(gomega.Receive(&addr)) + host, _, err := net.SplitHostPort(addr) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + gomega.Expect(host).To(gomega.Equal(clientIP), + fmt.Sprintf("child saw %s, expected client IP %s (gateway is 10.0.2.1)", host, clientIP)) + }) + + ginkgo.It("should preserve the client source IP for UDP", func() { + ip := hostIP() + if ip == nil { + ginkgo.Skip("no non-loopback IPv4 address found") + } + + s := newTestStack() + + childAddr := tcpip.FullAddress{NIC: 1, Addr: childIP, Port: 8081} + childConn, err := gonet.DialUDP(s, &childAddr, nil, ipv4.ProtocolNumber) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer childConn.Close() + + sourceAddrCh := make(chan string, 1) + go func() { + buf := make([]byte, 1024) + n, from, err := childConn.ReadFrom(buf) + if err != nil { + return + } + sourceAddrCh <- from.String() + // Echo back + childConn.WriteTo(buf[:n], from) + }() + + listenAddr := freeHostAddr("udp", ip) + fw := NewPortsForwarder(s) + gomega.Expect(fw.Expose(types.UDP, listenAddr, "10.0.2.100:8081")).Should(gomega.Succeed()) + defer fw.Unexpose(types.UDP, listenAddr) + + clientConn, err := net.Dial("udp", listenAddr) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer clientConn.Close() + clientIP := clientConn.LocalAddr().(*net.UDPAddr).IP.String() + + _, err = clientConn.Write([]byte("hello")) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Read echo to ensure round-trip completes. + clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 1024) + _, err = clientConn.Read(buf) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + var addr string + gomega.Eventually(sourceAddrCh).Should(gomega.Receive(&addr)) + host, _, err := net.SplitHostPort(addr) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + gomega.Expect(host).To(gomega.Equal(clientIP), + fmt.Sprintf("child saw %s, expected client IP %s (gateway is 10.0.2.1)", host, clientIP)) + }) +}) diff --git a/pkg/services/forwarder/udp.go b/pkg/services/forwarder/udp.go index 0ffdf0441..cc0915571 100644 --- a/pkg/services/forwarder/udp.go +++ b/pkg/services/forwarder/udp.go @@ -40,7 +40,7 @@ func UDP(s *stack.Stack, nat map[tcpip.Address]tcpip.Address, natLock *sync.Mute return } - p, _ := NewUDPProxy(&autoStoppingListener{underlying: gonet.NewUDPConn(&wq, ep)}, func() (net.Conn, error) { + p, _ := NewUDPProxy(&autoStoppingListener{underlying: gonet.NewUDPConn(&wq, ep)}, func(_ net.Addr) (net.Conn, error) { return net.Dial("udp", net.JoinHostPort(localAddress.String(), strconv.Itoa(int(r.ID().LocalPort)))) }) go func() { diff --git a/pkg/services/forwarder/udp_proxy.go b/pkg/services/forwarder/udp_proxy.go index fbb0029d9..120e10b19 100644 --- a/pkg/services/forwarder/udp_proxy.go +++ b/pkg/services/forwarder/udp_proxy.go @@ -52,13 +52,13 @@ type connTrackMap map[connTrackKey]net.Conn // addresses. type UDPProxy struct { listener udpConn - dialer func() (net.Conn, error) + dialer func(from net.Addr) (net.Conn, error) connTrackTable connTrackMap connTrackLock sync.Mutex } // NewUDPProxy creates a new UDPProxy. -func NewUDPProxy(listener udpConn, dialer func() (net.Conn, error)) (*UDPProxy, error) { +func NewUDPProxy(listener udpConn, dialer func(from net.Addr) (net.Conn, error)) (*UDPProxy, error) { return &UDPProxy{ listener: listener, connTrackTable: make(connTrackMap), @@ -119,7 +119,7 @@ func (proxy *UDPProxy) Run() { proxy.connTrackLock.Lock() proxyConn, hit := proxy.connTrackTable[*fromKey] if !hit { - proxyConn, err = proxy.dialer() + proxyConn, err = proxy.dialer(from) if err != nil { log.Errorf("Can't proxy a datagram to udp: %s\n", err) proxy.connTrackLock.Unlock() diff --git a/vendor/github.com/inetaf/tcpproxy/tcpproxy.go b/vendor/github.com/inetaf/tcpproxy/tcpproxy.go index d59c434d7..4dfd5ab81 100644 --- a/vendor/github.com/inetaf/tcpproxy/tcpproxy.go +++ b/vendor/github.com/inetaf/tcpproxy/tcpproxy.go @@ -293,6 +293,12 @@ type Target interface { HandleConn(net.Conn) } +type contextKey struct{} + +// SourceAddrContextKey is the context key used by DialProxy.HandleConn to +// pass the incoming connection's remote address (net.Addr) to DialContext. +var SourceAddrContextKey = contextKey{} + // To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. func To(addr string) *DialProxy { return &DialProxy{Addr: addr} @@ -374,7 +380,7 @@ func closeWrite(c net.Conn) { // HandleConn implements the Target interface. func (dp *DialProxy) HandleConn(src net.Conn) { - ctx := context.Background() + ctx := context.WithValue(context.Background(), SourceAddrContextKey, src.RemoteAddr()) var cancel context.CancelFunc if dp.DialTimeout >= 0 { ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) diff --git a/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback.go b/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback.go new file mode 100644 index 000000000..ffab1033a --- /dev/null +++ b/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback.go @@ -0,0 +1,148 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package loopback provides the implementation of loopback data-link layer +// endpoints. Such endpoints just turn outbound packets into inbound ones. +// +// Loopback endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package loopback + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + loopbackMTU = 65536 +) + +// +stateify savable +type endpoint struct { + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + dispatcher stack.NetworkDispatcher + // +checklocks:mu + addr tcpip.LinkAddress + // +checklocks:mu + mtu uint32 +} + +// New creates a new loopback endpoint. This link-layer endpoint just turns +// outbound packets into inbound packets. +func New() stack.LinkEndpoint { + return &endpoint{ + mtu: loopbackMTU, + } +} + +// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- +// layer dispatcher for later use when packets need to be dispatched. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + defer e.mu.Unlock() + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *endpoint) MTU() uint32 { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mtu +} + +// SetMTU implements stack.LinkEndpoint.SetMTU. It has no impact. +func (e *endpoint) SetMTU(mtu uint32) { + e.mu.Lock() + defer e.mu.Unlock() + e.mtu = mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises +// itself as supporting checksum offload, but in reality it's just omitted. +func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityRXChecksumOffload | stack.CapabilityTXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the +// loopback interface doesn't have a header, it just returns 0. +func (*endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + e.mu.RLock() + defer e.mu.RUnlock() + return e.addr +} + +// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress. +func (e *endpoint) SetLinkAddress(addr tcpip.LinkAddress) { + e.mu.Lock() + defer e.mu.Unlock() + e.addr = addr +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + +// WritePackets implements stack.LinkEndpoint.WritePackets. If the endpoint is +// not attached, the packets are not delivered. +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + for _, pkt := range pkts.AsSlice() { + // In order to properly loop back to the inbound side we must create a + // fresh packet that only contains the underlying payload with no headers + // or struct fields set. + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: pkt.ToBuffer(), + }) + if d != nil { + d.DeliverNetworkPacket(pkt.NetworkProtocolNumber, newPkt) + } + newPkt.DecRef() + } + return pkts.Len(), nil +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +// AddHeader implements stack.LinkEndpoint. +func (*endpoint) AddHeader(*stack.PacketBuffer) {} + +// ParseHeader implements stack.LinkEndpoint. +func (*endpoint) ParseHeader(*stack.PacketBuffer) bool { return true } + +// Close implements stack.LinkEndpoint. +func (*endpoint) Close() {} + +// SetOnCloseAction implements stack.LinkEndpoint. +func (*endpoint) SetOnCloseAction(func()) {} diff --git a/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback_state_autogen.go b/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100644 index 000000000..949f20bd8 --- /dev/null +++ b/vendor/gvisor.dev/gvisor/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,44 @@ +// automatically generated by stateify. + +package loopback + +import ( + "context" + + "gvisor.dev/gvisor/pkg/state" +) + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/link/loopback.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "dispatcher", + "addr", + "mtu", + } +} + +func (e *endpoint) beforeSave() {} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.dispatcher) + stateSinkObject.Save(1, &e.addr) + stateSinkObject.Save(2, &e.mtu) +} + +func (e *endpoint) afterLoad(context.Context) {} + +// +checklocksignore +func (e *endpoint) StateLoad(ctx context.Context, stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.dispatcher) + stateSourceObject.Load(1, &e.addr) + stateSourceObject.Load(2, &e.mtu) +} + +func init() { + state.Register((*endpoint)(nil)) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index a889bc321..ee3b74d00 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -45,7 +45,7 @@ github.com/google/go-cmp/cmp/internal/value ## explicit; go 1.12 github.com/google/gopacket github.com/google/gopacket/layers -# github.com/inetaf/tcpproxy v0.0.0-20250222171855-c4b9df066048 +# github.com/inetaf/tcpproxy v0.0.0-20250222171855-c4b9df066048 => ../tcpproxy ## explicit; go 1.16 github.com/inetaf/tcpproxy # github.com/insomniacslk/dhcp v0.0.0-20240710054256-ddd8a41251c9 @@ -260,6 +260,7 @@ gvisor.dev/gvisor/pkg/tcpip/hash/jenkins gvisor.dev/gvisor/pkg/tcpip/header gvisor.dev/gvisor/pkg/tcpip/header/parse gvisor.dev/gvisor/pkg/tcpip/internal/tcp +gvisor.dev/gvisor/pkg/tcpip/link/loopback gvisor.dev/gvisor/pkg/tcpip/link/nested gvisor.dev/gvisor/pkg/tcpip/link/sniffer gvisor.dev/gvisor/pkg/tcpip/network/arp @@ -281,3 +282,4 @@ gvisor.dev/gvisor/pkg/tcpip/transport/tcp gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack gvisor.dev/gvisor/pkg/tcpip/transport/udp gvisor.dev/gvisor/pkg/waiter +# github.com/inetaf/tcpproxy => ../tcpproxy