Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
return filteredModels
}

func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
var filtered GalleryElements[T]
for _, m := range gm {
for _, t := range m.GetTags() {
if strings.EqualFold(t, tag) {
filtered = append(filtered, m)
break
}
}
}
return filtered
}

func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
sort.Slice(gm, func(i, j int) bool {
if sortOrder == "asc" {
Expand Down
62 changes: 62 additions & 0 deletions core/gallery/gallery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,68 @@ var _ = Describe("Gallery", func() {
})
})

Describe("GalleryElements FilterByTag", func() {
var elements GalleryElements[*GalleryModel]

BeforeEach(func() {
elements = GalleryElements[*GalleryModel]{
{
Metadata: Metadata{
Name: "whisper-asr",
Tags: []string{"asr", "stt"},
},
},
{
Metadata: Metadata{
Name: "image-diffusers",
Tags: []string{"sd", "image"},
},
},
{
Metadata: Metadata{
Name: "another-stt-model",
Tags: []string{"stt", "audio"},
},
},
{
Metadata: Metadata{
Name: "no-tags-model",
Tags: []string{},
},
},
}
})

It("should return exact tag matches only", func() {
results := elements.FilterByTag("asr")
Expect(results).To(HaveLen(1))
Expect(results[0].GetName()).To(Equal("whisper-asr"))
})

It("should not match substrings (image-diffusers must NOT match 'asr')", func() {
results := elements.FilterByTag("asr")
for _, r := range results {
Expect(r.GetName()).NotTo(Equal("image-diffusers"))
}
})

It("should be case insensitive", func() {
results := elements.FilterByTag("ASR")
Expect(results).To(HaveLen(1))
Expect(results[0].GetName()).To(Equal("whisper-asr"))
})

It("should return multiple models with the same tag", func() {
results := elements.FilterByTag("stt")
Expect(results).To(HaveLen(2))
})

It("should return empty when no models have the tag", func() {
results := elements.FilterByTag("nonexistent")
Expect(results).To(HaveLen(0))
})
})

Describe("GalleryElements SortByName", func() {
var elements GalleryElements[*GalleryModel]

Expand Down
5 changes: 2 additions & 3 deletions core/http/react-ui/src/pages/Models.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,12 @@ export default function Models() {
const searchVal = params.search !== undefined ? params.search : search
const filterVal = params.filter !== undefined ? params.filter : filter
const sortVal = params.sort !== undefined ? params.sort : sort
// Combine search text and filter into 'term' param
const term = searchVal || filterVal || ''
const queryParams = {
page: params.page || page,
items: 9,
}
if (term) queryParams.term = term
if (filterVal) queryParams.tag = filterVal
if (searchVal) queryParams.term = searchVal
if (sortVal) {
queryParams.sort = sortVal
queryParams.order = params.order || order
Expand Down
8 changes: 8 additions & 0 deletions core/http/routes/ui_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
// Model Gallery APIs
app.GET("/api/models", func(c echo.Context) error {
term := c.QueryParam("term")
tag := c.QueryParam("tag")
page := c.QueryParam("page")
if page == "" {
page = "1"
Expand Down Expand Up @@ -239,6 +240,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
}
sort.Strings(tags)

if tag != "" {
models = gallery.GalleryElements[*gallery.GalleryModel](models).FilterByTag(tag)
}
if term != "" {
models = gallery.GalleryElements[*gallery.GalleryModel](models).Search(term)
}
Expand Down Expand Up @@ -726,6 +730,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
// Backend Gallery APIs
app.GET("/api/backends", func(c echo.Context) error {
term := c.QueryParam("term")
tag := c.QueryParam("tag")
page := c.QueryParam("page")
if page == "" {
page = "1"
Expand Down Expand Up @@ -756,6 +761,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
}
sort.Strings(tags)

if tag != "" {
backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).FilterByTag(tag)
}
if term != "" {
backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).Search(term)
}
Expand Down
Loading