diff --git a/backend/pkg/api/activity_test.go b/backend/pkg/api/activity_test.go index 9488dda8f..28fbde994 100644 --- a/backend/pkg/api/activity_test.go +++ b/backend/pkg/api/activity_test.go @@ -20,9 +20,9 @@ func TestGetActivity(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp.ID, PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID) - tFakeInstance, _ := a.RegisterInstance("{"+uuid.New().String()+"}", "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID, "", "") + tFakeInstance, _ := a.RegisterInstance("{"+uuid.New().String()+"}", "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID, "", "") _ = a.newGroupActivityEntry(activityRolloutStarted, activitySuccess, tVersion, tApp.ID, tGroup.ID) _ = a.newGroupActivityEntry(activityRolloutStarted, activitySuccess, tVersion, tApp.ID, tGroup2.ID) diff --git a/backend/pkg/api/applications_test.go b/backend/pkg/api/applications_test.go index 4a6ded237..4bb5dd454 100644 --- a/backend/pkg/api/applications_test.go +++ b/backend/pkg/api/applications_test.go @@ -158,7 +158,7 @@ func TestGetApp(t *testing.T) { assert.NoError(t, err) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") app, err := a.GetApp(tApp.ID) assert.NoError(t, err) @@ -290,8 +290,8 @@ func TestGetAppsFiltered(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) realInstanceID := uuid.New().String() fakeInstanceID := "{" + uuid.New().String() + "}" - _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(fakeInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") // should ignore fake instance in Instances count apps, err := a.GetApps(tTeam.ID, 1, 10) diff --git a/backend/pkg/api/db/migrations/0021_add_instance_oem.sql b/backend/pkg/api/db/migrations/0021_add_instance_oem.sql new file mode 100644 index 000000000..26720ac0a --- /dev/null +++ b/backend/pkg/api/db/migrations/0021_add_instance_oem.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE instance ADD COLUMN oem VARCHAR(256) DEFAULT ''; +ALTER TABLE instance ADD COLUMN oem_version VARCHAR(256) DEFAULT ''; + +-- +migrate Down + +ALTER TABLE instance DROP COLUMN oem_version; +ALTER TABLE instance DROP COLUMN oem; diff --git a/backend/pkg/api/events_test.go b/backend/pkg/api/events_test.go index a343eceab..5ea2431b3 100644 --- a/backend/pkg/api/events_test.go +++ b/backend/pkg/api/events_test.go @@ -19,7 +19,7 @@ func TestRegisterEvent_InvalidParams(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") err := a.RegisterEvent(uuid.New().String(), tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") assert.Equal(t, ErrInvalidInstance, err) @@ -33,7 +33,7 @@ func TestRegisterEvent_InvalidParams(t *testing.T) { err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") assert.Equal(t, ErrNoUpdateInProgress, err) - _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, 1000, ResultSuccess, "", "") assert.Equal(t, ErrInvalidEventTypeOrResult, err) @@ -51,10 +51,10 @@ func TestRegisterEvent_TriggerEventConsequences(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID, "", "") - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") @@ -77,7 +77,7 @@ func TestRegisterEvent_TriggerEventConsequences(t *testing.T) { instance, _ = a.GetInstance(tInstance.ID, tApp.ID) assert.Equal(t, null.IntFrom(int64(InstanceStatusComplete)), instance.Application.Status) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance2.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultFailed, "", "") @@ -97,9 +97,9 @@ func TestRegisterEvent_TriggerEventConsequences_FirstUpdateAttemptFailed(t *test tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultFailed, "", "") @@ -115,10 +115,10 @@ func TestRegisterEvent_CheckSuccessResult(t *testing.T) { defer a.Close() performUpdate := func(tApp *Application, tGroup *Group, resultType int) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") @@ -159,10 +159,10 @@ func TestRegisterEvent_CheckFlatcarSuccessResult(t *testing.T) { defer a.Close() performUpdate := func(tApp *Application, tGroup *Group, resultType, expectedInstanceStatus int) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "11.0.0", "") @@ -207,10 +207,10 @@ func TestRegisterEvent_CheckFlatcarIgnoredUpdate(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group9", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) performUpdate := func(previousVersion string) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, previousVersion, "") @@ -247,9 +247,9 @@ func TestRegisterEvent_GetEvent(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) _, err = a.GetEvent(tInstance.ID, tApp.ID, time.Now()) diff --git a/backend/pkg/api/groups_test.go b/backend/pkg/api/groups_test.go index 2ab4a2e97..cb154f725 100644 --- a/backend/pkg/api/groups_test.go +++ b/backend/pkg/api/groups_test.go @@ -168,9 +168,9 @@ func TestGetGroupsFiltered(t *testing.T) { realInstanceID := uuid.New().String() fakeInstanceID1 := "{" + uuid.New().String() + "}" fakeInstanceID2 := "{" + uuid.New().String() + "}" - _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID1, "", "10.0.0.1", "2.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID2, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(fakeInstanceID1, "", "10.0.0.1", "2.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(fakeInstanceID2, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") groups, err := a.GetGroups(tApp.ID, 0, 0) assert.NoError(t, err) @@ -226,7 +226,7 @@ func TestGetVersionCountTimeline(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "test_group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) instanceID := uuid.New().String() - _, _ = a.RegisterInstance(instanceID, "", "10.0.0.1", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(instanceID, "", "10.0.0.1", version, tApp.ID, tGroup.ID, "", "") instance, err := a.GetInstance(instanceID, tApp.ID) assert.NoError(t, err) @@ -304,14 +304,14 @@ func TestGetStatusCountTimeline(t *testing.T) { instanceID1 := uuid.New().String() instanceID2 := uuid.New().String() - _, _ = a.RegisterInstance(instanceID1, "", "10.0.0.1", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(instanceID1, "", "10.0.0.1", version, tApp.ID, tGroup.ID, "", "") instance1, err := a.GetInstance(instanceID1, tApp.ID) assert.NoError(t, err) _ = a.grantUpdate(instance1, version) _ = a.updateInstanceStatus(instanceID1, tApp.ID, InstanceStatusComplete) - _, _ = a.RegisterInstance(instanceID2, "", "10.0.0.2", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(instanceID2, "", "10.0.0.2", version, tApp.ID, tGroup.ID, "", "") instance2, err := a.GetInstance(instanceID2, tApp.ID) assert.NoError(t, err) diff --git a/backend/pkg/api/instances.go b/backend/pkg/api/instances.go index b712db80c..ff77b37ec 100644 --- a/backend/pkg/api/instances.go +++ b/backend/pkg/api/instances.go @@ -57,6 +57,8 @@ const ( type Instance struct { ID string `db:"id" json:"id"` IP string `db:"ip" json:"ip"` + OEM string `db:"oem" json:"oem,omitempty"` + OEMVersion string `db:"oem_version" json:"oem_version,omitempty"` CreatedTs time.Time `db:"created_ts" json:"created_ts"` Application InstanceApplication `db:"application" json:"application,omitempty"` Alias string `db:"alias" json:"alias,omitempty"` @@ -161,7 +163,7 @@ func sanitizeSortFilterParams(sortFilter string) string { } // RegisterInstance registers an instance into Nebraska. -func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) (*Instance, error) { +func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID, instanceOEM, instanceOEMVersion string) (*Instance, error) { if !isValidSemver(instanceVersion) { return nil, ErrInvalidSemver } @@ -183,8 +185,15 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance if instanceAlias == "" { instanceAlias = instance.Alias } - // The instance exists, so we just update it if its IP or Alias changed - updateInstance = instance.IP != instanceIP || instance.Alias != instanceAlias + // Give precedence to existing OEM values over omitted or empty fields + if instanceOEM == "" { + instanceOEM = instance.OEM + } + if instanceOEMVersion == "" { + instanceOEMVersion = instance.OEMVersion + } + // The instance exists, so we just update it if its IP, Alias, OEM or OEMVersion changed + updateInstance = instance.IP != instanceIP || instance.Alias != instanceAlias || instance.OEM != instanceOEM || instance.OEMVersion != instanceOEMVersion recent := nowUTC().Add(-5 * time.Minute) @@ -200,9 +209,9 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance } upsertInstance, _, err := goqu.Insert("instance"). - Cols("id", "ip", "alias"). - Vals(goqu.Vals{instanceID, instanceIP, instanceAlias}). - OnConflict(goqu.DoUpdate("id", goqu.Record{"id": instanceID, "ip": instanceIP, "alias": instanceAlias})). + Cols("id", "ip", "alias", "oem", "oem_version"). + Vals(goqu.Vals{instanceID, instanceIP, instanceAlias, instanceOEM, instanceOEMVersion}). + OnConflict(goqu.DoUpdate("id", goqu.Record{"id": instanceID, "ip": instanceIP, "alias": instanceAlias, "oem": instanceOEM, "oem_version": instanceOEMVersion})). ToSQL() if err != nil { return nil, err diff --git a/backend/pkg/api/instances_test.go b/backend/pkg/api/instances_test.go index be3e766c3..496f5ad34 100644 --- a/backend/pkg/api/instances_test.go +++ b/backend/pkg/api/instances_test.go @@ -25,47 +25,53 @@ func TestRegisterInstance(t *testing.T) { instanceID := uuid.New().String() - _, err := a.RegisterInstance("", "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, err := a.RegisterInstance("", "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.Error(t, err, "Using empty string as instance id.") - _, err = a.RegisterInstance(instanceID, "", "invalidIP", "1.0.0", tApp.ID, tGroup.ID) + _, err = a.RegisterInstance(instanceID, "", "invalidIP", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.Error(t, err, "Using an invalid instance ip.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", "invalidAppID", tGroup.ID) + _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", "invalidAppID", tGroup.ID, "", "") assert.Error(t, err, "Using an invalid application id.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID", "", "") assert.Error(t, err, "Using an invalid group id.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "", tApp.ID, "invalidGroupID", "", "") assert.Error(t, err, "Using an empty instance version.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "aaa1.0.0", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "aaa1.0.0", tApp.ID, "invalidGroupID", "", "") assert.Equal(t, ErrInvalidSemver, err, "Using an invalid instance version.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID) + _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID, "", "") assert.Equal(t, ErrInvalidApplicationOrGroup, err, "The group provided doesn't belong to the application provided.") - instance, err := a.RegisterInstance(instanceID, "myalias", "10.0.0.1", "1.0.0", "{"+tApp.ID+"}", "{"+tGroup.ID+"}") + instance, err := a.RegisterInstance(instanceID, "myalias", "10.0.0.1", "1.0.0", "{"+tApp.ID+"}", "{"+tGroup.ID+"}", "azure", "2.9.1.1-r1") assert.NoError(t, err) assert.Equal(t, instanceID, instance.ID) assert.Equal(t, "myalias", instance.Alias) assert.Equal(t, "10.0.0.1", instance.IP) + assert.Equal(t, "azure", instance.OEM) + assert.Equal(t, "2.9.1.1-r1", instance.OEMVersion) - instance, err = a.RegisterInstance(instanceID, "mynewalias", "10.0.0.2", "1.0.2", tApp.ID, tGroup.ID) + instance, err = a.RegisterInstance(instanceID, "mynewalias", "10.0.0.2", "1.0.2", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err, "Registering an already registered instance with some updates, that's fine.") assert.Equal(t, "mynewalias", instance.Alias) assert.Equal(t, "10.0.0.2", instance.IP) assert.Equal(t, "1.0.2", instance.Application.Version) + assert.Equal(t, "azure", instance.OEM, "OEM should be preserved when not provided") + assert.Equal(t, "2.9.1.1-r1", instance.OEMVersion, "OEMVersion should be preserved when not provided") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.2", "1.0.2", tApp2.ID, tGroup.ID) + _, err = a.RegisterInstance(instanceID, "", "10.0.0.2", "1.0.2", tApp2.ID, tGroup.ID, "", "") assert.Error(t, err, "Application id cannot be updated.") - instance, err = a.RegisterInstance(instanceID, "", "10.0.0.3", "1.0.3", tApp.ID, tGroup3.ID) + instance, err = a.RegisterInstance(instanceID, "", "10.0.0.3", "1.0.3", tApp.ID, tGroup3.ID, "gcp", "3.0.0") assert.NoError(t, err, "Registering an already registered instance using a different group, that's fine.") assert.Equal(t, "10.0.0.3", instance.IP) assert.Equal(t, "1.0.3", instance.Application.Version) assert.Equal(t, null.StringFrom(tGroup3.ID), instance.Application.GroupID) + assert.Equal(t, "gcp", instance.OEM, "OEM should be updated when provided") + assert.Equal(t, "3.0.0", instance.OEMVersion, "OEMVersion should be updated when provided") } func TestGetInstance(t *testing.T) { @@ -77,7 +83,7 @@ func TestGetInstance(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") _, err := a.GetInstance(uuid.New().String(), tApp.ID) assert.Error(t, err, "Using non existent instance id.") @@ -106,9 +112,9 @@ func TestGetInstances(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp.ID, PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup2.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup2.ID, "", "") result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Version: "1.0.0", Page: 1, PerPage: 10}, testDuration) assert.NoError(t, err) @@ -162,7 +168,7 @@ func TestGetInstances(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 0, len(result.Instances)) - _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") _ = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") result, err = a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Status: InstanceStatusComplete, Page: 1, PerPage: 10}, testDuration) @@ -191,12 +197,12 @@ func TestGetInstancesSearch(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup.ID, "", "") instanceAlias := "instance_alias" - _, _ = a.RegisterInstance(uuid.New().String(), instanceAlias, "10.0.0.4", "1.0.4", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(uuid.New().String(), instanceAlias, "10.0.0.4", "1.0.4", tApp.ID, tGroup.ID, "", "") result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Page: 1, PerPage: 10, SearchFilter: "All", SearchValue: tInstance.ID}, testDuration) assert.NoError(t, err) @@ -232,7 +238,7 @@ func TestGetInstancesFiltered(t *testing.T) { instanceID4 := "8d180b2a07344406af029a4f86bd1ee3" for idx, id := range []string{instanceID1, instanceID2, instanceID3, instanceID4} { ip := fmt.Sprintf("10.0.0.%d", idx+1) - _, _ = a.RegisterInstance(id, "", ip, "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(id, "", ip, "1.0.0", tApp.ID, tGroup.ID, "", "") } result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Version: "1.0.0", Page: 1, PerPage: 10}, testDuration) @@ -264,7 +270,7 @@ func TestGetInstanceStatusHistory(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) newInstance1ID := uuid.New().String() - tInstance, _ := a.RegisterInstance(newInstance1ID, "analias", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(newInstance1ID, "analias", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, tInstance.Alias, "analias") instance, err := a.GetInstance(tInstance.ID, tApp.ID) @@ -312,14 +318,14 @@ func TestUpdateInstanceStats(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID, Arch: ArchAMD64}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID), Arch: ArchAMD64}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.1", tApp.ID, tGroup.ID) + tInstance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.1", tApp.ID, tGroup.ID, "", "") - _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) ts := time.Now().UTC() @@ -343,13 +349,13 @@ func TestUpdateInstanceStats(t *testing.T) { // Next test case: Switch tInstance1 and tInstance2 versions to workaround the 5-minutes-rate-limiting of the check-in time and add new instance ts2 := time.Now().UTC() - _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.3", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.3", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.4", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.4", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.4", "1.0.5", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.4", "1.0.5", tApp.ID, tGroup.ID, "", "") ts3 := time.Now().UTC() elapsed = ts3.Sub(ts2) @@ -381,7 +387,7 @@ func TestUpdateInstanceStatsNoArch(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") ts := time.Now().UTC() // Use large duration to have some test coverage for durationToInterval diff --git a/backend/pkg/api/packages.go b/backend/pkg/api/packages.go index 4820d6ac1..b0dd89fc9 100644 --- a/backend/pkg/api/packages.go +++ b/backend/pkg/api/packages.go @@ -382,7 +382,7 @@ func (api *API) DeletePackage(pkgID string) error { var exists bool var isFloor bool err = api.db.QueryRow(` - SELECT + SELECT EXISTS(SELECT 1 FROM package WHERE id = $1), EXISTS(SELECT 1 FROM channel_package_floors WHERE package_id = $1) `, pkgID).Scan(&exists, &isFloor) diff --git a/backend/pkg/api/packages_floors.go b/backend/pkg/api/packages_floors.go index a01875fb1..fce12354a 100644 --- a/backend/pkg/api/packages_floors.go +++ b/backend/pkg/api/packages_floors.go @@ -226,13 +226,16 @@ const ( DefaultMaxFloorsPerResponse = 5 ) -// GetRequiredChannelFloors returns floor packages between instance and target versions for a channel -func (api *API) GetRequiredChannelFloors(channel *Channel, instanceVersion string) ([]*Package, error) { +// GetRequiredChannelFloorsWithLimit returns floor packages between instance and target versions, +// along with a boolean indicating if more floors remain beyond the limit. +// This uses LIMIT+1 approach: query for limit+1 rows, if we get more than limit, there are more. +// This is more efficient than a separate COUNT query. +func (api *API) GetRequiredChannelFloorsWithLimit(channel *Channel, instanceVersion string) ([]*Package, bool, error) { if channel == nil || channel.Package == nil { - return nil, ErrNoPackageFound + return nil, false, ErrNoPackageFound } if instanceVersion == "" { - return nil, fmt.Errorf("instance version cannot be empty") + return nil, false, fmt.Errorf("instance version cannot be empty") } targetVersion := channel.Package.Version @@ -245,19 +248,20 @@ func (api *API) GetRequiredChannelFloors(channel *Channel, instanceVersion strin // No blacklist check needed for floors gtExpr, err := versionCompareExpr("p.version", ">", instanceVersion) if err != nil { - return nil, err + return nil, false, err } lteExpr, err := versionCompareExpr("p.version", "<=", targetVersion) if err != nil { - return nil, err + return nil, false, err } semverExpr, err := semverToIntArray("p.version") if err != nil { - return nil, err + return nil, false, err } + // Query for LIMIT+1 to detect if more floors exist query, _, err := goqu.From(goqu.L(` package p JOIN channel_package_floors cpf ON p.id = cpf.package_id @@ -273,14 +277,26 @@ func (api *API) GetRequiredChannelFloors(channel *Channel, instanceVersion strin lteExpr, )). Order(goqu.L(semverExpr).Asc()). - Limit(uint(maxFloorsPerResponse)). + Limit(uint(maxFloorsPerResponse + 1)). ToSQL() if err != nil { - return nil, err + return nil, false, err } - return api.getPackagesFromQuery(query) + floors, err := api.getPackagesFromQuery(query) + if err != nil { + return nil, false, err + } + + // If we got more than the limit, there are more floors remaining + if len(floors) > maxFloorsPerResponse { + // Return only up to the limit, indicate more remain + return floors[:maxFloorsPerResponse], true, nil + } + + // All floors returned + return floors, false, nil } // GetChannelFloorPackagesCount returns the count of floor packages for a channel diff --git a/backend/pkg/api/packages_floors_test.go b/backend/pkg/api/packages_floors_test.go index 06e3fcd6b..6a76c4b05 100644 --- a/backend/pkg/api/packages_floors_test.go +++ b/backend/pkg/api/packages_floors_test.go @@ -55,7 +55,7 @@ func TestFloorOperations(t *testing.T) { for instance, expected := range testCases { ch, err := a.GetChannel(setup.Channel.ID) assert.NoError(t, err) - floors, err := a.GetRequiredChannelFloors(ch, instance) + floors, _, err := a.GetRequiredChannelFloorsWithLimit(ch, instance) assert.NoError(t, err) assert.Len(t, floors, expected, "instance %s", instance) } @@ -84,9 +84,10 @@ func TestFloorMaxLimit(t *testing.T) { // Should only get 3 floors due to limit ch, err := a.GetChannel(setup.Channel.ID) assert.NoError(t, err) - floors, err := a.GetRequiredChannelFloors(ch, "0.0.0") + floors, hasMore, err := a.GetRequiredChannelFloorsWithLimit(ch, "0.0.0") assert.NoError(t, err) assert.Len(t, floors, 3) + assert.True(t, hasMore, "Should indicate more floors remain beyond limit") } // TestFloorPagination tests paginated floor retrieval @@ -133,7 +134,7 @@ func TestNonStandardVersions(t *testing.T) { for instance, expected := range testCases { ch, err := a.GetChannel(setup.Channel.ID) assert.NoError(t, err) - floors, err := a.GetRequiredChannelFloors(ch, instance) + floors, _, err := a.GetRequiredChannelFloorsWithLimit(ch, instance) assert.NoError(t, err) assert.Len(t, floors, expected, "instance %s", instance) } @@ -187,12 +188,12 @@ func TestFloorRolloutPolicy(t *testing.T) { assert.NoError(t, err) // First client gets floor - pkg1, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "1000.0.0", setup.AppID, group.ID) + pkg1, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "1000.0.0", setup.AppID, group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg1.Version) // Second client blocked by policy - _, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "1000.0.0", setup.AppID, group.ID) + _, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "1000.0.0", setup.AppID, group.ID, "", "") assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err) } @@ -232,7 +233,7 @@ func TestTargetAsFloor(t *testing.T) { for instance, expected := range testCases { ch, err := a.GetChannel(setup.Channel.ID) assert.NoError(t, err) - floors, err := a.GetRequiredChannelFloors(ch, instance) + floors, _, err := a.GetRequiredChannelFloorsWithLimit(ch, instance) assert.NoError(t, err) assert.Len(t, floors, expected.expectedCount, "instance %s", instance) if expected.expectedCount > 0 { @@ -245,11 +246,11 @@ func TestTargetAsFloor(t *testing.T) { } // Test that regular client gets the appropriate update - pkg, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "1000.0.0", pkg.Version) // Gets first floor - pkg, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "2500.0.0", setup.AppID, setup.Group.ID) + pkg, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "2500.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "3000.0.0", pkg.Version) // Gets target-floor directly } diff --git a/backend/pkg/api/updates.go b/backend/pkg/api/updates.go index 0fe4c7301..6b113fff5 100644 --- a/backend/pkg/api/updates.go +++ b/backend/pkg/api/updates.go @@ -62,8 +62,8 @@ var ( // GetUpdatePackage returns an update package for the instance/application // provided. The instance details and the application it's running will be // registered in Nebraska (or updated if it's already registered). -func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) (*Package, error) { - instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID) +func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID, instanceOEM, instanceOEMVersion string) (*Package, error) { + instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID, instanceOEM, instanceOEMVersion) if err != nil { l.Error().Err(err).Msg("GetUpdatePackage - could not register instance") return nil, ErrRegisterInstanceFailed @@ -116,12 +116,15 @@ func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instance // Instance hasn't reached granted version yet - return what's next to install // This will be the first floor/target above current instance version - packages, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) + floors, target, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) if err != nil { return nil, err } - // packages[0] should be the granted version since instance < granted - return packages[0], nil + // Return first floor if any, otherwise target + if len(floors) > 0 { + return floors[0], nil + } + return target, nil } // No granted version tracked (old instances) - safer fallback @@ -141,15 +144,22 @@ func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instance return nil, ErrNoUpdatePackageAvailable } - packages, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) + floors, target, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) if err != nil { return nil, err } + // Determine the next package to return (first floor or target) + var nextPkg *Package + if len(floors) > 0 { + nextPkg = floors[0] + } else { + nextPkg = target + } + // Safety check: verify the next package isn't blacklisted for this channel // This should never happen (floors/targets can't be blacklisted for their own channel) // but we check anyway for data consistency - nextPkg := packages[0] if slices.Contains(nextPkg.ChannelsBlacklist, group.Channel.ID) { l.Error().Str("package", nextPkg.Version).Str("channel", group.Channel.ID). Msg("Package is blacklisted for its own channel - data inconsistency!") @@ -161,7 +171,7 @@ func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instance } // Grant the update using the version we're actually returning - version := packages[0].Version + version := nextPkg.Version if err := api.grantUpdate(instance, version); err != nil { l.Error().Err(err).Str("version", version).Str("instance", instance.ID).Msg("GetUpdatePackage - grantUpdate error") return nil, ErrUpdateGrantFailed @@ -186,79 +196,99 @@ func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instance } } - return packages[0], nil + return nextPkg, nil } -// GetUpdatePackagesForSyncer returns all packages (floors + target) for a syncer client -func (api *API) GetUpdatePackagesForSyncer(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) ([]*Package, error) { - instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID) +// GetUpdatePackagesForSyncer returns floor packages and target for a syncer client. +// Returns: +// - floors: Required floor packages (may be empty) +// - target: The target package, or nil if more floors remain beyond the limit +// - error: Any error that occurred +// +// When target is nil, the syncer should request again with the highest floor version. +// When target is not nil, all required floors have been sent and the channel can be updated. +func (api *API) GetUpdatePackagesForSyncer(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID, instanceOEM, instanceOEMVersion string) ([]*Package, *Package, error) { + instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID, instanceOEM, instanceOEMVersion) if err != nil { l.Error().Err(err).Msg("GetUpdatePackagesForSyncer - could not register instance") - return nil, ErrRegisterInstanceFailed + return nil, nil, ErrRegisterInstanceFailed } if instance.Application.Status.Valid { switch int(instance.Application.Status.Int64) { case InstanceStatusDownloading, InstanceStatusDownloaded, InstanceStatusInstalled: - return nil, ErrUpdateInProgressOnInstance + return nil, nil, ErrUpdateInProgressOnInstance } } group, err := api.GetGroup(groupID) if err != nil { - return nil, err + return nil, nil, err } if group.Channel == nil || group.Channel.Package == nil { if err := api.newGroupActivityEntry(activityPackageNotFound, activityWarning, "0.0.0", appID, groupID); err != nil { l.Error().Err(err).Msg("GetUpdatePackagesForSyncer - could not add new group activity entry") } - return nil, ErrNoPackageFound + return nil, nil, ErrNoPackageFound } // Check if update is needed instanceSemver, _ := semver.Make(instanceVersion) packageSemver, _ := semver.Make(group.Channel.Package.Version) if !instanceSemver.LT(packageSemver) { - return nil, ErrNoUpdatePackageAvailable + return nil, nil, ErrNoUpdatePackageAvailable } - packages, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) + floors, target, err := api.getPackagesWithFloorsForUpdate(group, instanceVersion) if err != nil { - return nil, err + return nil, nil, err } // Safety check: verify no packages are blacklisted for this channel // Syncers need all packages, so if any is blacklisted we can't send a valid manifest // This should never happen (floors/targets can't be blacklisted for their own channel) // but we check anyway for data consistency - for _, pkg := range packages { + for _, pkg := range floors { if slices.Contains(pkg.ChannelsBlacklist, group.Channel.ID) { l.Error().Str("package", pkg.Version).Str("channel", group.Channel.ID). Msg("Package is blacklisted for its own channel - data inconsistency!") - return nil, ErrNoUpdatePackageAvailable + return nil, nil, ErrNoUpdatePackageAvailable } } + if target != nil && slices.Contains(target.ChannelsBlacklist, group.Channel.ID) { + l.Error().Str("package", target.Version).Str("channel", group.Channel.ID). + Msg("Package is blacklisted for its own channel - data inconsistency!") + return nil, nil, ErrNoUpdatePackageAvailable + } if err := api.enforceRolloutPolicy(instance, group); err != nil { - return nil, err + return nil, nil, err } - // Grant the update using target version - targetVersion := packages[len(packages)-1].Version - if err := api.grantUpdate(instance, targetVersion); err != nil { - l.Error().Err(err).Str("version", targetVersion).Str("instance", instance.ID).Msg("GetUpdatePackagesForSyncer - grantUpdate error") - return nil, ErrUpdateGrantFailed + // Grant the update using the highest version we're returning + // (last floor if no target, or target if present) + var grantVersion string + if target != nil { + grantVersion = target.Version + } else if len(floors) > 0 { + grantVersion = floors[len(floors)-1].Version + } else { + return nil, nil, ErrNoUpdatePackageAvailable + } + if err := api.grantUpdate(instance, grantVersion); err != nil { + l.Error().Err(err).Str("version", grantVersion).Str("instance", instance.ID).Msg("GetUpdatePackagesForSyncer - grantUpdate error") + return nil, nil, ErrUpdateGrantFailed } // Record activity if !api.hasRecentActivity(activityRolloutStarted, ActivityQueryParams{ Severity: activityInfo, AppID: appID, - Version: targetVersion, + Version: grantVersion, GroupID: groupID, }) { - if err := api.newGroupActivityEntry(activityRolloutStarted, activityInfo, targetVersion, appID, groupID); err != nil { + if err := api.newGroupActivityEntry(activityRolloutStarted, activityInfo, grantVersion, appID, groupID); err != nil { l.Error().Err(err).Msg("GetUpdatePackagesForSyncer - could not add new group activity entry") } } @@ -270,7 +300,7 @@ func (api *API) GetUpdatePackagesForSyncer(instanceID, instanceAlias, instanceIP } } - return packages, nil + return floors, target, nil } // enforceRolloutPolicy validates if an update should be provided to the @@ -355,29 +385,50 @@ func inOfficeHoursNow(tz string) bool { return true } -// getPackagesWithFloorsForUpdate returns floors + target for the given group and instance version -// This is a helper method extracted from the UpdateHandler logic -func (api *API) getPackagesWithFloorsForUpdate(group *Group, instanceVersion string) ([]*Package, error) { +// getPackagesWithFloorsForUpdate returns floors and target for the given group and instance version. +// This is a helper method extracted from the UpdateHandler logic. +// +// Returns: +// - floors: Required floor packages between instance version and target (may be empty) +// - target: The target package, or nil if more floors remain beyond the limit +// +// IMPORTANT: When there are more floors remaining than NEBRASKA_MAX_FLOORS_PER_RESPONSE, +// target will be nil. This signals that the syncer should request again with the highest +// floor version to get remaining floors. +func (api *API) getPackagesWithFloorsForUpdate(group *Group, instanceVersion string) (floors []*Package, target *Package, err error) { if group.Channel == nil || group.Channel.Package == nil { - return nil, ErrNoPackageFound + return nil, nil, ErrNoPackageFound } - // Get required floors using the channel - requiredFloors, err := api.GetRequiredChannelFloors( + // Get required floors using LIMIT+1 to detect if more floors remain + // This is more efficient than a separate COUNT query + requiredFloors, hasMoreFloors, err := api.GetRequiredChannelFloorsWithLimit( group.Channel, instanceVersion, ) if err != nil { - return nil, err + return nil, nil, err } targetPkg := group.Channel.Package + // Check if target is already included (when target is also a floor) - if len(requiredFloors) > 0 && requiredFloors[len(requiredFloors)-1].ID == targetPkg.ID { - return requiredFloors, nil + var lastFloor *Package + if len(requiredFloors) > 0 { + lastFloor = requiredFloors[len(requiredFloors)-1] + } + targetIsLastFloor := lastFloor != nil && lastFloor.ID == targetPkg.ID + if targetIsLastFloor { + return requiredFloors, lastFloor, nil + } + + // If more floors remain, don't include target yet + // Syncer will request again with highest floor version + if hasMoreFloors { + return requiredFloors, nil, nil } - // Append target if not already included - return append(requiredFloors, targetPkg), nil + // All floors sent (or no floors) - include target + return requiredFloors, targetPkg, nil } diff --git a/backend/pkg/api/updates_test.go b/backend/pkg/api/updates_test.go index b93956995..d75c22da6 100644 --- a/backend/pkg/api/updates_test.go +++ b/backend/pkg/api/updates_test.go @@ -25,28 +25,28 @@ func TestGetUpdatePackage(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp2.ID, ChannelID: null.StringFrom(tChannel2.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", "invalidApplicationID", tGroup.ID) + _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", "invalidApplicationID", tGroup.ID, "", "") assert.Error(t, ErrInvalidApplicationOrGroup, err, "Invalid application id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID") + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID", "", "") assert.Error(t, err, "Invalid group id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", uuid.New().String(), tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", uuid.New().String(), tGroup.ID, "", "") assert.Error(t, err, "Non existent application id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, uuid.New().String()) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, uuid.New().String(), "", "") assert.Error(t, err, "Non existent group id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID, "", "") assert.Error(t, err, "Group doesn't belong to the application provided.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp2.ID, tGroup2.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp2.ID, tGroup2.ID, "", "") assert.Equal(t, ErrNoPackageFound, err, "Group's channel has no package bound.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Instance version is up to date.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1010.5.0+2016-05-27-1832", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1010.5.0+2016-05-27-1832", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Instance version is up to date.") } @@ -58,7 +58,7 @@ func TestGetUpdatePackage_GroupNoChannel(t *testing.T) { tApp, _ := a.AddApp(&Application{Name: "test_app", TeamID: tTeam.ID}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, PolicyUpdatesEnabled: false, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Error(t, ErrNoPackageFound) } @@ -72,7 +72,7 @@ func TestGetUpdatePackage_UpdatesDisabled(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: false, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrUpdatesDisabled, err) } @@ -88,10 +88,10 @@ func TestGetUpdatePackage_MaxUpdatesPerPeriodLimitReached_SafeMode(t *testing.T) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: safeMode, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err, "Safe mode is enabled, first update should be completed before letting more through.") } @@ -106,16 +106,16 @@ func TestGetUpdatePackage_MaxUpdatesPerPeriodLimitReached_LimitUpdated(t *testin tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 1, PolicyUpdateTimeout: "60 minutes"}) instanceID := uuid.New().String() - _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err, "Max 1 update per period, limit reached") tGroup.PolicyMaxUpdatesPerPeriod = 2 _ = a.UpdateGroup(tGroup) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) } @@ -136,23 +136,23 @@ func TestGetUpdatePackage_MaxUpdatesLimitsReached(t *testing.T) { newInstance1ID := uuid.New().String() - _, err := a.GetUpdatePackage(newInstance1ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(newInstance1ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err, "Period interval is over, but there are still two updates not completed or failed.") _ = a.updateInstanceStatus(newInstance1ID, tApp.ID, InstanceStatusComplete) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) } @@ -172,20 +172,20 @@ func TestGetUpdatePackage_MaxTimedOutUpdatesLimitReached_SafeMode(t *testing.T) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: periodIntervalSetting, PolicyMaxUpdatesPerPeriod: 1, PolicyUpdateTimeout: updateTimeoutSetting}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err) time.Sleep(updateTimeout - periodInterval + extraWaitPeriod) // ensure that update timeout is over - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxTimedOutUpdatesLimitReached, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrUpdatesDisabled, err) } @@ -206,20 +206,20 @@ func TestGetUpdatePackage_ResumeUpdates(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: periodIntervalSetting, PolicyMaxUpdatesPerPeriod: maxUpdatesPerPeriod, PolicyUpdateTimeout: updateTimeoutSetting}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err) time.Sleep(updateTimeout - periodInterval + extraWaitPeriod) // ensure that update timeout is over - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) } @@ -233,13 +233,13 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 4, PolicyUpdateTimeout: "60 minutes"}) - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - instance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - instance3, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + instance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") + instance3, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") - _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") group, _ := a.GetGroup(tGroup.ID) assert.True(t, group.RolloutInProgress) @@ -257,8 +257,8 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { assert.Equal(t, int64(2), stats.OnHold.Int64) _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") - _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID, "", "") + _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") group, _ = a.GetGroup(tGroup.ID) assert.True(t, group.RolloutInProgress) @@ -275,7 +275,7 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { assert.Equal(t, int64(2), stats.Complete.Int64) assert.Equal(t, int64(1), stats.Error.Int64) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID, "", "") _ = a.RegisterEvent(instance3.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") group, _ = a.GetGroup(tGroup.ID) @@ -295,10 +295,10 @@ func TestGetUpdatePackage_CompletionStats(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 4, PolicyUpdateTimeout: "60 minutes"}) addAndUpdateInstance := func() { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "11.0.0", "") @@ -329,19 +329,19 @@ func TestGetUpdatePackage_CompletionStats(t *testing.T) { // This instance has the group's current package's version already and reports no status. // We need to make sure it doesn't show up as undefined. - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID, "", "") stats, _ = a.GetGroupInstancesStats(tGroup.ID, testDuration) assert.Equal(t, int64(0), stats.Undefined.Int64) assert.Equal(t, int64(2), stats.Complete.Int64) // Just ensuring that a call for getting an update in an already up to date instance won't change its status - _, err := a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID, "", "") assert.Error(t, err, "nebraska: no update package available") // This version has a version different from the group's current one, and reports no status, so the // status should be undefined. - _, err = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "0.1.0", tApp.ID, tGroup.ID) + _, err = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "0.1.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) stats, _ = a.GetGroupInstancesStats(tGroup.ID, testDuration) @@ -374,10 +374,10 @@ func TestGetUpdatePackage_UpdateInProgressOnInstance(t *testing.T) { instanceID := uuid.New().String() - p1, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + p1, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) - p2, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + p2, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) assert.Equal(t, p1, p2) @@ -388,7 +388,7 @@ func TestGetUpdatePackage_UpdateInProgressOnInstance(t *testing.T) { err = a.updateInstanceStatus(instanceID, tApp.ID, InstanceStatusDownloading) assert.NoError(t, err) - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrUpdateInProgressOnInstance, err) } @@ -404,7 +404,7 @@ func TestGetUpdatePackage_CheckVersionForGrantedUpdate(t *testing.T) { instanceID := uuid.New().String() - _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") assert.NoError(t, err) instance, err := a.GetInstance(instanceID, tApp.ID) @@ -414,7 +414,7 @@ func TestGetUpdatePackage_CheckVersionForGrantedUpdate(t *testing.T) { assert.Equal(t, "12.1.0", instance.Application.LastUpdateVersion.String) assert.Equal(t, "12.0.0", instance.Application.Version) - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err) instanceStatusHistory, err := a.GetInstanceStatusHistory(instanceID, tApp.ID, tGroup.ID, 1) @@ -433,9 +433,9 @@ func TestGetUpdatePackage_InstanceStatusHistory(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 3, PolicyUpdateTimeout: "60 minutes"}) - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID, "", "") - _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID, "", "") _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") @@ -462,7 +462,7 @@ func TestMultiStepUpdateProgression(t *testing.T) { groupID := setup.Group.ID // Step 1: Instance at 1000 → should get first floor (2000) - pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID) + pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID, "", "") assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get first floor") @@ -472,33 +472,33 @@ func TestMultiStepUpdateProgression(t *testing.T) { assert.Equal(t, "2000.0.0", instance.Application.LastUpdateVersion.String) // Step 2: Still at 1000 (already-granted) → should get 2000 again - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID, "", "") assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get same floor when not updated") // Step 3: Instance updates to 2000 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when floor reached") instance, _ = a.GetInstance(instanceID, appID) assert.Equal(t, InstanceStatusComplete, int(instance.Application.Status.Int64)) // Step 4: Instance at 2000 → should get second floor (2500) - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID, "", "") assert.NoError(t, err) assert.Equal(t, "2500.0.0", pkg.Version, "Should get second floor") // Step 5: Instance updates to 2500 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err) // Step 6: Instance at 2500 → should get target (3000) - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID, "", "") assert.NoError(t, err) assert.Equal(t, "3000.0.0", pkg.Version, "Should get target after all floors") // Step 7: Instance updates to 3000 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "3000.0.0", appID, groupID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "3000.0.0", appID, groupID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when target reached") } @@ -511,7 +511,7 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { instanceID := "old-instance" // Get initial update - pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version) @@ -520,12 +520,12 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { assert.Equal(t, "2000.0.0", instance.Application.LastUpdateVersion.String) // Simulate old instance: clear LastUpdateVersion but keep UpdateGranted status - _, err = a.db.Exec(`UPDATE instance_application SET last_update_version = NULL + _, err = a.db.Exec(`UPDATE instance_application SET last_update_version = NULL WHERE instance_id = $1 AND application_id = $2`, instanceID, setup.AppID) assert.NoError(t, err) // Call with already-granted but no LastUpdateVersion - should complete and return error - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID, "", "") assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when LastUpdateVersion is NULL") // Verify status is now Complete @@ -533,7 +533,7 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { assert.Equal(t, InstanceStatusComplete, int(instance.Application.Status.Int64)) // Next call should go through normal grant path - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get floor through normal grant path") @@ -551,12 +551,12 @@ func TestSafetyRulesValidation(t *testing.T) { setup := setupFloors(t, a, "safety-test", []string{"1000.0.0", "2000.0.0"}, "3000.0.0") // Register an instance to test update behavior - _, err := a.RegisterInstance("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + _, err := a.RegisterInstance("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) t.Run("floor_never_blacklisted_for_own_channel", func(t *testing.T) { // Get update should work normally - pkg, err := a.GetUpdatePackage("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID, "", "") assert.NoError(t, err) assert.Equal(t, "1000.0.0", pkg.Version) diff --git a/backend/pkg/omaha/helpers_test.go b/backend/pkg/omaha/helpers_test.go index 6590919e2..219eecad6 100644 --- a/backend/pkg/omaha/helpers_test.go +++ b/backend/pkg/omaha/helpers_test.go @@ -1,8 +1,11 @@ package omaha import ( + "bytes" + "encoding/xml" "testing" + omahaSpec "github.com/flatcar/go-omaha/omaha" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" @@ -68,3 +71,31 @@ func setupOmahaFloorTest(t *testing.T, a *api.API, name string, floorVersions [] return group, pkgs } + +// doSyncerRequest sends a syncer Omaha request and returns the response +func doSyncerRequest(t *testing.T, h *Handler, version, groupID string, multiManifestOK bool) *omahaSpec.Response { + t.Helper() + + req := omahaSpec.NewRequest() + req.OS.Version = "3" + req.OS.Platform = "CoreOS" + req.OS.ServicePack = "linux" + req.OS.Arch = "x64" + req.Version = "CoreOSUpdateEngine-0.1.0.0" + req.InstallSource = "scheduler" + app := req.AddApp(flatcarAppID, version) + app.MachineID = "syncer-" + version + app.Track = groupID + app.AddUpdateCheck() + app.MultiManifestOK = multiManifestOK + + buf := bytes.NewBuffer(nil) + require.NoError(t, xml.NewEncoder(buf).Encode(req)) + + respBuf := bytes.NewBuffer(nil) + require.NoError(t, h.Handle(buf, respBuf, "10.0.0.1")) + + var resp omahaSpec.Response + require.NoError(t, xml.NewDecoder(respBuf).Decode(&resp)) + return &resp +} diff --git a/backend/pkg/omaha/omaha.go b/backend/pkg/omaha/omaha.go index 9ab2e8f93..a099df95f 100644 --- a/backend/pkg/omaha/omaha.go +++ b/backend/pkg/omaha/omaha.go @@ -148,7 +148,7 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o } if reqApp.Ping != nil { - if _, err := h.crAPI.RegisterInstance(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group); err != nil { + if _, err := h.crAPI.RegisterInstance(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group, reqApp.OEM, reqApp.OEMVersion); err != nil { l.Debug().Str("machineId", reqApp.MachineID).Msgf("processPing error %s", err.Error()) } respApp.AddPing() @@ -156,8 +156,9 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o if reqApp.UpdateCheck != nil { if isSyncerClient(omahaReq) { - // Syncer - get all packages - packages, err := h.crAPI.GetUpdatePackagesForSyncer(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group) + // Syncer - get floors and target separately + // target is nil when more floors remain beyond NEBRASKA_MAX_FLOORS_PER_RESPONSE limit + floors, target, err := h.crAPI.GetUpdatePackagesForSyncer(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group, reqApp.OEM, reqApp.OEMVersion) if err != nil { if err == api.ErrNoUpdatePackageAvailable || err == api.ErrUpdateGrantFailed { respApp.AddUpdateCheck(omahaSpec.NoUpdate) @@ -168,35 +169,26 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o continue } - // Check if we got any packages - if len(packages) == 0 { + // Check if we got anything to send + if len(floors) == 0 && target == nil { respApp.AddUpdateCheck(omahaSpec.NoUpdate) continue } // Critical safety rule: old syncers without MultiManifestOK cannot skip floors - // Check if ANY package is a floor (including the target which might also be a floor) - hasFloors := false - for _, pkg := range packages { - if pkg.IsFloor { - hasFloors = true - break - } - } - - if hasFloors && !reqApp.MultiManifestOK { + if len(floors) > 0 && !reqApp.MultiManifestOK { l.Warn().Str("instanceID", reqApp.MachineID). - Int("packageCount", len(packages)). + Int("floorCount", len(floors)). Msg("Syncer without multi-manifest support blocked due to floor requirements") respApp.AddUpdateCheck(omahaSpec.NoUpdate) continue } // Either multi-manifest capable syncer or no floors exist - h.prepareMultiManifestUpdateCheck(respApp, packages) + h.prepareMultiManifestUpdateCheck(respApp, floors, target) } else { // Regular client - get single package - pkg, err := h.crAPI.GetUpdatePackage(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group) + pkg, err := h.crAPI.GetUpdatePackage(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group, reqApp.OEM, reqApp.OEMVersion) if err != nil { if err == api.ErrNoUpdatePackageAvailable || err == api.ErrUpdateGrantFailed { respApp.AddUpdateCheck(omahaSpec.NoUpdate) @@ -337,25 +329,35 @@ func (h *Handler) prepareUpdateCheck(appResp *omahaSpec.AppResponse, pkg *api.Pa updateCheck.AddURL(pkg.URL) } -// prepareMultiManifestUpdateCheck creates a response with multiple manifests -// Each package gets one manifest with appropriate floor/target metadata based on its properties -func (h *Handler) prepareMultiManifestUpdateCheck(appResp *omahaSpec.AppResponse, packages []*api.Package) { - if len(packages) == 0 { +// prepareMultiManifestUpdateCheck creates a response with multiple manifests. +// When target is nil, more floors remain and syncer should request again. +func (h *Handler) prepareMultiManifestUpdateCheck(appResp *omahaSpec.AppResponse, floors []*api.Package, target *api.Package) { + if len(floors) == 0 && target == nil { appResp.AddUpdateCheck(omahaSpec.NoUpdate) return } - // The last package in the array is the target (by convention from GetUpdatePackagesForSyncer) - targetPkg := packages[len(packages)-1] + targetAlreadyInFloors := target != nil && len(floors) > 0 && floors[len(floors)-1].ID == target.ID + + var packages []*api.Package + packages = append(packages, floors...) + if target != nil && !targetAlreadyInFloors { + packages = append(packages, target) + } + + var codeBase string + if target != nil { + codeBase = target.URL + } else { + codeBase = packages[len(packages)-1].URL + } updateCheck := appResp.AddUpdateCheck(omahaSpec.UpdateOK) - updateCheck.AddURL(targetPkg.URL) + updateCheck.AddURL(codeBase) - // Create manifest for each package with appropriate flags - for i, pkg := range packages { + for _, pkg := range packages { manifest := updateCheck.AddManifest(pkg.Version) - // Set IsFloor if package has floor metadata if pkg.IsFloor { manifest.IsFloor = true if pkg.FloorReason.Valid { @@ -365,9 +367,8 @@ func (h *Handler) prepareMultiManifestUpdateCheck(appResp *omahaSpec.AppResponse } } - // Set IsTarget if this is the last package (the target) - // Note: A package can have both IsFloor and IsTarget flags - if i == len(packages)-1 { + isTarget := target != nil && pkg.ID == target.ID + if isTarget { manifest.IsTarget = true } diff --git a/backend/pkg/omaha/omaha_floors_test.go b/backend/pkg/omaha/omaha_floors_test.go index 61f3e4843..866c66328 100644 --- a/backend/pkg/omaha/omaha_floors_test.go +++ b/backend/pkg/omaha/omaha_floors_test.go @@ -3,6 +3,7 @@ package omaha import ( "bytes" "encoding/xml" + "os" "testing" omahaSpec "github.com/flatcar/go-omaha/omaha" @@ -11,159 +12,116 @@ import ( "gopkg.in/guregu/null.v4" ) -// TestFloorUpdateScenarios tests all floor-based update scenarios func TestFloorUpdateScenarios(t *testing.T) { a := newForTest(t) defer a.Close() h := NewHandler(a) - // Helper for syncer requests - syncerRequest := func(h *Handler, version, group string, multiManifestOK bool) *omahaSpec.Response { - req := omahaSpec.NewRequest() - req.OS.Version = "3" - req.OS.Platform = "CoreOS" - req.OS.ServicePack = "linux" - req.OS.Arch = "x64" - req.Version = "CoreOSUpdateEngine-0.1.0.0" - req.InstallSource = "scheduler" - app := req.AddApp(flatcarAppID, version) - app.MachineID = "syncer-" + version - app.Track = group - app.AddUpdateCheck() - app.MultiManifestOK = multiManifestOK - - buf := bytes.NewBuffer(nil) - err := xml.NewEncoder(buf).Encode(req) - if err != nil { - t.Fatalf("Failed to encode request: %v", err) - } - respBuf := bytes.NewBuffer(nil) - err = h.Handle(buf, respBuf, "10.0.0.1") - if err != nil { - t.Fatalf("Failed to handle request: %v", err) - } - var resp omahaSpec.Response - err = xml.NewDecoder(respBuf).Decode(&resp) - if err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - return &resp - } - - t.Run("RegularClientWithUpdate", func(t *testing.T) { - group, pkgs := setupOmahaFloorTest(t, a, "regular", []string{"2000.0.0", "2500.0.0"}, "3000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 3) + // RegularClient: Tests that regular Flatcar clients receive ONE package at a time + // Setup: floors [2000, 2500], target 3000 + // Regular clients progress through floors sequentially with reboots between each step + t.Run("RegularClient", func(t *testing.T) { + group, _ := setupOmahaFloorTest(t, a, "regular", []string{"2000.0.0", "2500.0.0"}, "3000.0.0") tests := []struct { instance, expected string }{ - {"1500.0.0", "2000.0.0"}, // below floors -> floor1 - {"2000.0.0", "2500.0.0"}, // at floor1 -> floor2 - {"2200.0.0", "2500.0.0"}, // between floors -> floor2 - {"2500.0.0", "3000.0.0"}, // at floor2 -> target - {"2700.0.0", "3000.0.0"}, // above floors -> target + {"1500.0.0", "2000.0.0"}, // below all floors -> gets first floor + {"2000.0.0", "2500.0.0"}, // at first floor -> gets second floor + {"2200.0.0", "2500.0.0"}, // between floors -> gets next floor above + {"2500.0.0", "3000.0.0"}, // at last floor -> gets target + {"2700.0.0", "3000.0.0"}, // above floors but below target -> gets target } - for _, tc := range tests { resp := doOmahaRequest(t, h, flatcarAppID, tc.instance, "client-"+tc.instance, group.ID, "10.0.0.1", false, true, nil) - require.NotNil(t, resp) - filename := "flatcar_" + tc.expected + ".gz" - url := "http://sample.url/" + tc.expected - checkOmahaUpdateResponse(t, resp, tc.expected, filename, url, omahaSpec.UpdateOK) + checkOmahaUpdateResponse(t, resp, tc.expected, "flatcar_"+tc.expected+".gz", "http://sample.url/"+tc.expected, omahaSpec.UpdateOK) assert.Len(t, resp.Apps[0].UpdateCheck.Manifests, 1) } }) + // RegularClientNoUpdate: Client already at target version gets NoUpdate t.Run("RegularClientNoUpdate", func(t *testing.T) { - group, pkgs := setupOmahaFloorTest(t, a, "noupdate", []string{}, "5000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 1) + group, _ := setupOmahaFloorTest(t, a, "noupdate", []string{}, "5000.0.0") resp := doOmahaRequest(t, h, flatcarAppID, "5000.0.0", "client", group.ID, "10.0.0.1", false, true, nil) checkOmahaUpdateResponse(t, resp, "", "", "", omahaSpec.NoUpdate) }) + // SyncerMultiManifest: Modern syncers (MultiManifestOK=true) receive ALL packages in one response + // Setup: floors [9000, 9500], target 10000 + // Unlike regular clients, syncers get floors + target together to sync them all at once t.Run("SyncerMultiManifest", func(t *testing.T) { - group, pkgs := setupOmahaFloorTest(t, a, "syncer", []string{"9000.0.0", "9500.0.0"}, "10000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 3) - - testCases := map[string]int{ - "8500.0.0": 3, // below floors: all manifests - "9200.0.0": 2, // between: floor2 + target - "9700.0.0": 1, // above: target only - } + group, _ := setupOmahaFloorTest(t, a, "syncer", []string{"9000.0.0", "9500.0.0"}, "10000.0.0") - for version, expectedManifests := range testCases { - resp := syncerRequest(h, version, group.ID, true) - assert.Len(t, resp.Apps[0].UpdateCheck.Manifests, expectedManifests) - // Verify last is target - last := resp.Apps[0].UpdateCheck.Manifests[expectedManifests-1] - assert.True(t, last.IsTarget) - // Verify others are floors - for i := 0; i < expectedManifests-1; i++ { + tests := []struct { + version string + expectedManifests int + }{ + {"8500.0.0", 3}, // below all -> gets 2 floors + target + {"9200.0.0", 2}, // between floors -> gets 1 floor + target + {"9700.0.0", 1}, // above floors -> gets only target + } + for _, tc := range tests { + resp := doSyncerRequest(t, h, tc.version, group.ID, true) + require.Len(t, resp.Apps[0].UpdateCheck.Manifests, tc.expectedManifests) + assert.True(t, resp.Apps[0].UpdateCheck.Manifests[tc.expectedManifests-1].IsTarget) + for i := 0; i < tc.expectedManifests-1; i++ { assert.True(t, resp.Apps[0].UpdateCheck.Manifests[i].IsFloor) } } }) - t.Run("OldSyncerBlocked", func(t *testing.T) { - group, pkgs := setupOmahaFloorTest(t, a, "oldsyncer", []string{"6000.0.0"}, "7000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 2) - resp := syncerRequest(h, "1500.0.0", group.ID, false) // multiPkgOK=false + // LegacySyncerBlocked: Old syncers without MultiManifestOK get NoUpdate when floors exist + // This prevents legacy syncers from skipping mandatory floor versions + t.Run("LegacySyncerBlocked", func(t *testing.T) { + group, _ := setupOmahaFloorTest(t, a, "oldsyncer", []string{"6000.0.0"}, "7000.0.0") + resp := doSyncerRequest(t, h, "1500.0.0", group.ID, false) // multiManifestOK=false checkOmahaUpdateResponse(t, resp, "", "", "", omahaSpec.NoUpdate) }) + // ModernSyncerNoUpdate: Modern syncer already at target version gets NoUpdate t.Run("ModernSyncerNoUpdate", func(t *testing.T) { - group, pkgs := setupOmahaFloorTest(t, a, "syncernoup", []string{}, "8000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 1) - resp := syncerRequest(h, "8000.0.0", group.ID, true) // At target version + group, _ := setupOmahaFloorTest(t, a, "syncernoup", []string{}, "8000.0.0") + resp := doSyncerRequest(t, h, "8000.0.0", group.ID, true) // at target version checkOmahaUpdateResponse(t, resp, "", "", "", omahaSpec.NoUpdate) }) + // NoFloors: When no floors are configured, both clients and syncers get direct update to target t.Run("NoFloors", func(t *testing.T) { - // Setup without floors - group, pkgs := setupOmahaFloorTest(t, a, "nofloor", []string{}, "4000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 1) + group, _ := setupOmahaFloorTest(t, a, "nofloor", []string{}, "4000.0.0") // Regular client gets direct update resp := doOmahaRequest(t, h, flatcarAppID, "1000.0.0", "client", group.ID, "10.0.0.1", false, true, nil) checkOmahaUpdateResponse(t, resp, "4000.0.0", "flatcar_4000.0.0.gz", "http://sample.url/4000.0.0", omahaSpec.UpdateOK) - // Syncer gets single manifest - resp = syncerRequest(h, "1000.0.0", group.ID, true) - assert.Len(t, resp.Apps[0].UpdateCheck.Manifests, 1) + // Syncer gets single manifest with IsTarget=true + resp = doSyncerRequest(t, h, "1000.0.0", group.ID, true) + require.Len(t, resp.Apps[0].UpdateCheck.Manifests, 1) assert.True(t, resp.Apps[0].UpdateCheck.Manifests[0].IsTarget) }) + // TargetAsFloor: Target package can also be marked as a floor + // Setup: floors [11000, 12000], target 13000, then mark 13000 as floor too + // Regular clients progress through floors including the target-floor + // Syncers receive all manifests with the last one having BOTH IsFloor=true AND IsTarget=true t.Run("TargetAsFloor", func(t *testing.T) { - // Setup where target is also a floor (critical mandatory version) group, pkgs := setupOmahaFloorTest(t, a, "targetfloor", []string{"11000.0.0", "12000.0.0"}, "13000.0.0") - require.NotNil(t, group) - require.Len(t, pkgs, 3) + // Mark target as also being a floor (critical version that must be installed) + require.NoError(t, a.AddChannelPackageFloor(group.ChannelID.String, pkgs[2].ID, null.StringFrom("Critical mandatory version"))) - // Mark the target as ALSO being a floor - err := a.AddChannelPackageFloor(group.ChannelID.String, pkgs[2].ID, - null.StringFrom("Critical mandatory version")) - require.NoError(t, err) - - // Regular client below all versions + // Regular client below all versions -> gets first floor resp := doOmahaRequest(t, h, flatcarAppID, "10000.0.0", "client-low", group.ID, "10.0.0.1", false, true, nil) checkOmahaUpdateResponse(t, resp, "11000.0.0", "flatcar_11000.0.0.gz", "http://sample.url/11000.0.0", omahaSpec.UpdateOK) - // Regular client between floors + // Regular client between floors -> gets second floor resp = doOmahaRequest(t, h, flatcarAppID, "11500.0.0", "client-mid", group.ID, "10.0.0.2", false, true, nil) checkOmahaUpdateResponse(t, resp, "12000.0.0", "flatcar_12000.0.0.gz", "http://sample.url/12000.0.0", omahaSpec.UpdateOK) - // Regular client above regular floors but below target-floor + // Regular client above regular floors but below target-floor -> gets target (which is also a floor) resp = doOmahaRequest(t, h, flatcarAppID, "12500.0.0", "client-high", group.ID, "10.0.0.3", false, true, nil) checkOmahaUpdateResponse(t, resp, "13000.0.0", "flatcar_13000.0.0.gz", "http://sample.url/13000.0.0", omahaSpec.UpdateOK) // Syncer should get all manifests with correct flags - resp = syncerRequest(h, "10000.0.0", group.ID, true) + resp = doSyncerRequest(t, h, "10000.0.0", group.ID, true) require.Len(t, resp.Apps[0].UpdateCheck.Manifests, 3) // First two are floors only @@ -171,7 +129,6 @@ func TestFloorUpdateScenarios(t *testing.T) { assert.False(t, resp.Apps[0].UpdateCheck.Manifests[0].IsTarget) assert.True(t, resp.Apps[0].UpdateCheck.Manifests[1].IsFloor) assert.False(t, resp.Apps[0].UpdateCheck.Manifests[1].IsTarget) - // Last one is BOTH floor AND target assert.True(t, resp.Apps[0].UpdateCheck.Manifests[2].IsFloor) assert.True(t, resp.Apps[0].UpdateCheck.Manifests[2].IsTarget) @@ -179,6 +136,74 @@ func TestFloorUpdateScenarios(t *testing.T) { }) } +func TestFloorLimitPagination(t *testing.T) { + oldMax := os.Getenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE") + defer os.Setenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE", oldMax) + os.Setenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE", "2") + + a := newForTest(t) + defer a.Close() + h := NewHandler(a) + + group, _ := setupOmahaFloorTest(t, a, "floor-limit", + []string{"1000.0.0", "2000.0.0", "3000.0.0", "4000.0.0", "5000.0.0"}, "6000.0.0") + + // Test scenarios for floor limit pagination with limit=2 and floors [1000, 2000, 3000, 4000, 5000] + target 6000 + // + // Scenario 1 (round1): Syncer at 0.0.0 + // - 5 floors remain (1000-5000), exceeds limit of 2 + // - Returns floors [1000, 2000], NO target (hasTarget=false) + // - Syncer should request again with version 2000.0.0 + // + // Scenario 2 (round2): Syncer at 2000.0.0 (after processing round1) + // - 3 floors remain (3000-5000), exceeds limit of 2 + // - Returns floors [3000, 4000], NO target (hasTarget=false) + // - Syncer should request again with version 4000.0.0 + // + // Scenario 3 (round3): Syncer at 4000.0.0 (after processing round2) + // - 1 floor remains (5000), under limit + // - Returns floor [5000] + target [6000] (hasTarget=true) + // - All floors sent, syncer can update channel to target + // + // Scenario 4 (at_limit): Syncer at 3000.0.0 + // - 2 floors remain (4000, 5000), exactly at limit + // - Returns floors [4000, 5000] + target [6000] (hasTarget=true) + // - All floors sent, syncer can update channel to target + tests := []struct { + name string + version string // syncer's current version + expectedCount int // number of manifests in response + expectedFirst string // first manifest version + expectedLast string // last manifest version + hasTarget bool // whether target is included (all floors sent) + }{ + {"round1", "0.0.0", 2, "1000.0.0", "2000.0.0", false}, + {"round2", "2000.0.0", 2, "3000.0.0", "4000.0.0", false}, + {"round3", "4000.0.0", 2, "5000.0.0", "6000.0.0", true}, + {"at_limit", "3000.0.0", 3, "4000.0.0", "6000.0.0", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp := doSyncerRequest(t, h, tc.version, group.ID, true) + + require.Len(t, resp.Apps[0].UpdateCheck.Manifests, tc.expectedCount) + assert.Equal(t, tc.expectedFirst, resp.Apps[0].UpdateCheck.Manifests[0].Version) + assert.Equal(t, tc.expectedLast, resp.Apps[0].UpdateCheck.Manifests[tc.expectedCount-1].Version) + + for i, m := range resp.Apps[0].UpdateCheck.Manifests { + isLast := i == tc.expectedCount-1 + if tc.hasTarget && isLast { + assert.True(t, m.IsTarget, "last manifest should be target") + } else { + assert.True(t, m.IsFloor, "manifest %d should be floor", i) + assert.False(t, m.IsTarget, "manifest %d should not be target when more floors remain", i) + } + } + }) + } +} + // TestLegacySyncerBlockedWithFloors tests that legacy syncers without MultiManifestOK are blocked when floors exist func TestLegacySyncerBlockedWithFloors(t *testing.T) { a := newForTest(t) diff --git a/backend/pkg/syncer/syncer.go b/backend/pkg/syncer/syncer.go index 44c862cd8..ef0d5f9be 100644 --- a/backend/pkg/syncer/syncer.go +++ b/backend/pkg/syncer/syncer.go @@ -619,6 +619,9 @@ func (s *Syncer) processMultiManifestUpdate(descriptor channelDescriptor, update } // If still no target found, all manifests are floors - targetVersion remains empty + // Track highest floor version for version reporting when no target is present + var highestFloorVersion string + // Process each manifest in the response for _, manifest := range update.Manifests { version := manifest.Version @@ -641,6 +644,8 @@ func (s *Syncer) processMultiManifestUpdate(descriptor channelDescriptor, update if err := s.markPackageAsFloor(descriptor, pkg, manifest); err != nil { return fmt.Errorf("failed to mark package %s as floor: %w", version, err) } + // Track highest floor version (manifests are ordered ascending) + highestFloorVersion = version } // Track target package if this is the target @@ -653,15 +658,20 @@ func (s *Syncer) processMultiManifestUpdate(descriptor channelDescriptor, update if targetPkg == nil { // All manifests were floors with no target package identified. // This is a VALID scenario where upstream wants to establish mandatory floors - // without changing the current channel target yet (target may come later). - // We've successfully processed and marked all floors, but there's no new - // version to point the channel to, so it remains at its current version. + // without changing the current channel target yet (more floors may remain). + // We update the tracked version to the highest floor so the next sync request + // will fetch remaining floors. + if highestFloorVersion != "" { + s.versions[descriptor] = highestFloorVersion + s.bootIDs[descriptor] = "{" + uuid.New().String() + "}" + } l.Info(). Str("channel", descriptor.name). Str("arch", descriptor.arch.String()). Int("floors_processed", len(update.Manifests)). - Msg("processMultiManifestUpdate - all manifests are floors, channel remains at current version") - return nil // Success - floors processed, just no channel update + Str("highestFloor", highestFloorVersion). + Msg("processMultiManifestUpdate - all manifests are floors, tracking highest floor for next sync") + return nil // Success - floors processed, channel not updated yet } // Update channel to point to the target package diff --git a/backend/pkg/syncer/syncer_multimanifest_test.go b/backend/pkg/syncer/syncer_multimanifest_test.go index a696f1416..ed927ca43 100644 --- a/backend/pkg/syncer/syncer_multimanifest_test.go +++ b/backend/pkg/syncer/syncer_multimanifest_test.go @@ -1,6 +1,7 @@ package syncer import ( + "os" "testing" "github.com/flatcar/go-omaha/omaha" @@ -416,3 +417,155 @@ func TestSyncer_EmptyManifestError(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "no manifests") } + +// TestSyncer_FloorLimitVersionTracking tests that when there are more floors than +// NEBRASKA_MAX_FLOORS_PER_RESPONSE, the syncer correctly syncs ALL floors across +// multiple sync rounds. +// +// When there are more floors remaining beyond the limit, the server +// sends ONLY floors (no target). This way: +// - Syncer processes floors and updates channel to highest floor +// - Syncer tracks highest floor version for next request +// - Next request fetches remaining floors +// - Only when all floors are sent does the server include the target +// +// Scenario with limit=2 and 5 floors: +// - Round 1: syncer at 0.0.0 -> server sends floors 1,2 (no target, more floors remain) +// - Round 2: syncer at 2000.0.0 -> server sends floors 3,4 (no target, more floors remain) +// - Round 3: syncer at 4000.0.0 -> server sends floor 5 + target (all floors sent) +func TestSyncer_FloorLimitVersionTracking(t *testing.T) { + // Set a low limit to test pagination behavior + oldMax := os.Getenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE") + defer os.Setenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE", oldMax) + os.Setenv("NEBRASKA_MAX_FLOORS_PER_RESPONSE", "2") + + syncer := newForTest(t, &Config{}) + a := syncer.api + t.Cleanup(func() { a.Close() }) + + tGroup := setupFlatcarAppStableGroup(t, a) + tChannel := tGroup.Channel + require.NoError(t, syncer.initialize()) + + desc := channelDescriptor{name: tChannel.Name, arch: tChannel.Arch} + + // Round 1: Server sends only floors (no target) because more floors remain + // This simulates what upstream Nebraska would send when there are 5 floors but limit is 2 + round1 := &omaha.UpdateResponse{ + Status: "ok", + URLs: []*omaha.URL{{CodeBase: "https://example.com"}}, + Manifests: []*omaha.Manifest{ + { + Version: "1000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-1000.0.0.gz", SHA1: "hash1000", Size: 1000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsFloor: true, + FloorReason: "Floor 1", + }, + { + Version: "2000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-2000.0.0.gz", SHA1: "hash2000", Size: 2000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsFloor: true, + FloorReason: "Floor 2", + }, + // NO TARGET - more floors remain + }, + } + + // Process round 1 + err := syncer.processMultiManifestUpdate(desc, round1) + require.NoError(t, err) + + // After round 1: syncer should track highest floor (2000.0.0) + // Channel should NOT be updated (no target in response) + trackedVersion := syncer.versions[desc] + t.Logf("After round 1, syncer tracked version: %s", trackedVersion) + + // Verify floors 1 and 2 were synced + floors, err := a.GetChannelFloorPackages(tChannel.ID) + require.NoError(t, err) + assert.Len(t, floors, 2, "Should have 2 floors after round 1") + + // Round 2: Server sends next batch of floors (still no target, more remain) + round2 := &omaha.UpdateResponse{ + Status: "ok", + URLs: []*omaha.URL{{CodeBase: "https://example.com"}}, + Manifests: []*omaha.Manifest{ + { + Version: "3000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-3000.0.0.gz", SHA1: "hash3000", Size: 3000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsFloor: true, + FloorReason: "Floor 3", + }, + { + Version: "4000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-4000.0.0.gz", SHA1: "hash4000", Size: 4000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsFloor: true, + FloorReason: "Floor 4", + }, + // NO TARGET - more floors remain + }, + } + + // Process round 2 + err = syncer.processMultiManifestUpdate(desc, round2) + require.NoError(t, err) + + // Verify floors 3 and 4 were synced + floors, err = a.GetChannelFloorPackages(tChannel.ID) + require.NoError(t, err) + assert.Len(t, floors, 4, "Should have 4 floors after round 2") + + // Round 3: Server sends last floor + target (all floors now sent) + round3 := &omaha.UpdateResponse{ + Status: "ok", + URLs: []*omaha.URL{{CodeBase: "https://example.com"}}, + Manifests: []*omaha.Manifest{ + { + Version: "5000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-5000.0.0.gz", SHA1: "hash5000", Size: 5000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsFloor: true, + FloorReason: "Floor 5", + }, + { + Version: "6000.0.0", + Packages: []*omaha.Package{{Name: "flatcar-6000.0.0.gz", SHA1: "hash6000", Size: 6000}}, + Actions: []*omaha.Action{{Event: "postinstall", SHA256: "dGVzdHNoYTI1Ng=="}}, + IsTarget: true, + }, + }, + } + + // Process round 3 + err = syncer.processMultiManifestUpdate(desc, round3) + require.NoError(t, err) + + // Final verification: ALL 5 floors should be synced + floors, err = a.GetChannelFloorPackages(tChannel.ID) + require.NoError(t, err) + assert.Len(t, floors, 5, "All 5 floors should be synced after 3 rounds") + + // Verify floor versions + floorVersions := make([]string, len(floors)) + for i, f := range floors { + floorVersions[i] = f.Version + } + assert.Contains(t, floorVersions, "1000.0.0", "Floor 1 should be synced") + assert.Contains(t, floorVersions, "2000.0.0", "Floor 2 should be synced") + assert.Contains(t, floorVersions, "3000.0.0", "Floor 3 should be synced") + assert.Contains(t, floorVersions, "4000.0.0", "Floor 4 should be synced") + assert.Contains(t, floorVersions, "5000.0.0", "Floor 5 should be synced") + + // Verify channel now points to target (only after all floors sent) + updatedChannel, err := a.GetChannel(tChannel.ID) + require.NoError(t, err) + assert.Equal(t, "6000.0.0", updatedChannel.Package.Version, "Channel should point to target after all floors synced") + + // Verify syncer now tracks target version (ready for next sync cycle) + finalVersion := syncer.versions[desc] + assert.Equal(t, "6000.0.0", finalVersion, "Syncer should track target after all floors synced") +} diff --git a/backend/test/api/instance_test.go b/backend/test/api/instance_test.go index dbee21dea..e0b88b72a 100644 --- a/backend/test/api/instance_test.go +++ b/backend/test/api/instance_test.go @@ -92,7 +92,7 @@ func TestGetInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // fetch instance from API @@ -116,7 +116,7 @@ func TestGetInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // fetch instance from API @@ -143,11 +143,11 @@ func TestGetInstanceStatusHistory(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // create event for instance @@ -177,11 +177,11 @@ func TestGetInstanceStatusHistory(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // create event for instance @@ -214,7 +214,7 @@ func TestUpdateInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // fetch instance from API diff --git a/backend/test/api/stats_test.go b/backend/test/api/stats_test.go index 9a3770b14..19ffd7cd0 100644 --- a/backend/test/api/stats_test.go +++ b/backend/test/api/stats_test.go @@ -121,11 +121,11 @@ func TestGroupStatusTimeline(t *testing.T) { // create instance for app[0] instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID, "", "") require.NoError(t, err) // create event for instance diff --git a/docs/ARD-001-multi-step-updates.md b/docs/ARD-001-multi-step-updates.md deleted file mode 100644 index 3d0ad0062..000000000 --- a/docs/ARD-001-multi-step-updates.md +++ /dev/null @@ -1,82 +0,0 @@ -# ARD-001: Multi-Step Updates (Floor Packages) - -## Status -Implemented - -## Context -Flatcar requires mandatory intermediate versions (floors) when updating across major versions to prevent compatibility issues and system failures. - -## Decision - -### Core Architecture -- **Channel-specific floors** stored in new `channel_package_floors` junction table with indexes -- **Dynamic metadata**: `IsFloor` and `FloorReason` populated at query time, not stored in package table -- **Reuse existing tracking**: Utilize existing `last_update_version` field for already-granted state management -- **Atomic package creation**: New `AddPackageWithMetadata` API for complete metadata insertion - -### Update Behavior - -#### Regular Clients -- Receive single package (next floor or target) and progress sequentially through floors -- Already-granted instances with `last_update_version` can continue progression without re-evaluation -- NULL `last_update_version` triggers completion to force fresh grant cycle - -#### Syncers (Nebraska instances) -- **Modern syncers with `multi_manifest_ok=true`**: Receive all packages (floors + target) in one response -- **Legacy syncers without `multi_manifest_ok`**: Blocked with `NoUpdate` response when floors exist -- Syncers identified by `InstallSource="scheduler"` in Omaha request - -#### Target Detection (Syncer-specific) -For multi-manifest responses, syncers use this priority: -1. **Explicit**: Manifest with `is_target="true"` attribute -2. **Implicit**: Last manifest that is NOT a floor (backward compatibility) -3. **None**: All manifests are floors (valid - no channel update) - -### Safety Rules & Constraints - -#### Universal Constraints -1. **Floor/Blacklist Mutual Exclusion**: Packages cannot be both floor AND blacklisted for same channel -2. **Channel Target Protection**: Channel's current package cannot be blacklisted -3. **Cross-channel Independence**: Package can be floor for one channel and blacklisted for another - -#### Syncer-specific Constraints -1. **Atomic Floor Operations**: Floor marking failures abort entire update -2. **Package Verification**: Existing packages verified for hash/size match before reuse -3. **Download Cleanup**: Failed downloads cleaned up to prevent orphaned files -4. **Legacy Syncer Safety**: Syncers without multi-manifest support blocked when floors exist - -### API Endpoints - -#### Floor Management -- `POST /api/v1/apps/{app_id}/channels/{channel_id}/packages/{package_id}/floor` - Mark as floor -- `DELETE /api/v1/apps/{app_id}/channels/{channel_id}/packages/{package_id}/floor` - Unmark as floor -- `GET /api/v1/apps/{app_id}/channels/{channel_id}/packages/floors` - List floor packages - -#### Package Response Fields -- `is_floor`: Boolean indicating floor status -- `floor_reason`: Text explanation for floor requirement - -## Consequences - -### Positive -- Safe update paths preventing incompatible version jumps -- Backward compatible with existing single-step updates -- Sequential progression ensures system stability -- Channel-specific flexibility for different update strategies -- Atomic operations prevent partial states - -### Negative -- Legacy syncers blocked when floors present (requires upgrade) -- Additional database queries for floor checking -- Increased complexity in update decision logic - -## Dependencies - -### go-omaha Library -Enhanced go-omaha library with: -- Multi-manifest support (`Manifests` array replacing single `Manifest`) -- Floor attributes (`IsFloor`, `FloorReason`, `IsTarget`) -- `MultiManifestOK` capability flag for syncers - -## References -- [Flatcar Discussion #1831](https://github.com/flatcar/Flatcar/discussions/1831) \ No newline at end of file diff --git a/docs/architecture-decisions.md b/docs/architecture-decisions.md index aa4def051..f969f46d5 100644 --- a/docs/architecture-decisions.md +++ b/docs/architecture-decisions.md @@ -15,6 +15,7 @@ This document captures important architectural decisions made for the Nebraska p **Problem:** OIDC tokens were exposed in server logs via query parameters, creating security vulnerabilities. **Additional Issues:** + - Deprecated password grant authentication - localStorage token storage (XSS vulnerable) - Backend OAuth flow complexity @@ -40,6 +41,7 @@ This document captures important architectural decisions made for the Nebraska p ### Configuration Changes **Removed Flags:** + - `--oidc-client-secret` (public client, no secret needed) - `--oidc-session-secret` (stateless backend) - `--oidc-session-crypt-key` (stateless backend) @@ -64,7 +66,7 @@ For Nebraska's use case as an infrastructure admin tool: **Usage Pattern:** Administrators typically use Nebraska a few times per month for specific maintenance tasks **Session Requirements:** SSO sessions (8-12 hours) exceed typical usage duration **User Experience:** SSO provides seamless re-authentication without manual intervention -**Complexity Trade-off:** Refresh token implementation adds significant complexity for minimal benefit given the usage pattern +**Complexity Trade-off:** Refresh token implementation adds significant complexity for minimal benefit given the usage pattern The OIDC provider's SSO session cookies handle re-authentication transparently, making refresh tokens unnecessary for this low-frequency admin tool use case. Users get the same "stay logged in" experience without the additional implementation and security complexity of refresh token rotation, storage, and revocation mechanisms. @@ -78,3 +80,106 @@ The OIDC provider's SSO session cookies handle re-authentication transparently, - [RFC 7636 - PKCE](https://datatracker.ietf.org/doc/html/rfc7636) - [OAuth 2.0 Security BCP](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics) - [OAuth 2.0 for SPAs](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-browser-based-apps) + +--- + +## ADR-002: Multi-Step Updates with Floor Packages + +**Status**: Implemented +**Date**: 2025-10-27 +**PR**: [#1195 - feat: Implement multi-step updates with floor packages](https://github.com/flatcar/nebraska/pull/1195) +**Discussion**: [Flatcar #1831 - Multi-Step Update Feature Design](https://github.com/flatcar/Flatcar/discussions/1831) + +### Context + +**Problem:** Nebraska only supported single-step updates, where instances jump directly from their current version to the target. This prevents safe rollout of breaking changes that require intermediate migration steps (e.g., filesystem support changes, partition table migrations). + +### Options Considered + +1. **Package-level prerequisites**: Each package declares its required predecessor version +2. **Version range specifications**: Complex patterns like `>=3374.2.0`, `!=3450.0.0`, etc. +3. **Channel-level floor packages**: Mandatory checkpoint versions managed per channel + +### Decision + +**Solution: Channel-level floor packages** + +Floor packages are checkpoint versions that ALL clients must install when updating through a channel. They are managed separately from packages in a dedicated `channel_package_floors` junction table. + +### Why Floor Packages Over Prerequisites + +The package-level prerequisite approach was initially implemented but revealed a critical flaw: + +- If package 3602.2.0 requires 3510.2.0 as prerequisite +- Later, 3815.0.0 is released WITHOUT prerequisites if prerquisit is not inherited +- Clients can jump directly to 3815.0.0, bypassing 3510.2.0 + +This means ALL future packages would need to inherit prerequisites, even for unrelated changes. Floor-based semantics solve this by making checkpoints a channel policy, not a package property. This way we have a central management of update path instead of tracking prerequisits for all later packages individually. + +### Architecture + +**Database**: New `channel_package_floors` table with channel_id, package_id, floor_reason + +**Update Flow**: + +- Regular clients: Receive one package at a time, progress sequentially through floors +- Modern syncers (`multi_manifest_ok=true`): Receive floors in batches (up to `NEBRASKA_MAX_FLOORS_PER_RESPONSE`). When more floors remain beyond the limit, response contains only floors (no target). Syncer requests again with highest floor version until all floors are sent, then target is included. +- Legacy syncers: Blocked with `NoUpdate` when floors exist (the syncer itself must be upgraded) + +**Safety Rules**: + +- Floor/blacklist mutual exclusion (package cannot be both for same channel) +- Channel target cannot be blacklisted +- Architecture must match between floor package and channel + +### Configuration + +**Environment Variables**: + +- `NEBRASKA_MAX_FLOORS_PER_RESPONSE`: Maximum floors per syncer response (default: 5) + +**API Endpoints** (see [API spec](../backend/api/spec.yaml) for details): + +- `PUT /api/channels/{channelID}/floors/{packageID}` - Set floor (idempotent) +- `DELETE /api/channels/{channelID}/floors/{packageID}` - Remove floor +- `GET /api/channels/{channelID}/floors` - List floors for a channel +- `GET /api/apps/{appIDorProductID}/packages/{packageID}/floor-channels` - List channels where package is a floor + +### Trade-offs + +**Benefits**: + +- Consistent update paths regardless of target version +- Simpler management (channel-level, not per-package) +- No prerequisite inheritance burden on new packages +- Clear separation of policy from package identity + +**Limitations**: + +- Legacy syncers blocked when floors exist +- Requires careful timing: configure floors BEFORE channel promotion +- Must configure floors for ALL channels (stable, beta, LTS) strategically + +### Operational Considerations + +**Timing**: Configure floors BEFORE promoting channel to new target. Adding after promotion allows clients to skip floors. + +**Cross-channel**: Don't use beta/alpha packages meant as floors for stable (would switch clients to wrong channel). + +**Use cases**: Breaking compatibility changes only (e.g., filesystem support), NOT security updates. + +### What Was NOT Implemented + +From the original design discussion, the following were deferred or rejected: + +- Complex version specifications (ranges, patterns, exclusions) +- PostgreSQL semver extension +- Recovery mechanisms +- Canary deployment checkbox for bootloader changes +- Emergency bypass mechanisms + +### References + +- [Flatcar Issue #1185 - RFE: Multi-stage updates](https://github.com/flatcar/Flatcar/issues/1185) +- [#1195 - feat: Implement multi-step updates with floor packages](https://github.com/flatcar/nebraska/pull/1195) +- [Flatcar #1831 - Multi-Step Update Feature Design](https://github.com/flatcar/Flatcar/discussions/1831)