diff --git a/packages/orchestrator/pkg/sandbox/map.go b/packages/orchestrator/pkg/sandbox/map.go index 13430d637d..12e65be690 100644 --- a/packages/orchestrator/pkg/sandbox/map.go +++ b/packages/orchestrator/pkg/sandbox/map.go @@ -88,17 +88,34 @@ func (m *Map) Get(sandboxID string) (*Sandbox, bool) { } // GetByHostPort looks up a sandbox by its host IP address parsed from hostPort. -// It matches any sandbox in the map (starting, running, or stopping). +// It prefers a running sandbox and only falls back to a non-running one when +// no running sandbox matches. func (m *Map) GetByHostPort(hostPort string) (*Sandbox, error) { reqIP, _, err := net.SplitHostPort(hostPort) if err != nil { return nil, fmt.Errorf("error parsing remote address %s: %w", hostPort, err) } + var fallback *Sandbox for _, sbx := range m.sandboxes.Items() { - if sbx.Slot.HostIPString() == reqIP { + if sbx.Slot.HostIPString() != reqIP { + continue + } + + if sbx.IsRunning() { return sbx, nil } + + // Prefer a starting sandbox over a stopping one so that when an IP + // slot is freed by a stopping sandbox and immediately reused by a + // new starting sandbox, we route to the new sandbox. + if fallback == nil || SandboxStatus(sbx.status.Load()) == StatusStarting { + fallback = sbx + } + } + + if fallback != nil { + return fallback, nil } return nil, fmt.Errorf("sandbox with address %s not found", hostPort) diff --git a/packages/orchestrator/pkg/sandbox/map_test.go b/packages/orchestrator/pkg/sandbox/map_test.go new file mode 100644 index 0000000000..0d698a054c --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/map_test.go @@ -0,0 +1,80 @@ +package sandbox + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/network" +) + +func TestGetByHostPortPrefersRunningSandbox(t *testing.T) { + t.Parallel() + + m := NewSandboxesMap() + + stopping := &Sandbox{ + Resources: &Resources{ + Slot: &network.Slot{HostIP: net.ParseIP("10.11.0.2")}, + }, + } + stopping.status.Store(int32(StatusStopping)) + m.sandboxes.Insert("stopping", stopping) + + running := &Sandbox{ + Resources: &Resources{ + Slot: &network.Slot{HostIP: net.ParseIP("10.11.0.2")}, + }, + } + running.status.Store(int32(StatusRunning)) + m.sandboxes.Insert("running", running) + + sbx, err := m.GetByHostPort("10.11.0.2:2049") + require.NoError(t, err) + require.Same(t, running, sbx) +} + +func TestGetByHostPortPrefersStartingOverStopping(t *testing.T) { + t.Parallel() + + m := NewSandboxesMap() + + stopping := &Sandbox{ + Resources: &Resources{ + Slot: &network.Slot{HostIP: net.ParseIP("10.11.0.2")}, + }, + } + stopping.status.Store(int32(StatusStopping)) + m.sandboxes.Insert("stopping", stopping) + + starting := &Sandbox{ + Resources: &Resources{ + Slot: &network.Slot{HostIP: net.ParseIP("10.11.0.2")}, + }, + } + starting.status.Store(int32(StatusStarting)) + m.sandboxes.Insert("starting", starting) + + sbx, err := m.GetByHostPort("10.11.0.2:2049") + require.NoError(t, err) + require.Same(t, starting, sbx) +} + +func TestGetByHostPortFallsBackToStoppingSandbox(t *testing.T) { + t.Parallel() + + m := NewSandboxesMap() + + stopping := &Sandbox{ + Resources: &Resources{ + Slot: &network.Slot{HostIP: net.ParseIP("10.11.0.3")}, + }, + } + stopping.status.Store(int32(StatusStopping)) + m.sandboxes.Insert("stopping", stopping) + + sbx, err := m.GetByHostPort("10.11.0.3:2049") + require.NoError(t, err) + require.Same(t, stopping, sbx) +}