diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index d79b3e569973..607ad37d1923 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -92,6 +92,19 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] { return filteredModels } +func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] { + var filtered GalleryElements[T] + for _, m := range gm { + for _, t := range m.GetTags() { + if strings.EqualFold(t, tag) { + filtered = append(filtered, m) + break + } + } + } + return filtered +} + func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] { sort.Slice(gm, func(i, j int) bool { if sortOrder == "asc" { diff --git a/core/gallery/gallery_test.go b/core/gallery/gallery_test.go index e2be9b19ba40..ef09c076ddde 100644 --- a/core/gallery/gallery_test.go +++ b/core/gallery/gallery_test.go @@ -159,6 +159,68 @@ var _ = Describe("Gallery", func() { }) }) + Describe("GalleryElements FilterByTag", func() { + var elements GalleryElements[*GalleryModel] + + BeforeEach(func() { + elements = GalleryElements[*GalleryModel]{ + { + Metadata: Metadata{ + Name: "whisper-asr", + Tags: []string{"asr", "stt"}, + }, + }, + { + Metadata: Metadata{ + Name: "image-diffusers", + Tags: []string{"sd", "image"}, + }, + }, + { + Metadata: Metadata{ + Name: "another-stt-model", + Tags: []string{"stt", "audio"}, + }, + }, + { + Metadata: Metadata{ + Name: "no-tags-model", + Tags: []string{}, + }, + }, + } + }) + + It("should return exact tag matches only", func() { + results := elements.FilterByTag("asr") + Expect(results).To(HaveLen(1)) + Expect(results[0].GetName()).To(Equal("whisper-asr")) + }) + + It("should not match substrings (image-diffusers must NOT match 'asr')", func() { + results := elements.FilterByTag("asr") + for _, r := range results { + Expect(r.GetName()).NotTo(Equal("image-diffusers")) + } + }) + + It("should be case insensitive", func() { + results := elements.FilterByTag("ASR") + Expect(results).To(HaveLen(1)) + Expect(results[0].GetName()).To(Equal("whisper-asr")) + }) + + It("should return multiple models with the same tag", func() { + results := elements.FilterByTag("stt") + Expect(results).To(HaveLen(2)) + }) + + It("should return empty when no models have the tag", func() { + results := elements.FilterByTag("nonexistent") + Expect(results).To(HaveLen(0)) + }) + }) + Describe("GalleryElements SortByName", func() { var elements GalleryElements[*GalleryModel] diff --git a/core/http/react-ui/src/pages/Models.jsx b/core/http/react-ui/src/pages/Models.jsx index 029cd71ff124..8797c800971c 100644 --- a/core/http/react-ui/src/pages/Models.jsx +++ b/core/http/react-ui/src/pages/Models.jsx @@ -130,13 +130,12 @@ export default function Models() { const filterVal = params.filter !== undefined ? params.filter : filter const sortVal = params.sort !== undefined ? params.sort : sort const backendVal = params.backendFilter !== undefined ? params.backendFilter : backendFilter - // Combine search text and filter into 'term' param - const term = searchVal || filterVal || '' const queryParams = { page: params.page || page, items: 9, } - if (term) queryParams.term = term + if (filterVal) queryParams.tag = filterVal + if (searchVal) queryParams.term = searchVal if (backendVal) queryParams.backend = backendVal if (sortVal) { queryParams.sort = sortVal diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 16cc104995e5..81d9b4275ef0 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -210,6 +210,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model // Model Gallery APIs (admin only) app.GET("/api/models", func(c echo.Context) error { term := c.QueryParam("term") + tag := c.QueryParam("tag") page := c.QueryParam("page") if page == "" { page = "1" @@ -253,6 +254,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } sort.Strings(backendNames) + if tag != "" { + models = gallery.GalleryElements[*gallery.GalleryModel](models).FilterByTag(tag) + } if term != "" { models = gallery.GalleryElements[*gallery.GalleryModel](models).Search(term) } @@ -776,6 +780,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model // Backend Gallery APIs app.GET("/api/backends", func(c echo.Context) error { term := c.QueryParam("term") + tag := c.QueryParam("tag") page := c.QueryParam("page") if page == "" { page = "1" @@ -806,6 +811,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } sort.Strings(tags) + if tag != "" { + backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).FilterByTag(tag) + } if term != "" { backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).Search(term) }