diff --git a/CHANGELOG.md b/CHANGELOG.md index 18b3f1c70..320f1b9e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - **Custom CA Certificate for TLS:** Added `--ca-file` flag to trust additional CA certificates for TLS verification (e.g., internal CA, Let's Encrypt staging). Applies to the OIDC provider client and the syncer. Supports multiple PEM-encoded certs, additive to system CAs. Also exposed as `config.caFile` in the Helm chart. +- **OEM Attribute Capture:** Instances now store OEM and Aleph version information from Omaha update requests. ([#1286](https://github.com/flatcar/nebraska/pull/1286)) - **Multi-Step Updates with Floor Packages:** Added support for mandatory intermediate update versions (floor packages) that clients must install before reaching the target version. This enables safe migration paths for breaking changes by ensuring clients update through specific versions in order. Floor packages can be configured per channel with optional reasons and are architecture-specific. ([#1195](https://github.com/flatcar/nebraska/pull/1195)) - **Nebraska backend is able to use OIDC userinfo endpoint:** Some OIDC providers do not return group membership inside the access token. The Nebraska frontend passes this access token via the header `Authorization: Bearer ` to the backend which can then (optionally) call the OIDC provider's userinfo endpoint to gather group membership. ([#1279](https://github.com/flatcar/nebraska/pull/1279)) diff --git a/backend/pkg/api/activity_test.go b/backend/pkg/api/activity_test.go index 9488dda8f..bbac601c4 100644 --- a/backend/pkg/api/activity_test.go +++ b/backend/pkg/api/activity_test.go @@ -20,9 +20,9 @@ func TestGetActivity(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp.ID, PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID) - tFakeInstance, _ := a.RegisterInstance("{"+uuid.New().String()+"}", "", "10.0.0.2", "1.0.0", tApp.ID, tGroup2.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + tInstance2, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup2.ID, "1.0.0")) + tFakeInstance, _ := a.RegisterInstance(Instance{ID: "{" + uuid.New().String() + "}", IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup2.ID, "1.0.0")) _ = a.newGroupActivityEntry(activityRolloutStarted, activitySuccess, tVersion, tApp.ID, tGroup.ID) _ = a.newGroupActivityEntry(activityRolloutStarted, activitySuccess, tVersion, tApp.ID, tGroup2.ID) diff --git a/backend/pkg/api/applications_test.go b/backend/pkg/api/applications_test.go index 4a6ded237..8ab9e66c3 100644 --- a/backend/pkg/api/applications_test.go +++ b/backend/pkg/api/applications_test.go @@ -158,7 +158,7 @@ func TestGetApp(t *testing.T) { assert.NoError(t, err) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) app, err := a.GetApp(tApp.ID) assert.NoError(t, err) @@ -290,8 +290,8 @@ func TestGetAppsFiltered(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) realInstanceID := uuid.New().String() fakeInstanceID := "{" + uuid.New().String() + "}" - _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: realInstanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + _, _ = a.RegisterInstance(Instance{ID: fakeInstanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) // should ignore fake instance in Instances count apps, err := a.GetApps(tTeam.ID, 1, 10) diff --git a/backend/pkg/api/db/migrations/0021_add_instance_oem.sql b/backend/pkg/api/db/migrations/0021_add_instance_oem.sql new file mode 100644 index 000000000..7772c021a --- /dev/null +++ b/backend/pkg/api/db/migrations/0021_add_instance_oem.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE instance ADD COLUMN oem VARCHAR(256) NOT NULL DEFAULT ''; +ALTER TABLE instance ADD COLUMN aleph_version VARCHAR(256) NOT NULL DEFAULT ''; + +-- +migrate Down + +ALTER TABLE instance DROP COLUMN aleph_version; +ALTER TABLE instance DROP COLUMN oem; diff --git a/backend/pkg/api/events_test.go b/backend/pkg/api/events_test.go index a343eceab..feb6c3438 100644 --- a/backend/pkg/api/events_test.go +++ b/backend/pkg/api/events_test.go @@ -19,7 +19,7 @@ func TestRegisterEvent_InvalidParams(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) err := a.RegisterEvent(uuid.New().String(), tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") assert.Equal(t, ErrInvalidInstance, err) @@ -33,7 +33,7 @@ func TestRegisterEvent_InvalidParams(t *testing.T) { err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") assert.Equal(t, ErrNoUpdateInProgress, err) - _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, 1000, ResultSuccess, "", "") assert.Equal(t, ErrInvalidEventTypeOrResult, err) @@ -51,10 +51,10 @@ func TestRegisterEvent_TriggerEventConsequences(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + tInstance2, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") @@ -77,7 +77,7 @@ func TestRegisterEvent_TriggerEventConsequences(t *testing.T) { instance, _ = a.GetInstance(tInstance.ID, tApp.ID) assert.Equal(t, null.IntFrom(int64(InstanceStatusComplete)), instance.Application.Status) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance2.ID, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance2.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultFailed, "", "") @@ -97,9 +97,9 @@ func TestRegisterEvent_TriggerEventConsequences_FirstUpdateAttemptFailed(t *test tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultFailed, "", "") @@ -115,10 +115,10 @@ func TestRegisterEvent_CheckSuccessResult(t *testing.T) { defer a.Close() performUpdate := func(tApp *Application, tGroup *Group, resultType int) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") @@ -159,10 +159,10 @@ func TestRegisterEvent_CheckFlatcarSuccessResult(t *testing.T) { defer a.Close() performUpdate := func(tApp *Application, tGroup *Group, resultType, expectedInstanceStatus int) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "11.0.0", "") @@ -207,10 +207,10 @@ func TestRegisterEvent_CheckFlatcarIgnoredUpdate(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group9", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) performUpdate := func(previousVersion string) { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, previousVersion, "") @@ -247,9 +247,9 @@ func TestRegisterEvent_GetEvent(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) - _, err := a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) _, err = a.GetEvent(tInstance.ID, tApp.ID, time.Now()) diff --git a/backend/pkg/api/groups_test.go b/backend/pkg/api/groups_test.go index 2ab4a2e97..0f559850b 100644 --- a/backend/pkg/api/groups_test.go +++ b/backend/pkg/api/groups_test.go @@ -168,9 +168,9 @@ func TestGetGroupsFiltered(t *testing.T) { realInstanceID := uuid.New().String() fakeInstanceID1 := "{" + uuid.New().String() + "}" fakeInstanceID2 := "{" + uuid.New().String() + "}" - _, _ = a.RegisterInstance(realInstanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID1, "", "10.0.0.1", "2.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(fakeInstanceID2, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: realInstanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + _, _ = a.RegisterInstance(Instance{ID: fakeInstanceID1, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "2.0.0")) + _, _ = a.RegisterInstance(Instance{ID: fakeInstanceID2, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) groups, err := a.GetGroups(tApp.ID, 0, 0) assert.NoError(t, err) @@ -226,7 +226,7 @@ func TestGetVersionCountTimeline(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "test_group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) instanceID := uuid.New().String() - _, _ = a.RegisterInstance(instanceID, "", "10.0.0.1", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, version)) instance, err := a.GetInstance(instanceID, tApp.ID) assert.NoError(t, err) @@ -304,14 +304,14 @@ func TestGetStatusCountTimeline(t *testing.T) { instanceID1 := uuid.New().String() instanceID2 := uuid.New().String() - _, _ = a.RegisterInstance(instanceID1, "", "10.0.0.1", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: instanceID1, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, version)) instance1, err := a.GetInstance(instanceID1, tApp.ID) assert.NoError(t, err) _ = a.grantUpdate(instance1, version) _ = a.updateInstanceStatus(instanceID1, tApp.ID, InstanceStatusComplete) - _, _ = a.RegisterInstance(instanceID2, "", "10.0.0.2", version, tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: instanceID2, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, version)) instance2, err := a.GetInstance(instanceID2, tApp.ID) assert.NoError(t, err) diff --git a/backend/pkg/api/instances.go b/backend/pkg/api/instances.go index 028da13c3..661f79b43 100644 --- a/backend/pkg/api/instances.go +++ b/backend/pkg/api/instances.go @@ -55,11 +55,13 @@ const ( // Instance represents an instance running one or more applications for which // Nebraska can provide updates. type Instance struct { - ID string `db:"id" json:"id"` - IP string `db:"ip" json:"ip"` - CreatedTs time.Time `db:"created_ts" json:"created_ts"` - Application InstanceApplication `db:"application" json:"application,omitempty"` - Alias string `db:"alias" json:"alias,omitempty"` + ID string `db:"id" json:"id"` + IP string `db:"ip" json:"ip"` + OEM string `db:"oem" json:"oem,omitempty"` + AlephVersion string `db:"aleph_version" json:"aleph_version,omitempty"` + CreatedTs time.Time `db:"created_ts" json:"created_ts"` + Application InstanceApplication `db:"application" json:"application,omitempty"` + Alias string `db:"alias" json:"alias,omitempty"` } type InstancesWithTotal struct { TotalInstances uint64 `json:"total"` @@ -82,6 +84,11 @@ type InstanceApplication struct { UpdateInProgress bool `db:"update_in_progress" json:"update_in_progress"` } +// NewInstanceApplication creates an InstanceApplication with the fields used for registration. +func NewInstanceApplication(appID, groupID, version string) InstanceApplication { + return InstanceApplication{ApplicationID: appID, GroupID: null.StringFrom(groupID), Version: version} +} + // InstanceStatusHistoryEntry represents an entry in the instance status // history. type InstanceStatusHistoryEntry struct { @@ -161,15 +168,23 @@ func sanitizeSortFilterParams(sortFilter string) string { } // RegisterInstance registers an instance into Nebraska. -func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) (*Instance, error) { - if !isValidSemver(instanceVersion) { +func (api *API) RegisterInstance(inst Instance, instApp InstanceApplication) (*Instance, error) { + if !isValidSemver(instApp.Version) { return nil, ErrInvalidSemver } + + appID := instApp.ApplicationID + groupID := instApp.GroupID.String + var err error if appID, groupID, err = api.validateApplicationAndGroup(appID, groupID); err != nil { return nil, err } + instanceAlias := inst.Alias + instanceOEM := inst.OEM + instanceAlephVersion := inst.AlephVersion + // We want to avoid having to create an unneeded DB transaction, so we check whether it // is necessary (we need it when writing into the two tables, instance and // instance_application). @@ -177,21 +192,28 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance updateInstance := true updateInstanceApplication := true - instance, err := api.GetInstance(instanceID, appID) + instance, err := api.GetInstance(inst.ID, appID) if err == nil { // Give precedence to an existing alias over an omitted or empty alias field if instanceAlias == "" { instanceAlias = instance.Alias } - // The instance exists, so we just update it if its IP or Alias changed - updateInstance = instance.IP != instanceIP || instance.Alias != instanceAlias + // Give precedence to existing OEM values over omitted or empty fields + if instanceOEM == "" { + instanceOEM = instance.OEM + } + if instanceAlephVersion == "" { + instanceAlephVersion = instance.AlephVersion + } + // The instance exists, so we just update it if its IP, Alias, OEM or AlephVersion changed + updateInstance = instance.IP != inst.IP || instance.Alias != instanceAlias || instance.OEM != instanceOEM || instance.AlephVersion != instanceAlephVersion recent := nowUTC().Add(-5 * time.Minute) // And we only update the instance_application if the latest registry is outdated or // older than what we establish as recent. updateInstanceApplication = instance.Application.LastCheckForUpdates.UTC().Before(recent) || - instance.Application.Version != instanceVersion || instance.Application.GroupID.String != groupID + instance.Application.Version != instApp.Version || instance.Application.GroupID.String != groupID // Skip updating anything unnecessary if !updateInstance && !updateInstanceApplication { @@ -200,9 +222,9 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance } upsertInstance, _, err := goqu.Insert("instance"). - Cols("id", "ip", "alias"). - Vals(goqu.Vals{instanceID, instanceIP, instanceAlias}). - OnConflict(goqu.DoUpdate("id", goqu.Record{"id": instanceID, "ip": instanceIP, "alias": instanceAlias})). + Cols("id", "ip", "alias", "oem", "aleph_version"). + Vals(goqu.Vals{inst.ID, inst.IP, instanceAlias, instanceOEM, instanceAlephVersion}). + OnConflict(goqu.DoUpdate("id", goqu.Record{"id": inst.ID, "ip": inst.IP, "alias": instanceAlias, "oem": instanceOEM, "aleph_version": instanceAlephVersion})). ToSQL() if err != nil { return nil, err @@ -210,8 +232,8 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance upsertInstanceApplication, _, err := goqu.Insert("instance_application"). Cols("instance_id", "application_id", "group_id", "version", "last_check_for_updates"). - Vals(goqu.Vals{instanceID, appID, groupID, instanceVersion, nowUTC()}). - OnConflict(goqu.DoUpdate("ON CONSTRAINT instance_application_pkey", goqu.Record{"group_id": groupID, "version": instanceVersion, "last_check_for_updates": nowUTC()})). + Vals(goqu.Vals{inst.ID, appID, groupID, instApp.Version, nowUTC()}). + OnConflict(goqu.DoUpdate("ON CONSTRAINT instance_application_pkey", goqu.Record{"group_id": groupID, "version": instApp.Version, "last_check_for_updates": nowUTC()})). ToSQL() if err != nil { return nil, err @@ -269,7 +291,7 @@ func (api *API) RegisterInstance(instanceID, instanceAlias, instanceIP, instance if err := tx.Commit(); err != nil { return nil, err } - return api.GetInstance(instanceID, appID) + return api.GetInstance(inst.ID, appID) } // GetInstance returns the instance identified by the id provided. diff --git a/backend/pkg/api/instances_test.go b/backend/pkg/api/instances_test.go index be3e766c3..d98f6a83c 100644 --- a/backend/pkg/api/instances_test.go +++ b/backend/pkg/api/instances_test.go @@ -25,47 +25,53 @@ func TestRegisterInstance(t *testing.T) { instanceID := uuid.New().String() - _, err := a.RegisterInstance("", "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, err := a.RegisterInstance(Instance{ID: "", IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.Error(t, err, "Using empty string as instance id.") - _, err = a.RegisterInstance(instanceID, "", "invalidIP", "1.0.0", tApp.ID, tGroup.ID) + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "invalidIP"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.Error(t, err, "Using an invalid instance ip.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", "invalidAppID", tGroup.ID) + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication("invalidAppID", tGroup.ID, "1.0.0")) assert.Error(t, err, "Using an invalid application id.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, "invalidGroupID", "1.0.0")) assert.Error(t, err, "Using an invalid group id.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, "invalidGroupID", "")) assert.Error(t, err, "Using an empty instance version.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "aaa1.0.0", tApp.ID, "invalidGroupID") + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, "invalidGroupID", "aaa1.0.0")) assert.Equal(t, ErrInvalidSemver, err, "Using an invalid instance version.") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID) + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup2.ID, "1.0.0")) assert.Equal(t, ErrInvalidApplicationOrGroup, err, "The group provided doesn't belong to the application provided.") - instance, err := a.RegisterInstance(instanceID, "myalias", "10.0.0.1", "1.0.0", "{"+tApp.ID+"}", "{"+tGroup.ID+"}") + instance, err := a.RegisterInstance(Instance{ID: instanceID, Alias: "myalias", IP: "10.0.0.1", OEM: "azure", AlephVersion: "2.9.1.1-r1"}, NewInstanceApplication("{"+tApp.ID+"}", "{"+tGroup.ID+"}", "1.0.0")) assert.NoError(t, err) assert.Equal(t, instanceID, instance.ID) assert.Equal(t, "myalias", instance.Alias) assert.Equal(t, "10.0.0.1", instance.IP) + assert.Equal(t, "azure", instance.OEM) + assert.Equal(t, "2.9.1.1-r1", instance.AlephVersion) - instance, err = a.RegisterInstance(instanceID, "mynewalias", "10.0.0.2", "1.0.2", tApp.ID, tGroup.ID) + instance, err = a.RegisterInstance(Instance{ID: instanceID, Alias: "mynewalias", IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.2")) assert.NoError(t, err, "Registering an already registered instance with some updates, that's fine.") assert.Equal(t, "mynewalias", instance.Alias) assert.Equal(t, "10.0.0.2", instance.IP) assert.Equal(t, "1.0.2", instance.Application.Version) + assert.Equal(t, "azure", instance.OEM, "OEM should be preserved when not provided") + assert.Equal(t, "2.9.1.1-r1", instance.AlephVersion, "AlephVersion should be preserved when not provided") - _, err = a.RegisterInstance(instanceID, "", "10.0.0.2", "1.0.2", tApp2.ID, tGroup.ID) + _, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.2"}, NewInstanceApplication(tApp2.ID, tGroup.ID, "1.0.2")) assert.Error(t, err, "Application id cannot be updated.") - instance, err = a.RegisterInstance(instanceID, "", "10.0.0.3", "1.0.3", tApp.ID, tGroup3.ID) + instance, err = a.RegisterInstance(Instance{ID: instanceID, IP: "10.0.0.3", OEM: "gcp", AlephVersion: "3.0.0"}, NewInstanceApplication(tApp.ID, tGroup3.ID, "1.0.3")) assert.NoError(t, err, "Registering an already registered instance using a different group, that's fine.") assert.Equal(t, "10.0.0.3", instance.IP) assert.Equal(t, "1.0.3", instance.Application.Version) assert.Equal(t, null.StringFrom(tGroup3.ID), instance.Application.GroupID) + assert.Equal(t, "gcp", instance.OEM, "OEM should be updated when provided") + assert.Equal(t, "3.0.0", instance.AlephVersion, "AlephVersion should be updated when provided") } func TestGetInstance(t *testing.T) { @@ -77,7 +83,7 @@ func TestGetInstance(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) _, err := a.GetInstance(uuid.New().String(), tApp.ID) assert.Error(t, err, "Using non existent instance id.") @@ -106,9 +112,9 @@ func TestGetInstances(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp.ID, PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup2.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.1")) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup2.ID, "1.0.2")) result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Version: "1.0.0", Page: 1, PerPage: 10}, testDuration) assert.NoError(t, err) @@ -162,7 +168,7 @@ func TestGetInstances(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 0, len(result.Instances)) - _, _ = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) _ = a.RegisterEvent(tInstance.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") result, err = a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Status: InstanceStatusComplete, Page: 1, PerPage: 10}, testDuration) @@ -191,12 +197,12 @@ func TestGetInstancesSearch(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.2", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.1")) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.2")) instanceAlias := "instance_alias" - _, _ = a.RegisterInstance(uuid.New().String(), instanceAlias, "10.0.0.4", "1.0.4", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), Alias: instanceAlias, IP: "10.0.0.4"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.4")) result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Page: 1, PerPage: 10, SearchFilter: "All", SearchValue: tInstance.ID}, testDuration) assert.NoError(t, err) @@ -232,7 +238,7 @@ func TestGetInstancesFiltered(t *testing.T) { instanceID4 := "8d180b2a07344406af029a4f86bd1ee3" for idx, id := range []string{instanceID1, instanceID2, instanceID3, instanceID4} { ip := fmt.Sprintf("10.0.0.%d", idx+1) - _, _ = a.RegisterInstance(id, "", ip, "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: id, IP: ip}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) } result, err := a.GetInstances(InstancesQueryParams{ApplicationID: tApp.ID, GroupID: tGroup.ID, Version: "1.0.0", Page: 1, PerPage: 10}, testDuration) @@ -264,7 +270,7 @@ func TestGetInstanceStatusHistory(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) newInstance1ID := uuid.New().String() - tInstance, _ := a.RegisterInstance(newInstance1ID, "analias", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, _ := a.RegisterInstance(Instance{ID: newInstance1ID, Alias: "analias", IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.Equal(t, tInstance.Alias, "analias") instance, err := a.GetInstance(tInstance.ID, tApp.ID) @@ -312,14 +318,14 @@ func TestUpdateInstanceStats(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID, Arch: ArchAMD64}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID), Arch: ArchAMD64}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - tInstance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - tInstance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.2", "1.0.0", tApp.ID, tGroup.ID) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.3", "1.0.1", tApp.ID, tGroup.ID) + tInstance1, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + tInstance2, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.1")) - _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance1.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.1", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance2.ID, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.1")) assert.NoError(t, err) ts := time.Now().UTC() @@ -343,13 +349,13 @@ func TestUpdateInstanceStats(t *testing.T) { // Next test case: Switch tInstance1 and tInstance2 versions to workaround the 5-minutes-rate-limiting of the check-in time and add new instance ts2 := time.Now().UTC() - _, err = a.GetUpdatePackage(tInstance1.ID, "", "10.0.0.1", "1.0.3", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance1.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.3")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance2.ID, "", "10.0.0.2", "1.0.4", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance2.ID, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.4")) assert.NoError(t, err) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.4", "1.0.5", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.4"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.5")) ts3 := time.Now().UTC() elapsed = ts3.Sub(ts2) @@ -381,7 +387,7 @@ func TestUpdateInstanceStatsNoArch(t *testing.T) { tPkg, _ := a.AddPackage(&Package{Type: PkgTypeOther, URL: "http://sample.url/pkg", Version: "12.1.0", ApplicationID: tApp.ID}) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + _, _ = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) ts := time.Now().UTC() // Use large duration to have some test coverage for durationToInterval diff --git a/backend/pkg/api/packages_floors_test.go b/backend/pkg/api/packages_floors_test.go index 06e3fcd6b..52e550099 100644 --- a/backend/pkg/api/packages_floors_test.go +++ b/backend/pkg/api/packages_floors_test.go @@ -187,12 +187,12 @@ func TestFloorRolloutPolicy(t *testing.T) { assert.NoError(t, err) // First client gets floor - pkg1, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "1000.0.0", setup.AppID, group.ID) + pkg1, err := a.GetUpdatePackage(Instance{ID: "i1", IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, group.ID, "1000.0.0")) assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg1.Version) // Second client blocked by policy - _, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "1000.0.0", setup.AppID, group.ID) + _, err = a.GetUpdatePackage(Instance{ID: "i2", IP: "10.0.0.2"}, NewInstanceApplication(setup.AppID, group.ID, "1000.0.0")) assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err) } @@ -245,11 +245,11 @@ func TestTargetAsFloor(t *testing.T) { } // Test that regular client gets the appropriate update - pkg, err := a.GetUpdatePackage("i1", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage(Instance{ID: "i1", IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "500.0.0")) assert.NoError(t, err) assert.Equal(t, "1000.0.0", pkg.Version) // Gets first floor - pkg, err = a.GetUpdatePackage("i2", "", "10.0.0.2", "2500.0.0", setup.AppID, setup.Group.ID) + pkg, err = a.GetUpdatePackage(Instance{ID: "i2", IP: "10.0.0.2"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "2500.0.0")) assert.NoError(t, err) assert.Equal(t, "3000.0.0", pkg.Version) // Gets target-floor directly } diff --git a/backend/pkg/api/updates.go b/backend/pkg/api/updates.go index 0fe4c7301..808b3da6f 100644 --- a/backend/pkg/api/updates.go +++ b/backend/pkg/api/updates.go @@ -62,13 +62,17 @@ var ( // GetUpdatePackage returns an update package for the instance/application // provided. The instance details and the application it's running will be // registered in Nebraska (or updated if it's already registered). -func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) (*Package, error) { - instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID) +func (api *API) GetUpdatePackage(inst Instance, instApp InstanceApplication) (*Package, error) { + instance, err := api.RegisterInstance(inst, instApp) if err != nil { l.Error().Err(err).Msg("GetUpdatePackage - could not register instance") return nil, ErrRegisterInstanceFailed } + instanceVersion := instApp.Version + appID := instApp.ApplicationID + groupID := instApp.GroupID.String + if instance.Application.Status.Valid { switch int(instance.Application.Status.Int64) { case InstanceStatusDownloading, InstanceStatusDownloaded, InstanceStatusInstalled: @@ -190,13 +194,17 @@ func (api *API) GetUpdatePackage(instanceID, instanceAlias, instanceIP, instance } // GetUpdatePackagesForSyncer returns all packages (floors + target) for a syncer client -func (api *API) GetUpdatePackagesForSyncer(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID string) ([]*Package, error) { - instance, err := api.RegisterInstance(instanceID, instanceAlias, instanceIP, instanceVersion, appID, groupID) +func (api *API) GetUpdatePackagesForSyncer(inst Instance, instApp InstanceApplication) ([]*Package, error) { + instance, err := api.RegisterInstance(inst, instApp) if err != nil { l.Error().Err(err).Msg("GetUpdatePackagesForSyncer - could not register instance") return nil, ErrRegisterInstanceFailed } + instanceVersion := instApp.Version + appID := instApp.ApplicationID + groupID := instApp.GroupID.String + if instance.Application.Status.Valid { switch int(instance.Application.Status.Int64) { case InstanceStatusDownloading, InstanceStatusDownloaded, InstanceStatusInstalled: diff --git a/backend/pkg/api/updates_test.go b/backend/pkg/api/updates_test.go index b93956995..949e5f9a5 100644 --- a/backend/pkg/api/updates_test.go +++ b/backend/pkg/api/updates_test.go @@ -25,28 +25,28 @@ func TestGetUpdatePackage(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) tGroup2, _ := a.AddGroup(&Group{Name: "group2", ApplicationID: tApp2.ID, ChannelID: null.StringFrom(tChannel2.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", "invalidApplicationID", tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication("invalidApplicationID", tGroup.ID, "1.0.0")) assert.Error(t, ErrInvalidApplicationOrGroup, err, "Invalid application id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, "invalidGroupID") + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, "invalidGroupID", "1.0.0")) assert.Error(t, err, "Invalid group id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", uuid.New().String(), tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(uuid.New().String(), tGroup.ID, "1.0.0")) assert.Error(t, err, "Non existent application id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, uuid.New().String()) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, uuid.New().String(), "1.0.0")) assert.Error(t, err, "Non existent group id.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup2.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup2.ID, "1.0.0")) assert.Error(t, err, "Group doesn't belong to the application provided.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp2.ID, tGroup2.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp2.ID, tGroup2.ID, "1.0.0")) assert.Equal(t, ErrNoPackageFound, err, "Group's channel has no package bound.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.1.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Instance version is up to date.") - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "1010.5.0+2016-05-27-1832", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1010.5.0+2016-05-27-1832")) assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Instance version is up to date.") } @@ -58,7 +58,7 @@ func TestGetUpdatePackage_GroupNoChannel(t *testing.T) { tApp, _ := a.AddApp(&Application{Name: "test_app", TeamID: tTeam.ID}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, PolicyUpdatesEnabled: false, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, _ = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Error(t, ErrNoPackageFound) } @@ -72,7 +72,7 @@ func TestGetUpdatePackage_UpdatesDisabled(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: false, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrUpdatesDisabled, err) } @@ -88,10 +88,10 @@ func TestGetUpdatePackage_MaxUpdatesPerPeriodLimitReached_SafeMode(t *testing.T) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: safeMode, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 2, PolicyUpdateTimeout: "60 minutes"}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err, "Safe mode is enabled, first update should be completed before letting more through.") } @@ -106,16 +106,16 @@ func TestGetUpdatePackage_MaxUpdatesPerPeriodLimitReached_LimitUpdated(t *testin tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 1, PolicyUpdateTimeout: "60 minutes"}) instanceID := uuid.New().String() - _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err, "Max 1 update per period, limit reached") tGroup.PolicyMaxUpdatesPerPeriod = 2 _ = a.UpdateGroup(tGroup) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) } @@ -136,23 +136,23 @@ func TestGetUpdatePackage_MaxUpdatesLimitsReached(t *testing.T) { newInstance1ID := uuid.New().String() - _, err := a.GetUpdatePackage(newInstance1ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: newInstance1ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxUpdatesPerPeriodLimitReached, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err, "Period interval is over, but there are still two updates not completed or failed.") _ = a.updateInstanceStatus(newInstance1ID, tApp.ID, InstanceStatusComplete) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) } @@ -172,20 +172,20 @@ func TestGetUpdatePackage_MaxTimedOutUpdatesLimitReached_SafeMode(t *testing.T) tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: periodIntervalSetting, PolicyMaxUpdatesPerPeriod: 1, PolicyUpdateTimeout: updateTimeoutSetting}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err) time.Sleep(updateTimeout - periodInterval + extraWaitPeriod) // ensure that update timeout is over - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxTimedOutUpdatesLimitReached, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrUpdatesDisabled, err) } @@ -206,20 +206,20 @@ func TestGetUpdatePackage_ResumeUpdates(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: false, PolicyPeriodInterval: periodIntervalSetting, PolicyMaxUpdatesPerPeriod: maxUpdatesPerPeriod, PolicyUpdateTimeout: updateTimeoutSetting}) - _, err := a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) time.Sleep(periodInterval + extraWaitPeriod) // ensure that period interval is over but update timeout isn't - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrMaxConcurrentUpdatesLimitReached, err) time.Sleep(updateTimeout - periodInterval + extraWaitPeriod) // ensure that update timeout is over - _, err = a.GetUpdatePackage(uuid.New().String(), "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: uuid.New().String(), IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) } @@ -233,13 +233,13 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 4, PolicyUpdateTimeout: "60 minutes"}) - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - instance2, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) - instance3, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + instance2, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) + instance3, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) - _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: instance1.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) + _, _ = a.GetUpdatePackage(Instance{ID: instance2.ID, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) + _, _ = a.GetUpdatePackage(Instance{ID: instance3.ID, IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) group, _ := a.GetGroup(tGroup.ID) assert.True(t, group.RolloutInProgress) @@ -257,8 +257,8 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { assert.Equal(t, int64(2), stats.OnHold.Int64) _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") - _, _ = a.GetUpdatePackage(instance2.ID, "", "10.0.0.2", "12.0.0", tApp.ID, tGroup.ID) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: instance2.ID, IP: "10.0.0.2"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) + _, _ = a.GetUpdatePackage(Instance{ID: instance3.ID, IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) group, _ = a.GetGroup(tGroup.ID) assert.True(t, group.RolloutInProgress) @@ -275,7 +275,7 @@ func TestGetUpdatePackage_RolloutStats(t *testing.T) { assert.Equal(t, int64(2), stats.Complete.Int64) assert.Equal(t, int64(1), stats.Error.Int64) - _, _ = a.GetUpdatePackage(instance3.ID, "", "10.0.0.3", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: instance3.ID, IP: "10.0.0.3"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) _ = a.RegisterEvent(instance3.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") group, _ = a.GetGroup(tGroup.ID) @@ -295,10 +295,10 @@ func TestGetUpdatePackage_CompletionStats(t *testing.T) { tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 4, PolicyUpdateTimeout: "60 minutes"}) addAndUpdateInstance := func() { - tInstance, err := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + tInstance, err := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) assert.NoError(t, err) - _, err = a.GetUpdatePackage(tInstance.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: tInstance.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) err = a.RegisterEvent(tInstance.ID, "{"+tApp.ID+"}", tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "11.0.0", "") @@ -329,19 +329,19 @@ func TestGetUpdatePackage_CompletionStats(t *testing.T) { // This instance has the group's current package's version already and reports no status. // We need to make sure it doesn't show up as undefined. - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, tPkg.Version)) stats, _ = a.GetGroupInstancesStats(tGroup.ID, testDuration) assert.Equal(t, int64(0), stats.Undefined.Int64) assert.Equal(t, int64(2), stats.Complete.Int64) // Just ensuring that a call for getting an update in an already up to date instance won't change its status - _, err := a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", tPkg.Version, tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: instance1.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, tPkg.Version)) assert.Error(t, err, "nebraska: no update package available") // This version has a version different from the group's current one, and reports no status, so the // status should be undefined. - _, err = a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "0.1.0", tApp.ID, tGroup.ID) + _, err = a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "0.1.0")) assert.NoError(t, err) stats, _ = a.GetGroupInstancesStats(tGroup.ID, testDuration) @@ -374,10 +374,10 @@ func TestGetUpdatePackage_UpdateInProgressOnInstance(t *testing.T) { instanceID := uuid.New().String() - p1, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + p1, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) - p2, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + p2, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) assert.Equal(t, p1, p2) @@ -388,7 +388,7 @@ func TestGetUpdatePackage_UpdateInProgressOnInstance(t *testing.T) { err = a.updateInstanceStatus(instanceID, tApp.ID, InstanceStatusDownloading) assert.NoError(t, err) - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.Equal(t, ErrUpdateInProgressOnInstance, err) } @@ -404,7 +404,7 @@ func TestGetUpdatePackage_CheckVersionForGrantedUpdate(t *testing.T) { instanceID := uuid.New().String() - _, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) assert.NoError(t, err) instance, err := a.GetInstance(instanceID, tApp.ID) @@ -414,7 +414,7 @@ func TestGetUpdatePackage_CheckVersionForGrantedUpdate(t *testing.T) { assert.Equal(t, "12.1.0", instance.Application.LastUpdateVersion.String) assert.Equal(t, "12.0.0", instance.Application.Version) - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "12.1.0", tApp.ID, tGroup.ID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.1.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err) instanceStatusHistory, err := a.GetInstanceStatusHistory(instanceID, tApp.ID, tGroup.ID, 1) @@ -433,9 +433,9 @@ func TestGetUpdatePackage_InstanceStatusHistory(t *testing.T) { tChannel, _ := a.AddChannel(&Channel{Name: "test_channel", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID)}) tGroup, _ := a.AddGroup(&Group{Name: "test_group", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: true, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: 3, PolicyUpdateTimeout: "60 minutes"}) - instance1, _ := a.RegisterInstance(uuid.New().String(), "", "10.0.0.1", "1.0.0", tApp.ID, tGroup.ID) + instance1, _ := a.RegisterInstance(Instance{ID: uuid.New().String(), IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "1.0.0")) - _, _ = a.GetUpdatePackage(instance1.ID, "", "10.0.0.1", "12.0.0", tApp.ID, tGroup.ID) + _, _ = a.GetUpdatePackage(Instance{ID: instance1.ID, IP: "10.0.0.1"}, NewInstanceApplication(tApp.ID, tGroup.ID, "12.0.0")) _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateDownloadStarted, ResultSuccess, "", "") _ = a.RegisterEvent(instance1.ID, tApp.ID, tGroup.ID, EventUpdateComplete, ResultSuccessReboot, "", "") @@ -462,7 +462,7 @@ func TestMultiStepUpdateProgression(t *testing.T) { groupID := setup.Group.ID // Step 1: Instance at 1000 → should get first floor (2000) - pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID) + pkg, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "1000.0.0")) assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get first floor") @@ -472,33 +472,33 @@ func TestMultiStepUpdateProgression(t *testing.T) { assert.Equal(t, "2000.0.0", instance.Application.LastUpdateVersion.String) // Step 2: Still at 1000 (already-granted) → should get 2000 again - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "1000.0.0")) assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get same floor when not updated") // Step 3: Instance updates to 2000 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "2000.0.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when floor reached") instance, _ = a.GetInstance(instanceID, appID) assert.Equal(t, InstanceStatusComplete, int(instance.Application.Status.Int64)) // Step 4: Instance at 2000 → should get second floor (2500) - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2000.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "2000.0.0")) assert.NoError(t, err) assert.Equal(t, "2500.0.0", pkg.Version, "Should get second floor") // Step 5: Instance updates to 2500 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "2500.0.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err) // Step 6: Instance at 2500 → should get target (3000) - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "2500.0.0", appID, groupID) + pkg, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "2500.0.0")) assert.NoError(t, err) assert.Equal(t, "3000.0.0", pkg.Version, "Should get target after all floors") // Step 7: Instance updates to 3000 → should complete - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "3000.0.0", appID, groupID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(appID, groupID, "3000.0.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when target reached") } @@ -511,7 +511,7 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { instanceID := "old-instance" // Get initial update - pkg, err := a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "1000.0.0")) assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version) @@ -520,12 +520,12 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { assert.Equal(t, "2000.0.0", instance.Application.LastUpdateVersion.String) // Simulate old instance: clear LastUpdateVersion but keep UpdateGranted status - _, err = a.db.Exec(`UPDATE instance_application SET last_update_version = NULL + _, err = a.db.Exec(`UPDATE instance_application SET last_update_version = NULL WHERE instance_id = $1 AND application_id = $2`, instanceID, setup.AppID) assert.NoError(t, err) // Call with already-granted but no LastUpdateVersion - should complete and return error - _, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + _, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "1000.0.0")) assert.Equal(t, ErrNoUpdatePackageAvailable, err, "Should complete when LastUpdateVersion is NULL") // Verify status is now Complete @@ -533,7 +533,7 @@ func TestAlreadyGrantedWithoutLastUpdateVersion(t *testing.T) { assert.Equal(t, InstanceStatusComplete, int(instance.Application.Status.Int64)) // Next call should go through normal grant path - pkg, err = a.GetUpdatePackage(instanceID, "", "10.0.0.1", "1000.0.0", setup.AppID, setup.Group.ID) + pkg, err = a.GetUpdatePackage(Instance{ID: instanceID, IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "1000.0.0")) assert.NoError(t, err) assert.Equal(t, "2000.0.0", pkg.Version, "Should get floor through normal grant path") @@ -551,12 +551,12 @@ func TestSafetyRulesValidation(t *testing.T) { setup := setupFloors(t, a, "safety-test", []string{"1000.0.0", "2000.0.0"}, "3000.0.0") // Register an instance to test update behavior - _, err := a.RegisterInstance("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + _, err := a.RegisterInstance(Instance{ID: "safety-instance", IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "500.0.0")) assert.NoError(t, err) t.Run("floor_never_blacklisted_for_own_channel", func(t *testing.T) { // Get update should work normally - pkg, err := a.GetUpdatePackage("safety-instance", "", "10.0.0.1", "500.0.0", setup.AppID, setup.Group.ID) + pkg, err := a.GetUpdatePackage(Instance{ID: "safety-instance", IP: "10.0.0.1"}, NewInstanceApplication(setup.AppID, setup.Group.ID, "500.0.0")) assert.NoError(t, err) assert.Equal(t, "1000.0.0", pkg.Version) diff --git a/backend/pkg/omaha/omaha.go b/backend/pkg/omaha/omaha.go index 9ab2e8f93..34901a8a3 100644 --- a/backend/pkg/omaha/omaha.go +++ b/backend/pkg/omaha/omaha.go @@ -147,8 +147,17 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o respApp.AddEvent() } + inst := api.Instance{ + ID: reqApp.MachineID, + Alias: reqApp.MachineAlias, + IP: ip, + OEM: reqApp.OEM, + AlephVersion: reqApp.AlephVersion, + } + instApp := api.NewInstanceApplication(appID, group, reqApp.Version) + if reqApp.Ping != nil { - if _, err := h.crAPI.RegisterInstance(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group); err != nil { + if _, err := h.crAPI.RegisterInstance(inst, instApp); err != nil { l.Debug().Str("machineId", reqApp.MachineID).Msgf("processPing error %s", err.Error()) } respApp.AddPing() @@ -157,7 +166,7 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o if reqApp.UpdateCheck != nil { if isSyncerClient(omahaReq) { // Syncer - get all packages - packages, err := h.crAPI.GetUpdatePackagesForSyncer(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group) + packages, err := h.crAPI.GetUpdatePackagesForSyncer(inst, instApp) if err != nil { if err == api.ErrNoUpdatePackageAvailable || err == api.ErrUpdateGrantFailed { respApp.AddUpdateCheck(omahaSpec.NoUpdate) @@ -196,7 +205,7 @@ func (h *Handler) buildOmahaResponse(omahaReq *omahaSpec.Request, ip string) (*o h.prepareMultiManifestUpdateCheck(respApp, packages) } else { // Regular client - get single package - pkg, err := h.crAPI.GetUpdatePackage(reqApp.MachineID, reqApp.MachineAlias, ip, reqApp.Version, appID, group) + pkg, err := h.crAPI.GetUpdatePackage(inst, instApp) if err != nil { if err == api.ErrNoUpdatePackageAvailable || err == api.ErrUpdateGrantFailed { respApp.AddUpdateCheck(omahaSpec.NoUpdate) diff --git a/backend/test/api/instance_test.go b/backend/test/api/instance_test.go index dbee21dea..106cadc9d 100644 --- a/backend/test/api/instance_test.go +++ b/backend/test/api/instance_test.go @@ -92,7 +92,7 @@ func TestGetInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // fetch instance from API @@ -116,7 +116,7 @@ func TestGetInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // fetch instance from API @@ -143,11 +143,11 @@ func TestGetInstanceStatusHistory(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(api.Instance{ID: instanceDB.ID, Alias: instanceDB.Alias, IP: instanceDB.IP}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, instanceDB.Application.Version)) require.NoError(t, err) // create event for instance @@ -177,11 +177,11 @@ func TestGetInstanceStatusHistory(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(api.Instance{ID: instanceDB.ID, Alias: instanceDB.Alias, IP: instanceDB.IP}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, instanceDB.Application.Version)) require.NoError(t, err) // create event for instance @@ -214,7 +214,7 @@ func TestUpdateInstance(t *testing.T) { // create instance for app instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // fetch instance from API diff --git a/backend/test/api/stats_test.go b/backend/test/api/stats_test.go index 9a3770b14..f818b4b52 100644 --- a/backend/test/api/stats_test.go +++ b/backend/test/api/stats_test.go @@ -121,11 +121,11 @@ func TestGroupStatusTimeline(t *testing.T) { // create instance for app[0] instanceID := uuid.New() - instanceDB, err := db.RegisterInstance(instanceID.String(), "alias", "0.0.0.0", "0.0.1", app.ID, app.Groups[0].ID) + instanceDB, err := db.RegisterInstance(api.Instance{ID: instanceID.String(), Alias: "alias", IP: "0.0.0.0"}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, "0.0.1")) require.NoError(t, err) // GetUpdatePackage - _, err = db.GetUpdatePackage(instanceDB.ID, instanceDB.Alias, instanceDB.IP, instanceDB.Application.Version, app.ID, app.Groups[0].ID) + _, err = db.GetUpdatePackage(api.Instance{ID: instanceDB.ID, Alias: instanceDB.Alias, IP: instanceDB.IP}, api.NewInstanceApplication(app.ID, app.Groups[0].ID, instanceDB.Application.Version)) require.NoError(t, err) // create event for instance