diff --git a/.gitignore b/.gitignore index e1cee41ee4..c262aa0c23 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ coverage* crd-helm-wrapper cmd/vmcp/__debug_bin* /vmcp +.worktrees/ diff --git a/cmd/thv-memory/config.go b/cmd/thv-memory/config.go new file mode 100644 index 0000000000..98cda699d9 --- /dev/null +++ b/cmd/thv-memory/config.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package main is the entry point for the ToolHive memory MCP server. +package main + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +const ( + providerOllama = "ollama" +) + +// Config is the memory server configuration, loaded from memory-server.yaml. +type Config struct { + Storage StorageConfig `yaml:"storage"` + Vector VectorConfig `yaml:"vector"` + Embedder EmbedderConfig `yaml:"embedder"` + Server ServerConfig `yaml:"server"` +} + +// StorageConfig configures the Store backend. +type StorageConfig struct { + Provider string `yaml:"provider"` // sqlite (default) + DSN string `yaml:"dsn"` +} + +// VectorConfig configures the VectorStore backend. +type VectorConfig struct { + Provider string `yaml:"provider"` // sqlite-vec (default) | qdrant | pgvector + URL string `yaml:"url"` +} + +// EmbedderConfig configures the Embedder backend. +type EmbedderConfig struct { + Provider string `yaml:"provider"` // ollama (default) | openai + URL string `yaml:"url"` + Model string `yaml:"model"` +} + +// ServerConfig configures the MCP server itself. +type ServerConfig struct { + Name string `yaml:"name"` + Version string `yaml:"version"` + Host string `yaml:"host"` // default 0.0.0.0 + Port int `yaml:"port"` // default 8080 + LifecycleHours int `yaml:"lifecycle_interval_hours"` // default 24 +} + +// LoadConfig reads and validates config from path. The path is operator-supplied +// and expected to be a trusted config file location. +func LoadConfig(path string) (*Config, error) { + // G304: path is an operator-supplied config file, not user input. + data, err := os.ReadFile(path) //nolint:gosec + if err != nil { + return nil, fmt.Errorf("reading config: %w", err) + } + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + applyStorageDefaults(&cfg) + applyEmbedderDefaults(&cfg) + applyServerDefaults(&cfg) + return &cfg, nil +} + +func applyStorageDefaults(cfg *Config) { + if cfg.Storage.Provider == "" { + cfg.Storage.Provider = "sqlite" + } + if cfg.Storage.DSN == "" && cfg.Storage.Provider == "sqlite" { + cfg.Storage.DSN = "/data/memory.db" + } + if cfg.Vector.Provider == "" { + cfg.Vector.Provider = "sqlite-vec" + } +} + +func applyEmbedderDefaults(cfg *Config) { + if cfg.Embedder.Provider == "" { + cfg.Embedder.Provider = providerOllama + } + if cfg.Embedder.Model == "" { + cfg.Embedder.Model = "nomic-embed-text" + } + if cfg.Embedder.URL == "" && cfg.Embedder.Provider == providerOllama { + cfg.Embedder.URL = "http://localhost:11434" + } +} + +func applyServerDefaults(cfg *Config) { + if cfg.Server.Name == "" { + cfg.Server.Name = "toolhive-memory" + } + if cfg.Server.Version == "" { + cfg.Server.Version = "0.1.0" + } + if cfg.Server.Host == "" { + cfg.Server.Host = "0.0.0.0" + } + if cfg.Server.Port <= 0 { + cfg.Server.Port = 8080 + } + if cfg.Server.LifecycleHours <= 0 { + cfg.Server.LifecycleHours = 24 + } +} diff --git a/cmd/thv-memory/integration_test.go b/cmd/thv-memory/integration_test.go new file mode 100644 index 0000000000..923abbd140 --- /dev/null +++ b/cmd/thv-memory/integration_test.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main_test + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/stacklok/toolhive/pkg/memory" + memorysqlite "github.com/stacklok/toolhive/pkg/memory/sqlite" +) + +// fakeEmbedder returns a deterministic embedding for testing without a real model server. +type fakeEmbedder struct{} + +func (*fakeEmbedder) Embed(_ context.Context, text string) ([]float32, error) { + v := []float32{0, 0, 0} + for i, c := range text { + if i >= 3 { + break + } + v[i] = float32(c) / 128.0 + } + return v, nil +} + +func (*fakeEmbedder) Dimensions() int { return 3 } + +func TestIntegration_RememberSearchForget(t *testing.T) { + t.Parallel() + dir := t.TempDir() + resolved, _ := filepath.EvalSymlinks(dir) + db, err := memorysqlite.Open(context.Background(), filepath.Join(resolved, "test.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + store := memorysqlite.NewStore(db) + vectors := memorysqlite.NewVectorStore(db) + svc, err := memory.NewService(store, vectors, &fakeEmbedder{}, zaptest.NewLogger(t)) + require.NoError(t, err) + + ctx := context.Background() + + r, err := svc.Remember(ctx, memory.RememberInput{ + Content: "deploy to us-east-1", + Type: memory.TypeSemantic, + Author: memory.AuthorHuman, + }) + require.NoError(t, err) + require.NotEmpty(t, r.MemoryID) + require.Empty(t, r.Conflicts) + + results, err := svc.Search(ctx, "deploy to us-east-1", nil, 5) + require.NoError(t, err) + require.NotEmpty(t, results) + require.Equal(t, "deploy to us-east-1", results[0].Entry.Content) + + entry, err := store.Get(ctx, r.MemoryID) + require.NoError(t, err) + require.Equal(t, 1, entry.AccessCount) + + require.NoError(t, store.Delete(ctx, r.MemoryID)) + _, err = store.Get(ctx, r.MemoryID) + require.ErrorIs(t, err, memory.ErrNotFound) +} + +func TestIntegration_ConflictDetection(t *testing.T) { + t.Parallel() + dir := t.TempDir() + resolved, _ := filepath.EvalSymlinks(dir) + db, err := memorysqlite.Open(context.Background(), filepath.Join(resolved, "test2.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + store := memorysqlite.NewStore(db) + vectors := memorysqlite.NewVectorStore(db) + svc, err := memory.NewService(store, vectors, &fakeEmbedder{}, zaptest.NewLogger(t)) + require.NoError(t, err) + + ctx := context.Background() + + r1, err := svc.Remember(ctx, memory.RememberInput{ + Content: "auth port 8080", + Type: memory.TypeSemantic, + Author: memory.AuthorHuman, + }) + require.NoError(t, err) + require.NotEmpty(t, r1.MemoryID) + + // fakeEmbedder hashes first 3 chars — "aut" maps to same vector for both, + // so "auth port 9090" will have cosine similarity 1.0 with "auth port 8080". + r2, err := svc.Remember(ctx, memory.RememberInput{ + Content: "auth port 9090", + Type: memory.TypeSemantic, + Author: memory.AuthorAgent, + }) + require.NoError(t, err) + require.Empty(t, r2.MemoryID) + require.NotEmpty(t, r2.Conflicts) + + r3, err := svc.Remember(ctx, memory.RememberInput{ + Content: "auth port 9090", + Type: memory.TypeSemantic, + Author: memory.AuthorHuman, + Force: true, + }) + require.NoError(t, err) + require.NotEmpty(t, r3.MemoryID) +} diff --git a/cmd/thv-memory/lifecycle/job.go b/cmd/thv-memory/lifecycle/job.go new file mode 100644 index 0000000000..bf1469dd01 --- /dev/null +++ b/cmd/thv-memory/lifecycle/job.go @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package lifecycle provides the background maintenance job for memory entries. +package lifecycle + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// StalenessAuditThreshold is the score above which entries are logged as audit candidates. +const StalenessAuditThreshold = float32(0.8) + +// Job runs periodic maintenance on the memory store: expiring TTL'd entries +// and recomputing trust/staleness scores. +type Job struct { + store memory.Store + log *zap.Logger +} + +// New creates a new lifecycle Job. +func New(store memory.Store, log *zap.Logger) *Job { + return &Job{store: store, log: log} +} + +// Run starts the background job, ticking at the given interval until ctx is cancelled. +func (j *Job) Run(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := j.RunOnce(ctx); err != nil { + j.log.Warn("lifecycle job error", zap.Error(err)) + } + } + } +} + +// RunOnce executes one maintenance pass: expire TTL'd entries, update scores. +func (j *Job) RunOnce(ctx context.Context) error { + if err := j.expireEntries(ctx); err != nil { + return err + } + return j.recomputeScores(ctx) +} + +func (j *Job) expireEntries(ctx context.Context) error { + expired, err := j.store.ListExpired(ctx) + if err != nil { + return err + } + for _, e := range expired { + if err := j.store.Archive(ctx, e.ID, memory.ArchiveReasonExpired, ""); err != nil { + j.log.Warn("failed to archive expired entry", zap.String("id", e.ID), zap.Error(err)) + } + } + return nil +} + +func (j *Job) recomputeScores(ctx context.Context) error { + entries, err := j.store.ListActive(ctx) + if err != nil { + return err + } + for _, e := range entries { + trust := memory.ComputeTrustScore(e) + staleness := memory.ComputeStalenessScore(e) + if err := j.store.UpdateScores(ctx, e.ID, trust, staleness); err != nil { + j.log.Warn("failed to update scores", zap.String("id", e.ID), zap.Error(err)) + } + if staleness >= StalenessAuditThreshold { + j.log.Debug("high staleness entry", zap.String("id", e.ID), zap.Float32("staleness", staleness)) + } + } + return nil +} diff --git a/cmd/thv-memory/lifecycle/job_test.go b/cmd/thv-memory/lifecycle/job_test.go new file mode 100644 index 0000000000..a4302e3eaa --- /dev/null +++ b/cmd/thv-memory/lifecycle/job_test.go @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package lifecycle_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + + "github.com/stacklok/toolhive/cmd/thv-memory/lifecycle" + "github.com/stacklok/toolhive/pkg/memory" + "github.com/stacklok/toolhive/pkg/memory/mocks" +) + +func TestJob_RunOnce_ExpiresEntries(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + + expired := memory.Entry{ + ID: "mem_expired", + CreatedAt: time.Now().Add(-48 * time.Hour), + } + store.EXPECT().ListExpired(gomock.Any()).Return([]memory.Entry{expired}, nil) + store.EXPECT().Archive(gomock.Any(), "mem_expired", memory.ArchiveReasonExpired, "").Return(nil) + store.EXPECT().ListActive(gomock.Any()).Return(nil, nil) + + job := lifecycle.New(store, zaptest.NewLogger(t)) + err := job.RunOnce(context.Background()) + require.NoError(t, err) +} + +func TestJob_RunOnce_UpdatesScores(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + + entry := memory.Entry{ + ID: "mem_active", + Author: memory.AuthorHuman, + CreatedAt: time.Now(), + } + store.EXPECT().ListExpired(gomock.Any()).Return(nil, nil) + store.EXPECT().ListActive(gomock.Any()).Return([]memory.Entry{entry}, nil) + store.EXPECT().UpdateScores(gomock.Any(), "mem_active", gomock.Any(), gomock.Any()).Return(nil) + + job := lifecycle.New(store, zaptest.NewLogger(t)) + err := job.RunOnce(context.Background()) + require.NoError(t, err) +} diff --git a/cmd/thv-memory/main.go b/cmd/thv-memory/main.go new file mode 100644 index 0000000000..52704b3f3a --- /dev/null +++ b/cmd/thv-memory/main.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + mcpserver "github.com/mark3labs/mcp-go/server" + "go.uber.org/zap" + + "github.com/stacklok/toolhive/cmd/thv-memory/lifecycle" + "github.com/stacklok/toolhive/cmd/thv-memory/resources" + "github.com/stacklok/toolhive/pkg/memory" + "github.com/stacklok/toolhive/pkg/memory/embedder/ollama" + memorysqlite "github.com/stacklok/toolhive/pkg/memory/sqlite" +) + +const ( + readHeaderTimeout = 10 * time.Second + readTimeout = 30 * time.Second + // writeTimeout is intentionally zero: SSE streams for MCP can be long-lived. + writeTimeout = 0 + idleTimeout = 120 * time.Second + shutdownTimeout = 10 * time.Second +) + +func main() { + cfgPath := os.Getenv("MEMORY_CONFIG") + if cfgPath == "" { + cfgPath = "/config/memory-server.yaml" + } + + cfg, err := LoadConfig(cfgPath) + if err != nil { + log.Fatalf("loading config: %v", err) + } + + logger, err := zap.NewProduction() + if err != nil { + log.Fatalf("creating logger: %v", err) + } + defer logger.Sync() //nolint:errcheck + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + db, err := memorysqlite.Open(ctx, cfg.Storage.DSN) + if err != nil { + logger.Fatal("opening database", zap.Error(err)) + } + defer db.Close() //nolint:errcheck + + store := memorysqlite.NewStore(db) + vectors := memorysqlite.NewVectorStore(db) + + var embedder memory.Embedder + switch cfg.Embedder.Provider { + case providerOllama: + embedder, err = ollama.New(cfg.Embedder.URL, cfg.Embedder.Model) + if err != nil { + logger.Fatal("creating ollama embedder", zap.Error(err)) + } + default: + logger.Fatal("unsupported embedder provider", zap.String("provider", cfg.Embedder.Provider)) + } + + svc, err := memory.NewService(store, vectors, embedder, logger) + if err != nil { + logger.Fatal("creating memory service", zap.Error(err)) + } + + job := lifecycle.New(store, logger) + go job.Run(ctx, time.Duration(cfg.Server.LifecycleHours)*time.Hour) + + s := newMCPServer(cfg, svc, store) + LoadExistingResources(ctx, s, store, logger) + + resourceAPI := resources.NewHandler( + store, vectors, embedder, + func(e memory.Entry) { RegisterResourceEntry(s, store, e) }, + func(id string) { UnregisterResourceEntry(s, id) }, + logger, + ) + + if err := serve(ctx, cfg, s, resourceAPI, logger); err != nil { + logger.Error("server exited with error", zap.Error(err)) + os.Exit(1) + } +} + +func serve( + ctx context.Context, + cfg *Config, + s *mcpserver.MCPServer, + resourceAPI http.Handler, + logger *zap.Logger, +) error { + addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("creating listener: %w", err) + } + + handler := newHandler(s, resourceAPI, logger) + httpServer := &http.Server{ + Handler: handler, + ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: idleTimeout, + } + + errCh := make(chan error, 1) + go func() { + logger.Info("memory MCP server listening", + zap.String("addr", listener.Addr().String()), + zap.String("mcp_endpoint", mcpEndpointPath), + zap.String("resource_api", "/api/resources"), + ) + if err := httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + shutCtx, shutCancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer shutCancel() + return httpServer.Shutdown(shutCtx) + case err := <-errCh: + return err + } +} diff --git a/cmd/thv-memory/resources/api.go b/cmd/thv-memory/resources/api.go new file mode 100644 index 0000000000..94fa20277e --- /dev/null +++ b/cmd/thv-memory/resources/api.go @@ -0,0 +1,327 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package resources provides the management REST API for UI-managed resource +// entries. Resource entries are stored as memory.Entry values with +// source=resource and are read-only to agents via MCP tools. +// +// Routes (all under /api/resources): +// +// POST /api/resources — create resource, embed content, register in MCP +// GET /api/resources — list resources (paginated via ?limit=&offset=) +// GET /api/resources/{id} — get single resource +// PUT /api/resources/{id} — update content (re-embeds), update MCP listing +// DELETE /api/resources/{id} — delete resource, unregister from MCP +package resources + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "github.com/stacklok/toolhive/pkg/memory" +) + +const ( + defaultListLimit = 50 + maxListLimit = 200 +) + +// Handler is the management REST API handler for resource entries. +// +// registerFn and unregisterFn are injected by the caller (package main) so +// that the resources package does not import the MCP server package directly. +// They keep the MCP resource listing in sync with the database: registerFn is +// called after a create or update, unregisterFn after a delete. +type Handler struct { + store memory.Store + vectors memory.VectorStore + embedder memory.Embedder + registerFn func(memory.Entry) + unregisterFn func(id string) + log *zap.Logger +} + +// NewHandler creates a new resource management Handler. +// +// registerFn and unregisterFn are the package-level MCP sync helpers from +// server.go (RegisterResourceEntry / UnregisterResourceEntry), wrapped as +// closures so this package does not need to import the MCP server package. +func NewHandler( + store memory.Store, + vectors memory.VectorStore, + embedder memory.Embedder, + registerFn func(memory.Entry), + unregisterFn func(id string), + log *zap.Logger, +) *Handler { + return &Handler{ + store: store, + vectors: vectors, + embedder: embedder, + registerFn: registerFn, + unregisterFn: unregisterFn, + log: log, + } +} + +// ServeHTTP routes /api/resources requests. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Strip /api/resources prefix to get the remaining path. + path := strings.TrimPrefix(r.URL.Path, "/api/resources") + path = strings.TrimPrefix(path, "/") + + switch { + case path == "" && r.Method == http.MethodPost: + h.create(w, r) + case path == "" && r.Method == http.MethodGet: + h.list(w, r) + case path != "" && r.Method == http.MethodGet: + h.get(w, r, path) + case path != "" && r.Method == http.MethodPut: + h.update(w, r, path) + case path != "" && r.Method == http.MethodDelete: + h.delete(w, r, path) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// CreateResourceRequest is the payload for POST /api/resources. +type CreateResourceRequest struct { + Content string `json:"content"` + Type string `json:"type"` // semantic | procedural | episodic (default: semantic) + Tags []string `json:"tags"` +} + +// UpdateResourceRequest is the payload for PUT /api/resources/{id}. +type UpdateResourceRequest struct { + Content string `json:"content"` + Tags []string `json:"tags"` +} + +// ResourceResponse is the API representation of a resource entry. +type ResourceResponse struct { + ID string `json:"id"` + Content string `json:"content"` + Type string `json:"type"` + Tags []string `json:"tags"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func entryToResponse(e memory.Entry) ResourceResponse { + return ResourceResponse{ + ID: e.ID, + Content: e.Content, + Type: string(e.Type), + Tags: e.Tags, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + } +} + +func (h *Handler) create(w http.ResponseWriter, r *http.Request) { + var req CreateResourceRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Content) == "" { + http.Error(w, "content is required", http.StatusBadRequest) + return + } + + memType := memory.TypeSemantic + if req.Type != "" { + memType = memory.Type(req.Type) + } + + id := "res_" + uuid.New().String() + now := time.Now().UTC() + entry := memory.Entry{ + ID: id, + Type: memType, + Content: req.Content, + Tags: req.Tags, + Author: memory.AuthorHuman, + Source: memory.SourceResource, + Status: memory.EntryStatusActive, + TrustScore: 1.0, // resources are always fully trusted + CreatedAt: now, + UpdatedAt: now, + } + + if err := h.store.Create(r.Context(), entry); err != nil { + h.jsonError(w, fmt.Errorf("creating entry: %w", err)) + return + } + + embedding, err := h.embed(r.Context(), req.Content) + if err != nil { + _ = h.store.Delete(r.Context(), id) // rollback + h.jsonError(w, fmt.Errorf("embedding content: %w", err)) + return + } + if err := h.vectors.Upsert(r.Context(), id, embedding); err != nil { + _ = h.store.Delete(r.Context(), id) // rollback + h.jsonError(w, fmt.Errorf("storing embedding: %w", err)) + return + } + + h.registerFn(entry) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(entryToResponse(entry)) +} + +func (h *Handler) list(w http.ResponseWriter, r *http.Request) { + limit, offset := parsePagination(r) + src := memory.SourceResource + entries, err := h.store.List(r.Context(), memory.ListFilter{ + Source: &src, + Limit: limit, + Offset: offset, + }) + if err != nil { + h.jsonError(w, err) + return + } + + resp := make([]ResourceResponse, 0, len(entries)) + for _, e := range entries { + resp = append(resp, entryToResponse(e)) + } + jsonOK(w, resp) +} + +func (h *Handler) get(w http.ResponseWriter, r *http.Request, id string) { + entry, err := h.store.Get(r.Context(), id) + if err != nil { + if errors.Is(err, memory.ErrNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + h.jsonError(w, err) + return + } + if entry.Source != memory.SourceResource { + http.Error(w, "not a resource entry", http.StatusNotFound) + return + } + jsonOK(w, entryToResponse(entry)) +} + +func (h *Handler) update(w http.ResponseWriter, r *http.Request, id string) { + entry, err := h.store.Get(r.Context(), id) + if err != nil { + if errors.Is(err, memory.ErrNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + h.jsonError(w, err) + return + } + if entry.Source != memory.SourceResource { + http.Error(w, "not a resource entry", http.StatusNotFound) + return + } + + var req UpdateResourceRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Content) == "" { + http.Error(w, "content is required", http.StatusBadRequest) + return + } + + if err := h.store.Update(r.Context(), id, req.Content, memory.AuthorHuman, "updated via management API"); err != nil { + h.jsonError(w, fmt.Errorf("updating entry: %w", err)) + return + } + + embedding, err := h.embed(r.Context(), req.Content) + if err != nil { + h.log.Warn("failed to re-embed resource after update", zap.String("id", id), zap.Error(err)) + } else if err := h.vectors.Upsert(r.Context(), id, embedding); err != nil { + h.log.Warn("failed to update embedding after resource update", zap.String("id", id), zap.Error(err)) + } + + updated, err := h.store.Get(r.Context(), id) + if err != nil { + h.jsonError(w, err) + return + } + + h.registerFn(updated) + jsonOK(w, entryToResponse(updated)) +} + +func (h *Handler) delete(w http.ResponseWriter, r *http.Request, id string) { + entry, err := h.store.Get(r.Context(), id) + if err != nil { + if errors.Is(err, memory.ErrNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + h.jsonError(w, err) + return + } + if entry.Source != memory.SourceResource { + http.Error(w, "not a resource entry", http.StatusNotFound) + return + } + + if err := h.store.Delete(r.Context(), id); err != nil { + h.jsonError(w, fmt.Errorf("deleting entry: %w", err)) + return + } + if err := h.vectors.Delete(r.Context(), id); err != nil { + h.log.Warn("failed to delete embedding for resource", zap.String("id", id), zap.Error(err)) + } + + h.unregisterFn(id) + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) embed(ctx context.Context, text string) ([]float32, error) { + return h.embedder.Embed(ctx, text) +} + +func (h *Handler) jsonError(w http.ResponseWriter, err error) { + h.log.Warn("resource API error", zap.Error(err)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) +} + +func jsonOK(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(v) +} + +func parsePagination(r *http.Request) (limit, offset int) { + limit = defaultListLimit + if l := r.URL.Query().Get("limit"); l != "" { + if v, err := strconv.Atoi(l); err == nil && v > 0 { + limit = min(v, maxListLimit) + } + } + if o := r.URL.Query().Get("offset"); o != "" { + if v, err := strconv.Atoi(o); err == nil && v >= 0 { + offset = v + } + } + return limit, offset +} diff --git a/cmd/thv-memory/server.go b/cmd/thv-memory/server.go new file mode 100644 index 0000000000..97c5acb675 --- /dev/null +++ b/cmd/thv-memory/server.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.uber.org/zap" + + "github.com/stacklok/toolhive/cmd/thv-memory/tools" + "github.com/stacklok/toolhive/pkg/memory" +) + +const ( + mcpEndpointPath = "/mcp" + resourceURIPrefix = "memory://resource/" + resourceURITemplate = "memory://resource/{id}" +) + +// newMCPServer creates the MCP server and registers all memory tools. +// Resource capabilities (listChanged) are enabled so connected agents +// receive notifications/resources/list_changed whenever a resource is +// created or deleted via the management API. +func newMCPServer(cfg *Config, svc *memory.Service, store memory.Store) *server.MCPServer { + s := server.NewMCPServer(cfg.Server.Name, cfg.Server.Version, + server.WithResourceCapabilities(false, true), + ) + + tools.RegisterRemember(s, svc) + tools.RegisterSearch(s, svc) + tools.RegisterRecall(s, store) + tools.RegisterForget(s, store) + tools.RegisterUpdate(s, store) + tools.RegisterFlag(s, store) + tools.RegisterList(s, store) + tools.RegisterConsolidate(s, svc, store) + tools.RegisterCrystallize(s, store) + + // URI template allows agents to probe a resource by known ID without listing. + s.AddResourceTemplate( + mcp.NewResourceTemplate(resourceURITemplate, "Memory Resource", + mcp.WithTemplateDescription("A UI-managed reference document stored in the memory server."), + mcp.WithTemplateMIMEType("text/plain"), + ), + server.ResourceTemplateHandlerFunc(makeResourceReadHandler(store)), + ) + + return s +} + +// LoadExistingResources registers all persisted resource entries with the MCP +// server at startup so they appear in resources/list immediately. +func LoadExistingResources(ctx context.Context, s *server.MCPServer, store memory.Store, log *zap.Logger) { + src := memory.SourceResource + entries, err := store.List(ctx, memory.ListFilter{Source: &src, Limit: 1000}) + if err != nil { + log.Warn("failed to load existing resources", zap.Error(err)) + return + } + for _, e := range entries { + registerResource(s, store, e) + } + log.Debug("loaded existing resources", zap.Int("count", len(entries))) +} + +// RegisterResourceEntry adds a resource entry to the MCP server listing. +// mcp-go automatically sends notifications/resources/list_changed to all +// connected sessions when WithResourceCapabilities listChanged is true. +func RegisterResourceEntry(s *server.MCPServer, store memory.Store, e memory.Entry) { + // Remove any previous registration (e.g., on update with name change). + s.DeleteResources(resourceURIPrefix + e.ID) + registerResource(s, store, e) +} + +// UnregisterResourceEntry removes a resource entry from the MCP server listing. +func UnregisterResourceEntry(s *server.MCPServer, id string) { + s.DeleteResources(resourceURIPrefix + id) +} + +func registerResource(s *server.MCPServer, store memory.Store, e memory.Entry) { + name := resourceName(e) + s.AddResource( + mcp.NewResource(resourceURIPrefix+e.ID, name, + mcp.WithResourceDescription(fmt.Sprintf("Resource entry %s", e.ID)), + mcp.WithMIMEType("text/plain"), + ), + makeResourceReadHandler(store), + ) +} + +// newHandler wraps the MCP server in the streamable-HTTP transport and +// returns a mux that exposes: +// - /mcp — MCP streamable-HTTP transport +// - /api/ — Management REST API (resource CRUD for UI) +// - /health — Liveness probe +func newHandler(s *server.MCPServer, resourceAPI http.Handler, log *zap.Logger) http.Handler { + log.Debug("registered memory MCP tools", zap.String("endpoint", mcpEndpointPath)) + + streamable := server.NewStreamableHTTPServer(s, + server.WithEndpointPath(mcpEndpointPath), + ) + + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.Handle("/api/", resourceAPI) + mux.Handle("/", streamable) + return mux +} + +// makeResourceReadHandler returns a handler that reads entry content from the +// store. When store is nil the handler is a no-op placeholder replaced at +// resource-registration time by AddResource with a proper store reference. +func makeResourceReadHandler(store memory.Store) server.ResourceHandlerFunc { + return func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + if store == nil { + return nil, errors.New("resource store not initialised") + } + id, ok := strings.CutPrefix(req.Params.URI, resourceURIPrefix) + if !ok || id == "" { + return nil, fmt.Errorf("invalid resource URI: %s", req.Params.URI) + } + entry, err := store.Get(ctx, id) + if err != nil { + return nil, err + } + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "text/plain", + Text: entry.Content, + }, + }, nil + } +} + +// resourceName returns a short display name for a resource entry. +// Uses the first 60 characters of content as the name. +func resourceName(e memory.Entry) string { + name := e.Content + if len(name) > 60 { + name = name[:60] + "…" + } + return name +} diff --git a/cmd/thv-memory/tools/consolidate.go b/cmd/thv-memory/tools/consolidate.go new file mode 100644 index 0000000000..3d7ef11913 --- /dev/null +++ b/cmd/thv-memory/tools/consolidate.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package tools registers MCP tools for the memory server. +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterConsolidate registers the memory_consolidate tool. +func RegisterConsolidate(s *server.MCPServer, svc *memory.Service, store memory.Store) { + tool := mcp.NewTool("memory_consolidate", + mcp.WithDescription( + "Merge related memory entries into one richer entry. "+ + "Originals are archived with a pointer to the new entry.", + ), + mcp.WithArray("ids", mcp.Required(), mcp.Description("Array of memory IDs to consolidate"), mcp.WithStringItems()), + mcp.WithString("content", mcp.Required(), mcp.Description("Content for the consolidated entry")), + mcp.WithString("type", mcp.Required(), mcp.Description("Memory type for the consolidated entry")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ids, err := req.RequireStringSlice("ids") + if err != nil { + return mcp.NewToolResultError("ids must be an array of strings"), nil + } + if len(ids) < 2 { + return mcp.NewToolResultError("at least 2 ids required"), nil + } + + content := req.GetString("content", "") + memTypeStr := req.GetString("type", "") + + result, err := svc.Remember(ctx, memory.RememberInput{ + Content: content, + Type: memory.Type(memTypeStr), + Author: memory.AuthorHuman, + Force: true, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("creating consolidated entry: %v", err)), nil + } + + var archiveErrors []string + for _, id := range ids { + if err := store.Archive(ctx, id, memory.ArchiveReasonConsolidated, result.MemoryID); err != nil { + archiveErrors = append(archiveErrors, fmt.Sprintf("%s: %v", id, err)) + } + } + + resp := map[string]any{ + "consolidated_id": result.MemoryID, + "archived_ids": ids, + } + if len(archiveErrors) > 0 { + resp["archive_errors"] = strings.Join(archiveErrors, "; ") + } + out, _ := json.Marshal(resp) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/cmd/thv-memory/tools/crystallize.go b/cmd/thv-memory/tools/crystallize.go new file mode 100644 index 0000000000..e7318a8127 --- /dev/null +++ b/cmd/thv-memory/tools/crystallize.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterCrystallize registers the memory_crystallize tool. +func RegisterCrystallize(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_crystallize", + mcp.WithDescription("Generate a SKILL.md scaffold from procedural memory entries for human review and publishing."), + mcp.WithArray("ids", mcp.Required(), mcp.Description("Array of procedural memory IDs"), mcp.WithStringItems()), + mcp.WithString("name", mcp.Required(), mcp.Description("Proposed skill name (kebab-case)")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ids, err := req.RequireStringSlice("ids") + if err != nil { + return mcp.NewToolResultError("ids must be an array of strings"), nil + } + name := req.GetString("name", "") + + var contents []string + for _, id := range ids { + entry, err := store.Get(ctx, id) + if err != nil { + continue + } + contents = append(contents, entry.Content) + } + if len(contents) == 0 { + return mcp.NewToolResultError("no valid entries found"), nil + } + + scaffold := buildSkillScaffold(name, contents) + out, _ := json.Marshal(map[string]string{ + "skill_name": name, + "skill_md": scaffold, + "note": "Review this scaffold, edit as needed, then publish with: thv skills push " + name, + }) + return mcp.NewToolResultText(string(out)), nil + }) +} + +func buildSkillScaffold(name string, contents []string) string { + return fmt.Sprintf(`--- +name: %s +description: "[TODO: one-line description of what this skill does]" +--- + +# %s + +## Context + +This skill was crystallized from %d procedural memory entries. + +## Guidance + +%s + +## When to Use + +[TODO: describe when an agent should apply this skill] +`, name, name, len(contents), "- "+strings.Join(contents, "\n- ")) +} diff --git a/cmd/thv-memory/tools/flag.go b/cmd/thv-memory/tools/flag.go new file mode 100644 index 0000000000..db1e986d42 --- /dev/null +++ b/cmd/thv-memory/tools/flag.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterFlag registers the memory_flag tool. +func RegisterFlag(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_flag", + mcp.WithDescription("Mark a memory as potentially stale without deleting it."), + mcp.WithString("id", mcp.Required(), mcp.Description("Memory entry ID")), + mcp.WithString("reason", mcp.Required(), mcp.Description("Why this memory may be stale")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := req.GetString("id", "") + reason := req.GetString("reason", "") + if err := checkMutable(ctx, store, id); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if err := store.Flag(ctx, id, reason); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + return mcp.NewToolResultText(`{"status":"ok"}`), nil + }) +} diff --git a/cmd/thv-memory/tools/forget.go b/cmd/thv-memory/tools/forget.go new file mode 100644 index 0000000000..0c88ed1a43 --- /dev/null +++ b/cmd/thv-memory/tools/forget.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterForget registers the memory_forget tool. +func RegisterForget(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_forget", + mcp.WithDescription("Delete a memory entry permanently."), + mcp.WithString("id", mcp.Required(), mcp.Description("Memory entry ID")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := req.GetString("id", "") + if err := checkMutable(ctx, store, id); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if err := store.Delete(ctx, id); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + return mcp.NewToolResultText(`{"status":"ok"}`), nil + }) +} diff --git a/cmd/thv-memory/tools/helpers.go b/cmd/thv-memory/tools/helpers.go new file mode 100644 index 0000000000..c8e262bb16 --- /dev/null +++ b/cmd/thv-memory/tools/helpers.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// checkMutable returns an error if the entry's source type is read-only to +// agents (SourceSkill or SourceResource). These entries may only be modified +// via the management REST API, not via MCP tool calls. +func checkMutable(ctx context.Context, store memory.Store, id string) error { + entry, err := store.Get(ctx, id) + if err != nil { + return err + } + if entry.Source == memory.SourceSkill || entry.Source == memory.SourceResource { + return fmt.Errorf("entry %q (source=%s): %w", id, entry.Source, memory.ErrReadOnly) + } + return nil +} diff --git a/cmd/thv-memory/tools/list.go b/cmd/thv-memory/tools/list.go new file mode 100644 index 0000000000..b813dc4a18 --- /dev/null +++ b/cmd/thv-memory/tools/list.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterList registers the memory_list tool. +func RegisterList(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_list", + mcp.WithDescription("List memory entries with structured filters (not semantic). Use memory_search for semantic queries."), + mcp.WithString("type", mcp.Description("Filter by type: semantic or procedural")), + mcp.WithString("author", mcp.Description("Filter by author: human or agent")), + mcp.WithNumber("limit", mcp.Description("Max results (default 20)")), + mcp.WithNumber("offset", mcp.Description("Pagination offset")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var f memory.ListFilter + if rawType := req.GetString("type", ""); rawType != "" { + t := memory.Type(rawType) + f.Type = &t + } + if rawAuthor := req.GetString("author", ""); rawAuthor != "" { + a := memory.AuthorType(rawAuthor) + f.Author = &a + } + f.Limit = req.GetInt("limit", 20) + f.Offset = req.GetInt("offset", 0) + active := memory.EntryStatusActive + f.Status = &active + + entries, err := store.List(ctx, f) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + out, _ := json.Marshal(entries) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/cmd/thv-memory/tools/recall.go b/cmd/thv-memory/tools/recall.go new file mode 100644 index 0000000000..f4e52ccf21 --- /dev/null +++ b/cmd/thv-memory/tools/recall.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + "errors" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterRecall registers the memory_recall tool. +func RegisterRecall(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_recall", + mcp.WithDescription("Fetch a specific memory entry by ID, including its full revision history."), + mcp.WithString("id", mcp.Required(), mcp.Description("Memory entry ID")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := req.GetString("id", "") + entry, err := store.Get(ctx, id) + if err != nil { + if errors.Is(err, memory.ErrNotFound) { + return mcp.NewToolResultError("entry not found"), nil + } + return mcp.NewToolResultError(err.Error()), nil + } + _ = store.IncrementAccess(ctx, id) + out, _ := json.Marshal(entry) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/cmd/thv-memory/tools/remember.go b/cmd/thv-memory/tools/remember.go new file mode 100644 index 0000000000..585fdcac83 --- /dev/null +++ b/cmd/thv-memory/tools/remember.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterRemember registers the memory_remember tool. +func RegisterRemember(s *server.MCPServer, svc *memory.Service) { + tool := mcp.NewTool("memory_remember", + mcp.WithDescription("Store a new semantic or procedural memory. Returns conflict_detected if a similar memory exists."), + mcp.WithString("content", mcp.Required(), mcp.Description("The knowledge to store")), + mcp.WithString("type", mcp.Required(), mcp.Description("Memory type: semantic, procedural, or episodic")), + mcp.WithString("author", mcp.Description("Author type: human or agent (default: agent)")), + mcp.WithArray("tags", mcp.Description("Optional labels for filtering and retrieval"), mcp.WithStringItems()), + mcp.WithString("session_id", mcp.Description("Originating session ID")), + mcp.WithNumber("ttl_days", mcp.Description("Optional TTL in days")), + mcp.WithBoolean("force", mcp.Description("Write even if conflicts detected")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + content := req.GetString("content", "") + memTypeStr := req.GetString("type", "") + authorStr := req.GetString("author", "agent") + if authorStr == "" { + authorStr = "agent" + } + tags, _ := req.RequireStringSlice("tags") // optional; ignore error when absent + force := req.GetBool("force", false) + sessionID := req.GetString("session_id", "") + + var ttlDays *int + args := req.GetArguments() + if raw, ok := args["ttl_days"].(float64); ok { + v := int(raw) + ttlDays = &v + } + + result, err := svc.Remember(ctx, memory.RememberInput{ + Content: content, + Type: memory.Type(memTypeStr), + Author: memory.AuthorType(authorStr), + Tags: tags, + SessionID: sessionID, + TTLDays: ttlDays, + Force: force, + }) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + out, _ := json.Marshal(result) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/cmd/thv-memory/tools/search.go b/cmd/thv-memory/tools/search.go new file mode 100644 index 0000000000..9de39e0144 --- /dev/null +++ b/cmd/thv-memory/tools/search.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterSearch registers the memory_search tool. +func RegisterSearch(s *server.MCPServer, svc *memory.Service) { + tool := mcp.NewTool("memory_search", + mcp.WithDescription( + "Semantic search across memory entries. "+ + "Returns entries ranked by similarity with trust and staleness scores.", + ), + mcp.WithString("query", mcp.Required(), mcp.Description("Natural language query")), + mcp.WithString("type", mcp.Description("Filter by type: semantic or procedural")), + mcp.WithNumber("top_k", mcp.Description("Maximum results to return (default 10)")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query := req.GetString("query", "") + + var memType *memory.Type + if rawType := req.GetString("type", ""); rawType != "" { + t := memory.Type(rawType) + memType = &t + } + + topK := req.GetInt("top_k", 10) + + results, err := svc.Search(ctx, query, memType, topK) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + out, _ := json.Marshal(results) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/cmd/thv-memory/tools/update.go b/cmd/thv-memory/tools/update.go new file mode 100644 index 0000000000..07cd9f2f2d --- /dev/null +++ b/cmd/thv-memory/tools/update.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// RegisterUpdate registers the memory_update tool. +func RegisterUpdate(s *server.MCPServer, store memory.Store) { + tool := mcp.NewTool("memory_update", + mcp.WithDescription("Correct or refine an existing memory entry. Previous content is saved to history."), + mcp.WithString("id", mcp.Required(), mcp.Description("Memory entry ID")), + mcp.WithString("content", mcp.Required(), mcp.Description("Updated content")), + mcp.WithString("author", mcp.Description("Author type: human or agent (default: agent)")), + mcp.WithString("correction_note", mcp.Description("Explanation for the correction")), + ) + s.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := req.GetString("id", "") + content := req.GetString("content", "") + authorStr := req.GetString("author", "agent") + if authorStr == "" { + authorStr = "agent" + } + note := req.GetString("correction_note", "") + + if err := checkMutable(ctx, store, id); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if err := store.Update(ctx, id, content, memory.AuthorType(authorStr), note); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + entry, err := store.Get(ctx, id) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + out, _ := json.Marshal(entry) + return mcp.NewToolResultText(string(out)), nil + }) +} diff --git a/demo/recruiter/Makefile b/demo/recruiter/Makefile new file mode 100644 index 0000000000..322b675b19 --- /dev/null +++ b/demo/recruiter/Makefile @@ -0,0 +1,194 @@ +# ══════════════════════════════════════════════════════════════════════════════ +# ToolHive Memory Demo — The Recruiter Scenario +# ══════════════════════════════════════════════════════════════════════════════ +# +# Prerequisites +# • Ollama running locally with nomic-embed-text pulled +# brew install ollama && ollama pull nomic-embed-text && ollama serve +# • Go 1.22+ +# • Claude Code CLI installed and authenticated (for agent sessions) +# npm install -g @anthropic-ai/claude-code +# +# Quick start +# make all — build binaries, start server, run full demo, stop server +# make demo — run full demo (server must already be running) +# make teardown — stop server, delete DB and demo binary (safe to repeat) +# +# Individual targets +# make build — build thv-memory server + demo binary +# make server-start — start memory server in background +# make server-stop — stop memory server +# make server-logs — tail server log file +# make status — check server health +# +# Agent sessions (run individually after setup phase) +# make session-recruiter-alice — Phase 3: recruiter records Alice's phone screen +# make session-hiring-manager — Phase 4: hiring manager searches memory cold +# make session-recruiter-bob — Phase 5: recruiter records Bob + procedural lesson +# make session-recruiter-charlie — Phase 6: recruiter uses checklist, records Charlie +# make session-crystallize — Phase 7: crystallize phone-screen pattern → Skill +# ══════════════════════════════════════════════════════════════════════════════ + +SHELL := /bin/bash +.DEFAULT_GOAL := all + +# Locate the worktree root (parent of demo/recruiter) +WORKTREE_ROOT := $(shell cd ../.. && pwd) + +# Binaries +SERVER_BIN := $(WORKTREE_ROOT)/bin/thv-memory +DEMO_BIN := $(CURDIR)/.demo-bin + +# Server config (generated from template) +CONFIG_TMPL := $(CURDIR)/config/memory-server.yaml.tmpl +CONFIG := $(CURDIR)/config/memory-server.yaml + +# Runtime files (all cleaned by teardown) +DB_FILE := $(CURDIR)/demo-memory.db +PID_FILE := $(CURDIR)/.server.pid +LOG_FILE := $(CURDIR)/.server.log + +# Demo server port — non-standard to avoid conflicts with anything on 8080 +PORT := 8765 +BASE_URL := http://127.0.0.1:$(PORT) + +# MCP config file for Claude Code agent sessions +MCP_CONFIG := $(CURDIR)/.demo.mcp.json + +# Env for the demo binary +DEMO_ENV := \ + MEMORY_MCP_URL=$(BASE_URL)/mcp \ + MEMORY_API_URL=$(BASE_URL)/api/resources \ + MEMORY_JD_FILE=$(CURDIR)/data/job-description.txt + +.PHONY: all demo build \ + server-start server-stop server-wait server-logs status \ + mcp-config \ + session-recruiter-alice session-hiring-manager \ + session-recruiter-bob session-recruiter-charlie \ + session-crystallize \ + teardown + +# ── all: one-shot full demo ─────────────────────────────────────────────────── +all: build server-start server-wait demo server-stop + +# ── demo: run full scenario (server must be up) ─────────────────────────────── +# Runs setup (Phases 1-2) then all five Claude Code agent sessions (Phases 3-7) +demo: $(DEMO_BIN) mcp-config + @echo "" + $(DEMO_ENV) $(DEMO_BIN) + @$(MAKE) --no-print-directory session-recruiter-alice + @$(MAKE) --no-print-directory session-hiring-manager + @$(MAKE) --no-print-directory session-recruiter-bob + @$(MAKE) --no-print-directory session-recruiter-charlie + @$(MAKE) --no-print-directory session-crystallize + +# ── mcp-config: generate MCP config for Claude Code sessions ───────────────── +mcp-config: + @printf '{\n "mcpServers": {\n "toolhive-memory": {\n "type": "http",\n "url": "$(BASE_URL)/mcp"\n }\n }\n}\n' > $(MCP_CONFIG) + @echo "✓ MCP config written to $(MCP_CONFIG)" + +# ── agent sessions (Phase 3-7) ──────────────────────────────────────────────── +# Each session prints the prompt to use, then waits for you to run it in +# Claude Code (claude --mcp-config .demo.mcp.json) and press Enter to continue. + +define print-prompt + @echo "" + @echo "═══════════════════════════════════════════════════════════════" + @echo " $(1)" + @echo " MCP config: $(MCP_CONFIG)" + @echo "═══════════════════════════════════════════════════════════════" + @echo "" + @cat $(2) + @echo "" + @echo "───────────────────────────────────────────────────────────────" + @read -p " ↑ Use this prompt in Claude Code, then press Enter to continue... " _ +endef + +session-recruiter-alice: mcp-config + $(call print-prompt,Phase 3 · Recruiter Session — Alice Chen phone screen,$(CURDIR)/prompts/recruiter-alice.txt) + +session-hiring-manager: mcp-config + $(call print-prompt,Phase 4 · Hiring Manager — cold memory search,$(CURDIR)/prompts/hiring-manager.txt) + +session-recruiter-bob: mcp-config + $(call print-prompt,Phase 5 · Recruiter Session — Bob Martinez + procedural lesson,$(CURDIR)/prompts/recruiter-bob.txt) + +session-recruiter-charlie: mcp-config + $(call print-prompt,Phase 6 · Recruiter Session — Charlie Kim (HIRE),$(CURDIR)/prompts/recruiter-charlie.txt) + +session-crystallize: mcp-config + $(call print-prompt,Phase 7 · Crystallize — phone-screen pattern → Skill,$(CURDIR)/prompts/crystallize.txt) + +# ── build: compile both binaries ───────────────────────────────────────────── +build: $(SERVER_BIN) $(DEMO_BIN) + +$(SERVER_BIN): + @echo "▶ Building thv-memory server..." + @mkdir -p $(WORKTREE_ROOT)/bin + cd $(WORKTREE_ROOT) && go build -o $(SERVER_BIN) ./cmd/thv-memory/ + @echo "✓ $(SERVER_BIN)" + +$(DEMO_BIN): $(CURDIR)/cmd/demo/main.go + @echo "▶ Building demo binary..." + cd $(WORKTREE_ROOT) && go build -o $(DEMO_BIN) ./demo/recruiter/cmd/demo/ + @echo "✓ $(DEMO_BIN)" + +# ── server config (generated from template) ────────────────────────────────── +$(CONFIG): $(CONFIG_TMPL) + @echo "▶ Generating server config..." + @sed 's|DEMO_DB_PATH|$(DB_FILE)|g' $(CONFIG_TMPL) > $(CONFIG) + @echo "✓ $(CONFIG)" + +# ── server-start ───────────────────────────────────────────────────────────── +server-start: $(SERVER_BIN) $(CONFIG) + @if [ -f $(PID_FILE) ] && kill -0 $$(cat $(PID_FILE)) 2>/dev/null; then \ + echo "⚠ Server already running (pid=$$(cat $(PID_FILE)))"; \ + else \ + echo "▶ Starting memory server on port $(PORT)..."; \ + MEMORY_CONFIG=$(CONFIG) $(SERVER_BIN) > $(LOG_FILE) 2>&1 & \ + echo $$! > $(PID_FILE); \ + echo "✓ Server started (pid=$$!, log=$(LOG_FILE))"; \ + fi + +# ── server-wait: poll /health until ready ──────────────────────────────────── +server-wait: + @printf "▶ Waiting for server..." + @for i in $$(seq 1 30); do \ + if curl -sf $(BASE_URL)/health > /dev/null 2>&1; then \ + printf " ready!\n"; exit 0; \ + fi; \ + printf "."; sleep 1; \ + done; \ + printf " TIMEOUT\n"; \ + echo " Check logs: make server-logs"; \ + exit 1 + +# ── server-stop ────────────────────────────────────────────────────────────── +server-stop: + @if [ -f $(PID_FILE) ]; then \ + PID=$$(cat $(PID_FILE)); \ + kill $$PID 2>/dev/null && echo "✓ Server stopped (pid=$$PID)" || echo "⚠ Process $$PID not found"; \ + rm -f $(PID_FILE); \ + else \ + echo "⚠ No PID file found — server may not be running"; \ + fi + +# ── server-logs ────────────────────────────────────────────────────────────── +server-logs: + @tail -f $(LOG_FILE) + +# ── status: quick health check ─────────────────────────────────────────────── +status: + @if curl -sf $(BASE_URL)/health > /dev/null 2>&1; then \ + echo "✓ Server healthy at $(BASE_URL)"; \ + else \ + echo "✗ Server not responding at $(BASE_URL)"; \ + fi + +# ── teardown: stop server + wipe all demo data ─────────────────────────────── +teardown: server-stop + @echo "▶ Removing demo data..." + @rm -f $(DB_FILE) $(DEMO_BIN) $(SERVER_BIN) $(LOG_FILE) $(CONFIG) $(MCP_CONFIG) + @rm -f $(CURDIR)/data/*.md + @echo "✓ Teardown complete — safe to run 'make all' again" diff --git a/demo/recruiter/cmd/demo/main.go b/demo/recruiter/cmd/demo/main.go new file mode 100644 index 0000000000..d8f4147ecd --- /dev/null +++ b/demo/recruiter/cmd/demo/main.go @@ -0,0 +1,286 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Command demo is the automated setup phase of the ToolHive Memory recruiter demo. +// +// It handles Phase 1 (upload the job description as a static MCP Resource and +// print it in full) and Phase 2 (write shared semantic memories). Phases 3-7 +// are run as real Claude Code agent sessions — see the Makefile targets. +// +// Configuration via environment variables: +// +// MEMORY_MCP_URL — MCP endpoint (default: http://127.0.0.1:8765/mcp) +// MEMORY_API_URL — Resources REST endpoint (default: http://127.0.0.1:8765/api/resources) +// MEMORY_JD_FILE — Path to job description file (default: data/job-description.txt) +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// ─── configuration ─────────────────────────────────────────────────────────── + +var ( + mcpURL = envOr("MEMORY_MCP_URL", "http://127.0.0.1:8765/mcp") + apiURL = envOr("MEMORY_API_URL", "http://127.0.0.1:8765/api/resources") + jdFile = envOr("MEMORY_JD_FILE", "data/job-description.txt") +) + +func envOr(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +// ─── ANSI helpers ───────────────────────────────────────────────────────────── + +const ( + ansiReset = "\033[0m" + ansiBold = "\033[1m" + ansiDim = "\033[2m" + ansiGreen = "\033[32m" + ansiYellow = "\033[33m" + ansiPurple = "\033[35m" + ansiCyan = "\033[36m" + ansiWhite = "\033[97m" + ansiGray = "\033[90m" +) + +func col(color, s string) string { return color + s + ansiReset } + +// ─── main ───────────────────────────────────────────────────────────────────── + +func main() { + ctx := context.Background() + + jd, err := os.ReadFile(jdFile) + if err != nil { + fatalf("reading %s: %v\n Hint: run from demo/recruiter/ directory", jdFile, err) + } + + printBanner() + + // ── Phase 1 · Resource ──────────────────────────────────────────────────── + phase(1, "Resources", "Upload the job description as a static MCP Resource (read-only to agents)") + resID := uploadResource(ctx, jd) + printJobDescription(jd) + pause() + + // ── Phase 2 · Semantic Memory ───────────────────────────────────────────── + phase(2, "Semantic Memory", "Company-wide facts — written once, recalled by any agent session at any time") + cl := newSession(ctx, "setup") + defer cl.Close() + + remember(ctx, cl, "semantic", + "Company does not sponsor US work visas for any engineering role", + "policy", "visa", "hiring") + remember(ctx, cl, "semantic", + "Senior Go Engineer base salary band: $100,000–$150,000 USD; total comp includes equity", + "compensation", "hiring", "senior-go-engineer") + remember(ctx, cl, "semantic", + "Engineering team is fully remote, US timezone preferred (EST/PST). Async-first culture.", + "remote", "culture", "hiring") + pause() + + printHandoff(resID) +} + +// ─── MCP helpers ────────────────────────────────────────────────────────────── + +func newSession(ctx context.Context, name string) *mcpclient.Client { + t, err := transport.NewStreamableHTTP(mcpURL) + if err != nil { + fatalf("transport (%s): %v", name, err) + } + cl := mcpclient.NewClient(t) + if err := cl.Start(ctx); err != nil { + fatalf("start (%s): %v", name, err) + } + if _, err := cl.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "demo-" + name, Version: "1.0"}, + }, + }); err != nil { + fatalf("initialize (%s): %v", name, err) + } + fmt.Printf(" %s Session opened: %s\n", col(ansiGray, "→"), col(ansiBold+ansiCyan, name)) + return cl +} + +func callTool(ctx context.Context, cl *mcpclient.Client, tool string, args map[string]any) string { + result, err := cl.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{Name: tool, Arguments: args}, + }) + if err != nil { + fatalf("tool/%s: %v", tool, err) + } + for _, content := range result.Content { + if tc, ok := content.(mcp.TextContent); ok { + return tc.Text + } + } + return "" +} + +func remember(ctx context.Context, cl *mcpclient.Client, memType, content string, tags ...string) string { + raw := callTool(ctx, cl, "memory_remember", map[string]any{ + "content": content, + "type": memType, + "author": "human", + "tags": tags, + }) + // RememberResult serialises without json tags: {"MemoryID":"...","Conflicts":null} + var resp struct { + MemoryID string `json:"MemoryID"` + Conflicts []any `json:"Conflicts"` + } + _ = json.Unmarshal([]byte(raw), &resp) + + icon := typeIcon(memType) + fmt.Printf("\n %s %s %s\n", + icon, + col(ansiBold+typeColor(memType), "["+memType+"]"), + truncate(content, 72)) + if resp.MemoryID != "" { + fmt.Printf(" %sid=%-36s tags=%v%s\n", ansiGray, resp.MemoryID, tags, ansiReset) + } + return resp.MemoryID +} + +// ─── Resources REST API helper ───────────────────────────────────────────────── + +func uploadResource(ctx context.Context, content []byte) string { + body, _ := json.Marshal(map[string]any{ + "content": string(content), + "type": "semantic", + "tags": []string{"job-description", "senior-go-engineer", "hiring"}, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + fatalf("building request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + fatalf("POST %s: %v", apiURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + var errBody map[string]string + _ = json.NewDecoder(resp.Body).Decode(&errBody) + fatalf("POST /api/resources returned %d: %v", resp.StatusCode, errBody["error"]) + } + + var r struct { + ID string `json:"id"` + } + _ = json.NewDecoder(resp.Body).Decode(&r) + + fmt.Printf("\n %s POST /api/resources ← Senior Go Engineer job description\n", col(ansiGray, "→")) + fmt.Printf(" %s Resource registered: %s\n", col(ansiGreen, "✓"), col(ansiCyan, r.ID)) + fmt.Printf(" %s Agents discover it via memory_search or MCP resources/list\n", col(ansiGreen, "✓")) + return r.ID +} + +func printJobDescription(content []byte) { + divider := col(ansiDim, strings.Repeat("─", 64)) + fmt.Printf("\n%s\n", divider) + for _, line := range strings.Split(string(content), "\n") { + fmt.Printf(" %s%s%s\n", ansiGray, line, ansiReset) + } + fmt.Printf("%s\n", divider) +} + +// ─── display helpers ───────────────────────────────────────────────────────── + +func printBanner() { + bar := strings.Repeat("═", 64) + fmt.Printf("\n%s\n", col(ansiBold+ansiCyan, bar)) + fmt.Printf("%s\n", col(ansiBold+ansiCyan, " ToolHive Memory Demo — The Recruiter")) + fmt.Printf("%s\n", col(ansiCyan, " Scenario: Hiring a Senior Go Engineer at Stacklok")) + fmt.Printf("%s\n\n", col(ansiBold+ansiCyan, bar)) + fmt.Printf(" Server : %s\n\n", col(ansiGray, mcpURL)) +} + +func phase(n int, title, subtitle string) { + bar := strings.Repeat("─", 64) + fmt.Printf("\n%s\n", col(ansiYellow, bar)) + fmt.Printf(" %s\n", col(ansiBold+ansiWhite, fmt.Sprintf("Phase %d · %s", n, title))) + if subtitle != "" { + fmt.Printf(" %s%s%s\n", ansiGray, subtitle, ansiReset) + } + fmt.Printf("%s\n", col(ansiYellow, bar)) +} + +func printHandoff(resourceID string) { + bar := strings.Repeat("═", 64) + fmt.Printf("\n\n%s\n", col(ansiBold+ansiGreen, bar)) + fmt.Printf("%s\n", col(ansiBold+ansiGreen, " Setup complete — memory server is primed")) + fmt.Printf("%s\n\n", col(ansiBold+ansiGreen, bar)) + fmt.Printf(" Resource : %s\n", col(ansiCyan, resourceID)) + fmt.Printf(" Semantic : 3 company-wide facts written\n\n") + fmt.Printf(" %sNext: run the agent sessions to see Claude use the memory:%s\n\n", ansiBold, ansiReset) + fmt.Printf(" %smake session-recruiter-alice%s — recruiter records Alice Chen's interview\n", ansiCyan, ansiReset) + fmt.Printf(" %smake session-hiring-manager%s — hiring manager searches cold\n", ansiCyan, ansiReset) + fmt.Printf(" %smake session-recruiter-bob%s — recruiter records Bob + procedural lesson\n", ansiCyan, ansiReset) + fmt.Printf(" %smake session-recruiter-charlie%s — recruiter records Charlie (HIRE)\n", ansiCyan, ansiReset) + fmt.Printf(" %smake session-crystallize%s — crystallize phone-screen pattern → Skill\n\n", ansiCyan, ansiReset) + fmt.Printf(" %smake demo%s — run all sessions in sequence\n\n", ansiPurple, ansiReset) +} + +func typeIcon(t string) string { + switch t { + case "semantic": + return "🧠" + case "episodic": + return "📅" + case "procedural": + return "📋" + default: + return "💾" + } +} + +func typeColor(t string) string { + switch t { + case "semantic": + return ansiCyan + case "episodic": + return ansiYellow + case "procedural": + return ansiPurple + default: + return ansiWhite + } +} + +func truncate(s string, n int) string { + s = strings.Join(strings.Fields(s), " ") + if len(s) <= n { + return s + } + return s[:n-1] + "…" +} + +func pause() { time.Sleep(200 * time.Millisecond) } + +func fatalf(format string, args ...any) { + fmt.Fprintf(os.Stderr, col(ansiGreen, "ERROR: ")+format+"\n", args...) + os.Exit(1) +} diff --git a/demo/recruiter/config/memory-server.yaml.tmpl b/demo/recruiter/config/memory-server.yaml.tmpl new file mode 100644 index 0000000000..9e5642ad2b --- /dev/null +++ b/demo/recruiter/config/memory-server.yaml.tmpl @@ -0,0 +1,15 @@ +storage: + provider: sqlite + dsn: DEMO_DB_PATH + +embedder: + provider: ollama + url: http://localhost:11434 + model: nomic-embed-text + +server: + name: "ToolHive Memory Demo" + version: "0.1.0" + host: "127.0.0.1" + port: 8765 + lifecycle_hours: 720 diff --git a/demo/recruiter/data/.gitignore b/demo/recruiter/data/.gitignore new file mode 100644 index 0000000000..32b11de842 --- /dev/null +++ b/demo/recruiter/data/.gitignore @@ -0,0 +1,2 @@ +# Claude-generated runbooks and artifacts from demo sessions +*.md diff --git a/demo/recruiter/data/job-description.txt b/demo/recruiter/data/job-description.txt new file mode 100644 index 0000000000..1de97f0c97 --- /dev/null +++ b/demo/recruiter/data/job-description.txt @@ -0,0 +1,88 @@ +Senior Software Engineer, Platform (Go) +Stacklok · Remote (US only) + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +ABOUT STACKLOK + +Stacklok is an open-source security company building tools that help development +teams trust their software supply chain. We build ToolHive — a lightweight, +secure manager for Model Context Protocol (MCP) servers — and Minder, an +open-source policy engine for your software supply chain. + +We are a small, fully remote team. Everyone ships production code. We have a +strong bias toward async communication and written clarity over meetings. + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +COMPENSATION + +Base salary: $100,000 – $150,000 USD (depending on experience and location) +Total comp includes meaningful equity at an early-stage company. + +We do NOT sponsor US work visas. Candidates must be authorized to work in +the United States without sponsorship now or in the future. + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +THE ROLE + +We are looking for a Senior Software Engineer to join our platform team. You +will work on the core infrastructure powering ToolHive and contribute to our +Kubernetes operator (thv-operator). + +WHAT YOU WILL DO + +• Design and implement distributed systems components for MCP server management +• Contribute to the Kubernetes operator using controller-runtime and CRDs +• Build and maintain the MCP proxy runner and authentication middleware +• Work across pkg/runner, pkg/auth, and cmd/thv layers of the codebase +• Write clean, well-tested Go with a focus on operational simplicity +• Review code, mentor peers, and raise the engineering bar across the team + +WHAT WE ARE LOOKING FOR + +• 5+ years of Go experience in production environments +• Strong distributed systems fundamentals (consensus, failure modes, back-pressure) +• Hands-on experience with Kubernetes and controller-runtime +• Familiarity with container runtimes (Docker, containerd) +• Strong written communication — we are async-first and documentation matters +• Comfort with open-source development and working in public + +NICE TO HAVE + +• Experience with MCP (Model Context Protocol) or similar tool-use protocols +• Contributions to open-source Go projects +• Experience with OCI registries and artifact management +• Familiarity with OIDC, OAuth 2.0, and modern auth patterns + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +WHAT WE OFFER + +• Fully remote (US only, EST/PST timezone preferred) +• $100,000–$150,000 base salary + equity +• Comprehensive health, dental, and vision insurance +• $3,000 annual learning & development stipend +• 4 weeks PTO + company holidays +• Home office stipend +• Async-first culture — no mandatory meetings before 9am your local time + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +HIRING PROCESS + +1. Recruiter phone screen (30 min) — role alignment, logistics, expectations +2. Technical screen with Go exercises (90 min) +3. System design interview (60 min) +4. Final loop: hiring manager + team lead (90 min) +5. Reference checks → offer + +We move quickly. Most candidates complete the process in 2–3 weeks. + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +TO APPLY + +Apply via our careers page or reach out directly at careers@stacklok.com. +Please include a link to a Go project you are proud of — open source preferred. diff --git a/demo/recruiter/demo.tape b/demo/recruiter/demo.tape new file mode 100644 index 0000000000..ddc026bb6c --- /dev/null +++ b/demo/recruiter/demo.tape @@ -0,0 +1,114 @@ +# ToolHive Memory Demo — The Recruiter +# Requires: vhs (brew install vhs), Ollama running, nomic-embed-text pulled, +# Claude Code CLI installed and authenticated (npm i -g @anthropic-ai/claude-code) +# +# Usage: +# cd demo/recruiter +# vhs demo.tape + +Output demo.gif +Set FontSize 14 +Set Width 160 +Set Height 50 +Set Theme "Dracula" +Set TypingSpeed 40ms +Set PlaybackSpeed 1.0 + +# ── title card ──────────────────────────────────────────────────────────────── +Hide +Type "echo ''" +Enter +Sleep 500ms +Show + +Type "# ToolHive Memory Demo — The Recruiter" +Sleep 1.5s +Enter +Sleep 300ms + +Type "# Hiring a Senior Go Engineer at Stacklok" +Sleep 1.5s +Enter +Sleep 500ms + +# ── step 1: change to demo dir ──────────────────────────────────────────────── +Type "cd demo/recruiter" +Sleep 500ms +Enter +Sleep 300ms + +# ── step 2: build ───────────────────────────────────────────────────────────── +Type "make build" +Sleep 500ms +Enter +Sleep 15s + +# ── step 3: start server ────────────────────────────────────────────────────── +Type "make server-start server-wait" +Sleep 500ms +Enter +Sleep 8s + +# ── step 4: setup phase (resource upload + semantic memories) ───────────────── +Type "# Phase 1-2: Upload job description and prime shared memory" +Sleep 1s +Enter +Sleep 300ms +Type "make demo" +Sleep 500ms +Enter + +# Phase 1: Resource upload + print JD (fast — pure HTTP) +Sleep 5s + +# Phase 2: Semantic memories (3 Ollama embedding calls, ~3s each) +Sleep 12s + +# ── step 5: recruiter session — Alice Chen ──────────────────────────────────── +Type "make session-recruiter-alice" +Sleep 500ms +Enter + +# Claude Code: checks visa policy via memory_search, records 3 episodic memories +Sleep 30s + +# ── step 6: hiring manager cold search ─────────────────────────────────────── +Type "make session-hiring-manager" +Sleep 500ms +Enter + +# Claude Code: 4 independent memory_search calls, reports pipeline state +Sleep 30s + +# ── step 7: recruiter session — Bob Martinez ───────────────────────────────── +Type "make session-recruiter-bob" +Sleep 500ms +Enter + +# Claude Code: checks pipeline, records episodic + procedural memory +Sleep 35s + +# ── step 8: recruiter session — Charlie Kim ─────────────────────────────────── +Type "make session-recruiter-charlie" +Sleep 500ms +Enter + +# Claude Code: retrieves phone-screen checklist, records Charlie advancing +Sleep 30s + +# ── step 9: crystallize pattern into a Skill ────────────────────────────────── +Type "make session-crystallize" +Sleep 500ms +Enter + +# Claude Code: lists procedural memory, crystallizes → SKILL.md scaffold +Sleep 25s + +# ── step 10: teardown ──────────────────────────────────────────────────────── +Type "make teardown" +Sleep 500ms +Enter +Sleep 3s + +# Final pause +Sleep 2s diff --git a/demo/recruiter/prompts/crystallize.txt b/demo/recruiter/prompts/crystallize.txt new file mode 100644 index 0000000000..57b4a5a285 --- /dev/null +++ b/demo/recruiter/prompts/crystallize.txt @@ -0,0 +1 @@ +We've wrapped up the first week of screens for the Senior Go Engineer role. I want to turn what we learned into something the whole recruiting team can reuse — a proper runbook for future phone screens so we're not reinventing this every time. Can you put that together based on what we've figured out? diff --git a/demo/recruiter/prompts/hiring-manager.txt b/demo/recruiter/prompts/hiring-manager.txt new file mode 100644 index 0000000000..7a3ab9b116 --- /dev/null +++ b/demo/recruiter/prompts/hiring-manager.txt @@ -0,0 +1 @@ +Hey, I'm the hiring manager for the Senior Go Engineer opening. I haven't been in the loop on recruiting — can you catch me up? I'd like to know who's been screened so far, where they stand, what the approved comp range looks like, and a reminder of what we're actually hiring for. diff --git a/demo/recruiter/prompts/recruiter-alice.txt b/demo/recruiter/prompts/recruiter-alice.txt new file mode 100644 index 0000000000..685b431e37 --- /dev/null +++ b/demo/recruiter/prompts/recruiter-alice.txt @@ -0,0 +1,5 @@ +I just finished a phone screen with Alice Chen for the Senior Go Engineer role. She has 8 years of Go experience, strong distributed systems background, currently at Google. Technically she looks great. Her ask is $160K base. + +One issue: she mentioned she's on OPT and would need an H1-B transfer — her OPT expires in about 6 months. Before I move her forward I want to make sure that's not a blocker. Can you check if we have any policy on that? + +Also, once you've got that, can you log the outcome of this screen for me so the hiring manager and any other recruiter can see it? diff --git a/demo/recruiter/prompts/recruiter-bob.txt b/demo/recruiter/prompts/recruiter-bob.txt new file mode 100644 index 0000000000..73391bc1c4 --- /dev/null +++ b/demo/recruiter/prompts/recruiter-bob.txt @@ -0,0 +1,5 @@ +Just wrapped a screen with Bob Martinez. US citizen, no visa issues. 5 years Go, mostly microservices work. His ask is $145K which seems fine. Technically though he struggled — I asked him to walk me through how Raft works and he couldn't get through it. Felt like he was below bar for senior level. + +I'm going to archive him. Can you log that and update the pipeline? + +Also, this is the second screen in a row where something obvious knocked the candidate out early — visa with Alice, distributed systems fundamentals with Bob. I feel like we could save everyone time if we checked those things right at the start of each call. Worth noting that pattern somewhere so I don't forget it. diff --git a/demo/recruiter/prompts/recruiter-charlie.txt b/demo/recruiter/prompts/recruiter-charlie.txt new file mode 100644 index 0000000000..4f362ae0d3 --- /dev/null +++ b/demo/recruiter/prompts/recruiter-charlie.txt @@ -0,0 +1,7 @@ +About to jump on a screen with Charlie Kim. Before I start — do we have anything on how to run these calls? I remember we were going to standardize the approach after the last couple of screens didn't go well. + +--- + +[30 minutes later] + +That went really well. Charlie has 7 years of Go, spent time at AWS building distributed storage. US citizen. Asking $140K. He explained Raft clearly and even got into leader election edge cases without me prompting him. Great communicator too. I want to move him to the full interview loop. Can you record this and give me a quick status on where the pipeline stands overall? diff --git a/demo/recruiter/slides.html b/demo/recruiter/slides.html new file mode 100644 index 0000000000..0170bfb5f6 --- /dev/null +++ b/demo/recruiter/slides.html @@ -0,0 +1,1111 @@ + + + + + +ToolHive Memory — Agentic Shared Memory + + + + + +
+
+ + +
+

ToolHive Memory

+

Persistent Memory Across Agent Sessions

+

+ A persistent memory layer that any MCP-compatible agent can + write, + search, and + manage — + through a standard MCP interface. +

+
+ MCP + SQLite + sqlite-vec + Ollama embeddings + Go +
+
+ + +
+
Part 1
+

The Problem & the Research

+

+ Why agentic memory is hard, and what the research tells us to do about it. +

+
+ + +
+

The Problem

+

Every agent session starts with no memory of what came before.

+ + + + + + + + + Claude Code + Recruiter A + + + Claude Code + Hiring Manager + + + + + + + + + + ❌ Knowledge from Recruiter A + invisible to the hiring manager session + + + ❌ Decisions made twice + no shared context, no continuity + + no shared state + + +
+ A recruiter learns a visa policy the hard way. The next session repeats the mistake. + The hiring manager has no idea where the pipeline stands. +
+
+ + +
+

What the Research Tells Us

+

Three bodies of work converge on the same failure modes.

+ +
+
+
A Survey on the Memory Mechanism of LLM-based Agents
+
arXiv 2603.07670 · 2025
+
Most systems neglect the manage phase — consolidation, conflict resolution, staleness. Silent contradictions are the most common failure mode.
+
+
+
LinkedIn Cognitive Memory Agent (CMA)
+
LinkedIn Engineering Blog · 2024
+
Hierarchical episodic + semantic + procedural memory at LinkedIn scale. Trust-weighted retrieval and per-application tenant isolation in production.
+
+
+
Practical Guide to LLM Memory Systems
+
Towards Data Science · 2024
+
Trust differentiation between human and agent memory is widely neglected. Staleness is the second leading cause of retrieval degradation.
+
+
+ +
+
+ Key insight #1 — Memory isn't just storage. Write, retrieval, and management (consolidation, conflict, staleness) all need first-class design. +
+
+ Key insight #2 — Human-authored memories must outrank agent-authored ones. Trust differentiation is the most neglected dimension in practice. +
+
+
+ + +
+

Five Design Tensions

+

+ Every memory system is pulled along five axes that tug in opposite directions. + The right balance shifts with the application — a medical triage agent operates under a very different + faithfulness–efficiency frontier than a recipe recommender. + (arXiv 2603.07670) +

+ + +
+ + +
+
+
⚡ Utility
+
vs Efficiency
+
+
+
Maximising utility tempts you to store everything — bloating storage and retrieval cost. Aggressive compression silently discards the one rare fact that matters three weeks later.
+
+
+
Store semantically compressed memories, not raw transcripts. Trust scores surface high-value entries first so retrieval cost stays bounded.
+
+
+ + +
+
+
🎯 Faithfulness
+
vs Adaptivity
+
+
+
Stale or hallucinated recall can be worse than no recall at all. But locking down memory prevents it from reflecting a world that changes.
+
+
+
Conflict detection on write. Staleness scoring updated every 24h. Human-authored entries always outrank agent entries. Correction history preserved.
+
+
+ + +
+
+
🔄 Adaptivity
+
vs Stability
+
+
+
Memory that updates freely drifts — an agent can overwrite good knowledge with a bad observation. Full retrains are expensive and disrupt continuity.
+
+
+
Incremental writes via memory_remember. Updates create a revision history — nothing is ever silently overwritten. Flag + review before trust score recovers.
+
+
+ + +
+
+
📊 Efficiency
+
latency · tokens · storage
+
+
+
Every retrieved memory costs tokens in the context window. Large retrieval sets slow inference and dilute relevance. Embedding calls add latency on every write.
+
+
+
Embed once on write, never on read. Single-pass vector search with configurable top_k. TTL-based expiry keeps the store bounded. Score-weighted ranking cuts noise.
+
+
+ + +
+
+
🏛️ Governance
+
privacy · deletion · policy
+
+
+
Memory systems accumulate sensitive data. Without explicit deletion and access controls, memory becomes a liability — agents can recall things they shouldn't.
+
+
+
Explicit memory_forget. Agents can only delete their own entries; humans can delete any. TTL for automatic expiry. Read-only Resources protected by ErrReadOnly.
+
+
+ +
+ +
+
Tension
+
Our position
+
+
+ + +
+

Our Approach vs. the Literature

+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
DimensionLinkedIn CMAToolHive Memory
Conflict detectionOn roadmap✅ Built — cosine sim > 0.85 on write
Trust / stalenessTime-based, on roadmap✅ Built — formula + 24 h background job
User controlPlanned✅ list · update · flag · forget
Memory typesEpisodic + Semantic + Procedural+ Resource (read-only reference docs)
RetrievalLLM-orchestrated multi-stepSingle-pass vector — agent is the orchestrator
CrystallizationNot described✅ Procedural → versioned Skill (OCI)
Tenant isolationPer-application storage isolationAuth at proxy level — storage isolation on roadmap
+
+ +
+ We are ahead on conflict detection, trust scoring, and user control. + LinkedIn is ahead on hierarchical aggregation and tenant isolation — both planned for a later phase. +
+
+ + +
+
Part 2
+

Architecture

+

+ How it's built — components, memory types, and the full lifecycle. +

+
+ + +
+

Architecture

+

+ Memory is a system workload inside ToolHive — + auto-provisioned, singleton per scope, excluded from thv stop --all. + Fully pluggable storage, vector, and embedder backends. +

+ + + + + + + + + + + + + + Claude Code + Recruiter A + + + Claude Code + Hiring Manager + + + Claude Code + Recruiter B + + HTTP / MCP + + + + + + + + + + Memory Server + + MCP · STREAMABLE HTTP + + MCP TOOLS + + + memory_remember + store + conflict check + memory_search + vector similarity query + memory_crystallize + → Skill scaffold + resources/list · flag · forget + + + + + + + + + 💾 Storage + + + default + + SQLite + sqlite-vec + also: PostgreSQL + pgvector · Qdrant · Weaviate · Pinecone + swap provider in memory-server.yaml — no code changes + + + + 🔍 Vector Store + + default + sqlite-vec (embedded, zero infra) + also: Qdrant · pgvector · Weaviate · Pinecone + PostgreSQL + pgvector collapses storage + vector into one + + + + 🔤 Embedder + + default + Ollama (local, nomic-embed-text) + also: OpenAI · Cohere · Google Vertex AI + teams with no local GPU use OpenAI or Cohere with zero changes + + + +
+
+
💻 Local (thv CLI)
+
+ thv memory init
+ Personal memory, local container.
+ SQLite + sqlite-vec defaults.
+ State in ~/.local/state/toolhive/ +
+
+
+
👥 Team (thv serve)
+
+ Shared instance, all agents connect
+ through the ToolHive API proxy.
+ Auth via existing OIDC middleware.
+ Postgres + Qdrant recommended. +
+
+
+
☸️ Kubernetes (thv-operator)
+
+ MCPMemoryServer CRD.
+ Operator reconciles → Deployment
+ + Service + PVC automatically.
+ Registered in MCPRegistry. +
+
+
+
+ + +
+

Hierarchical Memory roadmap

+

+ Today memory is a flat namespace. The path to layered, scope-aware memory is well defined. +

+ +
+ +
+

TODAY — tags as a workaround

+
+ Tag memories with a project label:
+ tags: ["project:payments"]

+ Filter in search:
+ memory_search("auth", tags=["project:payments"])

+ ⚠ Convention only — no enforcement.
+ An unfiltered search still returns everything.
+
+
+ + +
+

ROADMAP — namespace isolation

+
+ Add Namespace to every entry.
+ Proxy stamps it from OIDC token or project context — agents never set it.

+ Search walks up the hierarchy:
+ projectteamglobal

+ One schema migration. No API surface change. +
+
+
+ + + + + + + + + + + + 🌍 global + company policies · salary bands · shared skills + + + + + + + + + 👥 team: platform-eng + infra patterns · on-call runbooks + + + 👥 team: recruiting + pipeline · comp bands · interview SOP + + + 👥 team: security + threat models · CVE notes + + + + + project: thv-memory + project: operator + +
+ + +
+

Memory Types

+

Four types cover reference docs, facts, events, and learned processes.

+ +
+
+
📄
+
+
Resource
+
Read-only reference documents uploaded via REST API. Agents discover via resources/list or memory_search — never writable by agents. source: resource
+
+
+
+
🧠
+
+
Semantic
+
Aggregated facts and domain knowledge — things that are durably true. Conflict detection (cosine sim > 0.85) prevents silent contradictions on write. human | agent
+
+
+
+
📅
+
+
Episodic
+
Time-indexed event records — things that happened. Phone screens, decisions, observations. Queryable by tags and time range. e.g. "Alice Chen screened 2024-03-10"
+
+
+
+
🔧
+
+
Procedural
+
Learned behaviors and processes. Emerges from episodic patterns. Can be crystallized into a versioned Skill (OCI artifact) the whole team can reuse. crystallizable → Skill
+
+
+
+
+ + +
+

Memory Lifecycle

+

Memory is managed — not just stored.

+ + + + + + + + + + + + + Remember + embed + conflict check + + + Search + vector + trust-ranked + + + Flag / Update + correct or mark stale + + + Consolidate + merge related entries + + + Crystallize + → Skill runbook + + + + + + + + + Background job every 24h: recomputes trust & staleness scores · expires TTL'd entries · surfaces consolidation candidates (sim > 0.92) + + + +
+
+ Trust score = author weight × age decay × correction penalty × flag penalty.
+ Human-authored entries always outrank agent-authored ones. +
+
+ Conflict detection on write — cosine similarity > 0.85 returns matching entries. The agent (with context) decides: force-write, update, or abort. +
+
+
+ + +
+

From Memory to Skill

+

Fluid procedural knowledge crystallizes into versioned, distributable runbooks.

+ + + + + + + + + + + + + 🔧 Procedural memory + + + 🔧 Procedural memory + + + 🔧 Procedural memory + + + + + + + memory_crystallize + drafts SKILL.md scaffold + human authors + reviews + + + thv skills push + + + Skill (OCI artifact) + versioned · immutable + distributable across teams + + archived with crystallized_into pointer + + +
+ Procedural memories and Skills are the same knowledge at different stages of maturity — + fluid and evolving in memory, crystallized and versioned as a Skill. +
+
+ + +
+
Part 3
+

Live Demo

+

+ The Recruiter — hiring a Senior Go Engineer at Stacklok +

+
+ + +
+

The Recruiter Scenario

+

A one-week hiring process — every session shares a single memory server.

+ +
+
+

Cast

+
+ 👩 Recruiter — runs phone screens
+ 🧑‍💼 Hiring Manager — checks pipeline cold
+ 🖥️ Memory Server — one server, all sessions +
+
+
+

What we'll see

+
+ ✅ Policy recalled without being told
+ ✅ Cross-session knowledge sharing
+ ✅ Process learned from repeated failure
+ ✅ Runbook born from lived experience +
+
+
+ +
+ claude --mcp-config .demo.mcp.json +  —  the only integration needed +
+
+ + +
+

Scenario — All Phases

+
    +
  • +
    1
    +
    + Resource upload  resource
    + Job description registered as a read-only MCP Resource — discoverable, never modifiable +
    +
  • +
  • +
    2
    +
    + Semantic memory  semantic
    + 3 company-wide facts: no visa sponsorship · salary $100–150K · remote US async culture +
    +
  • +
  • +
    3
    +
    + Alice Chen  episodic
    + Strong candidate, needs H1-B. Agent searches policy, finds blocker, logs outcome +
    +
  • +
  • +
    4
    +
    + Hiring Manager — cold search  semantic episodic
    + Joins with zero context. Memory gives full pipeline status, comp band, and JD +
    +
  • +
  • +
    5
    +
    + Bob Martinez  episodic procedural
    + Archived. Recruiter spots a pattern → agent writes a procedural memory unprompted +
    +
  • +
  • +
    6
    +
    + Charlie Kim  procedural episodic
    + Retrieves the checklist (written by a different session). Applies it. Charlie advances. +
    +
  • +
  • +
    7
    +
    + Crystallize → Skill  procedural
    + One week of screens → phone-screen runbook the whole recruiting team can reuse +
    +
  • +
+
+ + +
+
+
Phases 1 – 2
+
Setup — prime the memory server
+
+ +
+
+

Phase 1 · Resource

+
+ POST /api/resources

+ Job description uploaded as a read-only MCP Resource.
+ Agents discover it via memory_search or resources/list.
+ Protected by ErrReadOnly — no agent can modify it. +
+
+
+

Phase 2 · Semantic memory

+
+ 🧠  Company does not sponsor US work visas for any engineering role +
+
+ 🧠  Senior Go Engineer base: $100K–$150K + equity +
+
+ 🧠  Engineering team fully remote, US timezone, async-first +
+
+
+

Written once by the setup script — recalled by any agent session at any time.

+
+ + +
+
+
Phase 3
+
Recruiter — Alice Chen phone screen
+
+ +
+
+

The recruiter says

+
+ "She mentioned she's on OPT and would need an H1-B transfer. Before I move her forward I want to make sure that's not a blocker. Can you check if we have any policy on that?" +
+
+
+

What the agent does

+
+
memory_search("visa sponsorship policy")
+
Finds: "Company does not sponsor US work visas"
+
memory_remember (episodic)
+
Alice Chen — OPT / H1-B needed → archived
+
+
+
+ +
+ The agent found a policy it was never told about in this session — + retrieved from shared memory written in Phase 2. +
+
+ + +
+
+
Phase 4
+
Hiring Manager — cold pipeline review
+
+ +

+ "I haven't been in the loop on recruiting — can you catch me up? Pipeline status, approved comp range, and a reminder of what we're hiring for." +

+ +
+
memory_search("candidates screened")
Alice Chen → archived (visa). 1 of 1.
+
memory_search("visa sponsorship")
No sponsorship policy. Explains Alice.
+
memory_search("salary compensation")
$100K–$150K base + equity. Approved band.
+
memory_search("job description")
Retrieves the JD Resource. Full requirements.
+
+ +

The hiring manager never spoke to the recruiter. The memory server was the handoff.

+
+ + +
+
+
Phase 5
+
Recruiter — Bob Martinez + a lesson learned
+
+ +
+
+
+ "…this is the second screen in a row where something obvious knocked the candidate out early. I feel like we could save everyone time if we checked those things right at the start of each call. Worth noting that pattern somewhere." +
+
+
+
+ 🔧 Procedural memory written

+ "Phone screen gate: (1) confirm work-auth in first 5 min, (2) ask candidate to explain Raft — weak answers correlate with underperformance on distributed systems work." +
+
+
+ +
+ The recruiter said "worth noting." The agent recognised it as a reusable process and chose the right memory type without being asked. +
+
+ + +
+
+
Phase 6
+
Recruiter — Charlie Kim (HIRE)
+
+ +
+
+

Before the screen

+
+ "About to jump on a screen with Charlie Kim. Do we have anything on how to run these calls?" +
+
+ → memory_search("phone screen process")
+ Retrieves the gate checklist from Phase 5 +
+
+
+

Post-screen results

+
✅ US citizen — work-auth clear (first question)
+
✅ $140K ask — within $100–150K band
+
✅ Explained Raft + leader election edge cases unprompted
+

→ Advancing to interview loop

+
+
+ +
+ The checklist was written by a different agent session. This session retrieved and applied it cold. +
+
+ + +
+
+
Phase 7
+
Crystallize — one week of screens → a reusable Skill
+
+ +
+ "We've wrapped the first week of screens. I want to turn what we learned into something the whole recruiting team can reuse — a proper runbook for future phone screens." +
+ +
+
+
Agent retrieves & crystallizes
+
+ → memory_list(type=procedural)
+ → memory_search("phone screen patterns")
+ → memory_crystallize([ids…], name="...") +
+
+
+
+ Output: SKILL.md scaffold
+ Knockout gates · Technical depth probe
+ Decision rubric · Post-call logging template

+ Human reviews → thv skills push → OCI artifact.
+ Originals archived with crystallized_into pointer. +
+
+
+ +

One week of lived experience → a versioned runbook any recruiter can follow from day one.

+
+ + +
+

What We Just Saw

+ +
    +
  • +
    📄
    +
    Resource — reference docs agents discover through MCP; protected from modification
    +
  • +
  • +
    🧠
    +
    Semantic — company-wide facts written once, recalled by any session with no explicit handoff
    +
  • +
  • +
    📅
    +
    Episodic — time-indexed events building a shared pipeline log across recruiter sessions
    +
  • +
  • +
    🔧
    +
    Procedural — process knowledge that emerged from failure; retrieved by a session that never wrote it
    +
  • +
  • +
    +
    Crystallization — lived team experience promoted into a versioned Skill the whole org can distribute
    +
  • +
+ +
+ Any MCP-compatible agent. One config file. Shared memory across every session. +
+
+ + +
+

References

+
+
+
A Survey on the Memory Mechanism of LLM-based Agents
+
arXiv:2603.07670 · 2025
+
Taxonomy of memory types and operations lifecycle (acquire, manage, utilize). The "manage" phase — consolidation, conflict resolution, staleness — is the most neglected in practice. Informed our lifecycle design and the 24 h background job.
+
+
+
The LinkedIn Generative AI Application Tech Stack: Personalization with Cognitive Memory Agent
+
LinkedIn Engineering Blog · 2024
+
Production deployment of hierarchical episodic + aggregated semantic + procedural memory at LinkedIn scale. Trust-weighted retrieval and per-application isolation. Informed our comparison table and the prioritisation of conflict detection and user control.
+
+
+
A Practical Guide to Implementing Memory in LLM Applications
+
Towards Data Science · 2024
+
Practitioner analysis of memory degradation: staleness and trust neglect are the top two causes. Recommends human-authored memory outranking agent-authored, and explicit staleness scoring — both implemented here.
+
+
+
+ +
+
+ + + + + diff --git a/docs/proposals/2026-04-22-shared-memory-server.md b/docs/proposals/2026-04-22-shared-memory-server.md new file mode 100644 index 0000000000..899d7b15e1 --- /dev/null +++ b/docs/proposals/2026-04-22-shared-memory-server.md @@ -0,0 +1,312 @@ +# Shared Long-Term Memory Server + +**Date:** 2026-04-22 +**Status:** Implementation in progress (Plan 1 of 3 complete) + +--- + +## Problem + +ToolHive manages MCPs (tools) and Skills (procedural knowledge as OCI artifacts). The missing +primitive is **shared long-term memory**: a team-wide knowledge store that agents can query and +contribute to across sessions. + +Without it, every agent session starts cold. Facts learned by one agent are invisible to others. +Patterns that emerge from repeated interactions are lost when the session ends. + +--- + +## Memory Types + +Two long-term memory namespaces are in scope: + +| Type | Purpose | Example | +|---|---|---| +| `semantic` | Aggregated facts and world-state knowledge | "Company does not sponsor visas" | +| `procedural` | How-to knowledge, heuristics, SOPs | "Always run `task lint-fix` before committing" | +| `episodic` | Time-indexed event records | "Recruiter archived candidate on 2024-03-15 — visa required" | + +**Out of scope:** working memory and conversational memory — agents handle those internally via +their context window. + +--- + +## Architecture + +### System Workload + +The memory server is ToolHive's first **system workload** — a managed MCP server auto-provisioned +by ToolHive rather than explicitly started by users. Key properties: + +- Auto-provisioned on first use (`thv memory init`) +- Persistent — excluded from `thv stop --all` +- Singleton per scope (one per team in `thv serve` mode) +- Registered in the registry under the reserved name `toolhive.memory` + +### Transport + +The memory server uses **MCP streamable HTTP** transport (not stdio). Agents connect via +`http://:8080/mcp`. A `/health` liveness probe is available at the same host. + +### Pluggable Backends + +Three independent interfaces, configured via `memory-server.yaml`: + +```yaml +storage: + provider: sqlite # sqlite (default) | postgres | mongodb + dsn: /data/memory.db + +vector: + provider: sqlite-vec # sqlite-vec (default) | qdrant | pgvector + url: "" + +embedder: + provider: ollama # ollama (default) | openai | cohere + model: nomic-embed-text + url: http://localhost:11434 + +server: + host: 0.0.0.0 + port: 8080 + lifecycle_interval_hours: 24 +``` + +Zero-infra teams use SQLite defaults with no external dependencies. Teams with Postgres can +collapse both storage and vector into pgvector. + +### Deployment Modes + +**Local (`thv` CLI):** Personal memory, local container, SQLite defaults. + +**Team (`thv serve`):** Shared instance; all team agents connect via the API server proxy. +Auth enforced via existing OIDC middleware. + +**Kubernetes (`thv-operator`):** New `MCPMemoryServer` CRD (Plan 3). Operator reconciles to +`Deployment + Service + PVC`. + +--- + +## MCP Tool Surface + +Agents consume memory exactly like any other MCP — no special integration. + +| Tool | Description | +|---|---| +| `memory_remember` | Write a memory. Runs conflict detection; returns conflicts if similarity > 0.85 | +| `memory_search` | Semantic vector search, results ranked by composite trust+staleness score | +| `memory_recall` | Fetch a specific entry by ID, including full revision history | +| `memory_forget` | Delete a memory | +| `memory_update` | Correct content; previous version saved to revision history | +| `memory_flag` | Mark as potentially stale without deleting | +| `memory_list` | Structured listing with filters: type, author, tags, time-range | +| `memory_consolidate` | Merge related entries; originals archived with pointer | +| `memory_crystallize` | Promote procedural memories to a Skill scaffold for human authoring | + +### Conflict Detection + +On `memory_remember`, the server embeds the new content and searches for similar active entries. +If any entry has cosine similarity > 0.85, the write is blocked and the agent receives a +`conflict_detected` response with the conflicting entries. The agent decides: force-write, +update the existing entry, or abort. No LLM inference — the agent (which has context) is better +placed to judge whether two similar entries actually conflict. + +### Search Ranking + +`memory_search` returns results ranked by a composite score that combines vector similarity with +the entry's trust and staleness signals: + +``` +composite = similarity × trust_score × (1 - 0.3 × staleness_score) +``` + +This prevents a high-similarity but flagged or stale entry from ranking above a fresher, +more trusted one. + +--- + +## Trust and Staleness Scoring + +### Trust Score + +``` +trust_score = author_weight + × age_decay(created_at, half_life=180d) + × (1 - min(corrections × 0.05, 0.30)) + × (0.5 if flagged else 1.0) + +author_weight: human=1.0, agent=0.7 +``` + +### Staleness Score + +``` +staleness_score = normalize(days_since_last_access, max=90d) + + (0.3 if flagged) + + min(corrections × 0.1, 0.3) +``` + +Entries with `staleness_score > 0.8` surface in the lifecycle audit log every 24 hours. + +--- + +## Skills Relationship + +Skills (existing) and procedural memory are the same kind of knowledge at different stages of +maturity: + +``` +Agent/human observes something + │ + ▼ + Procedural Memory ← fluid, emergent, evolving + (memory server) + │ + (patterns emerge, + human crystallizes) + │ + ▼ + Skill (OCI) ← crystallized, versioned, distributed + (existing skills system) +``` + +`memory_crystallize` bridges the gap: it takes stable procedural memory entries and produces a +`SKILL.md` scaffold for a human to author and push via `thv skills push`. The source entries are +archived with a `crystallized_into` pointer so search returns the canonical Skill instead. + +--- + +## Recommended Memory Activation Strategy + +Not every agent interaction should touch the memory server. The recommended approach is a +three-tier strategy: + +### Tier 1 — Session-boundary injection (always) + +At the **start** of every task-bearing session, the system prompt instructs the agent to run one +`memory_search` call with the task description before doing anything else. This is silent, +cheap (one vector search), and covers the most valuable case: cross-session continuity. + +``` +Before starting work, call memory_search with the task description to load +relevant team knowledge. Do this once, silently — do not explain it to the user. +``` + +At the **end** of a session, the agent writes what was discovered or decided that would be +useful to a different agent in a future session. + +### Tier 2 — Signal-based mid-session reads (agent-decided) + +The system prompt instructs the agent to call `memory_search` when it encounters: + +1. **Uncertainty** — "I don't have enough context to answer this confidently" +2. **Cross-session references** — phrases like "last time", "previously", "we decided", + "our policy", "do you remember" +3. **Team-specific facts** — questions about preferences, conventions, or domain knowledge + not in the codebase or current context + +### Tier 3 — Write on observation, not speculation + +The agent calls `memory_remember` only for facts that: +- Were not already in the search results from Tier 1 +- Would be useful to a **different** agent in a **future** session + +The system prompt guidance: + +``` +Write a memory when you learn something that: +- corrects or refines an existing fact (use memory_update instead) +- is a team decision, constraint, or policy that will apply again +- is a recurring pattern observed more than once + +Do NOT write memories for facts already in the codebase, documentation, +or the current conversation context. +``` + +### Why not automatic ingestion? + +Auto-ingestion (LinkedIn's streaming pipeline approach) requires: +- An LLM call in the ingestion path to extract facts from raw transcripts +- Quality control to decide what is worth persisting +- Evaluation tooling to measure ingestion accuracy + +These are deferred to a later plan. The explicit tool-use model is more predictable and +debuggable for a v1, and the agent (which has full context) makes better judgments about +what is worth writing than a pipeline operating on raw text. + +--- + +## Comparison with LinkedIn's Cognitive Memory Agent + +LinkedIn's CMA (described in their [engineering blog](https://www.linkedin.com/blog/engineering/ai/the-linkedin-generative-ai-application-tech-stack-personalization-with-cognitive-memory-agent)) +is the closest public reference. Key differences: + +| Dimension | LinkedIn CMA | ToolHive Memory | +|---|---|---| +| Conflict detection | On roadmap | Implemented (cosine > 0.85) | +| Trust/staleness scoring | Time-based prioritization planned | Implemented (full formula + background job) | +| User control (list/update/delete/flag) | On roadmap | Implemented | +| Search ranking | Implicit | Composite score: similarity × trust × (1 − staleness penalty) | +| Episodic memory type | Distinct tier | `TypeEpisodic` with time-range `ListFilter` | +| Retrieval orchestration | LLM-powered multi-step planner | Agent calls tools directly (agent IS the orchestrator) | +| Hierarchical aggregation | Auto tree: events → summaries → facets | Explicit: `memory_consolidate` + `memory_crystallize` | +| Tenant isolation | Per-application isolated stores | Auth at proxy layer; storage-level namespace deferred | +| Auto ingestion pipeline | Streaming + batch LLM extraction | Deferred; `memory_distill` returns candidates for agent review | + +--- + +## Implementation Status + +### Plan 1 — Memory server core (this branch) + +- [x] `pkg/memory/` — domain types, interfaces (`Store`, `VectorStore`, `Embedder`), scoring, service +- [x] `pkg/memory/sqlite/` — SQLite Store + VectorStore (Go cosine similarity, no CGo) +- [x] `pkg/memory/embedder/ollama/` — Ollama HTTP embedder +- [x] `pkg/memory/mocks/` — gomock mocks for all three interfaces +- [x] `cmd/thv-memory/` — MCP server binary (streamable HTTP on `/mcp`) +- [x] `cmd/thv-memory/lifecycle/` — 24h background job (TTL expiry, score recomputation) +- [x] `cmd/thv-memory/tools/` — 9 MCP tool handlers +- [x] Integration test (SQLite + fake embedder, end-to-end remember → search → delete) + +### Plan 2 — CLI + system workload integration (not started) + +- `thv memory` subcommand tree +- System workload auto-provisioning (`thv memory init`) +- Registry integration under `toolhive.memory` + +### Plan 3 — Kubernetes operator (not started) + +- `MCPMemoryServer` CRD +- Operator controller: reconciles to `Deployment + Service + PVC` +- `MCPRegistry` integration + +--- + +## Package Layout + +``` +pkg/memory/ +├── types.go — Entry, Revision, ListFilter, VectorFilter, scoring types +├── interfaces.go — Store, VectorStore, Embedder interfaces + mockgen directives +├── service.go — Orchestration: conflict detection, remember, search +├── scoring.go — ComputeTrustScore, ComputeStalenessScore +├── errors.go — ErrNotFound +├── mocks/ — Generated gomock mocks +├── sqlite/ +│ ├── db.go — DB wrapper, WAL pragmas, goose migrations +│ ├── store.go — Store implementation +│ ├── vector.go — VectorStore implementation (Go cosine similarity) +│ └── migrations/ — goose SQL migrations +└── embedder/ + └── ollama/ — Ollama HTTP embedder + +cmd/thv-memory/ +├── main.go — Entry point, HTTP server lifecycle +├── server.go — MCP server construction, tool registration, HTTP handler +├── config.go — YAML config with defaults +├── lifecycle/ +│ └── job.go — Background maintenance job +└── tools/ — One file per MCP tool + ├── remember.go, search.go, recall.go, forget.go, update.go + ├── flag.go, list.go, consolidate.go, crystallize.go +``` diff --git a/pkg/memory/embedder/ollama/embedder.go b/pkg/memory/embedder/ollama/embedder.go new file mode 100644 index 0000000000..00ca3817ee --- /dev/null +++ b/pkg/memory/embedder/ollama/embedder.go @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package ollama provides a memory.Embedder backed by a local Ollama server. +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// Embedder calls the Ollama /api/embeddings endpoint. +type Embedder struct { + baseURL string + model string + dimensions int + client *http.Client +} + +// New creates an Ollama embedder. It probes the server once to discover the +// embedding dimension. Returns an error if the server is unreachable or the +// model returns an empty vector. +func New(baseURL, model string) (*Embedder, error) { + if _, err := url.ParseRequestURI(baseURL); err != nil { + return nil, fmt.Errorf("invalid Ollama URL %q: %w", baseURL, err) + } + if model == "" { + return nil, fmt.Errorf("model name is required") + } + e := &Embedder{baseURL: baseURL, model: model, client: &http.Client{}} + + emb, err := e.Embed(context.Background(), "probe") + if err != nil { + return nil, fmt.Errorf("probing Ollama embedder: %w", err) + } + e.dimensions = len(emb) + return e, nil +} + +// Embed calls the Ollama /api/embeddings endpoint and returns the vector. +func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) { + body, err := json.Marshal(map[string]string{"model": e.model, "prompt": text}) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/api/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := e.client.Do(req) + if err != nil { + return nil, fmt.Errorf("calling Ollama: %w", err) + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode) + } + + var result struct { + Embedding []float32 `json:"embedding"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("decoding Ollama response: %w", err) + } + if len(result.Embedding) == 0 { + return nil, fmt.Errorf("ollama returned empty embedding") + } + return result.Embedding, nil +} + +// Dimensions returns the fixed vector length produced by this embedder. +func (e *Embedder) Dimensions() int { return e.dimensions } + +var _ memory.Embedder = (*Embedder)(nil) diff --git a/pkg/memory/embedder/ollama/embedder_test.go b/pkg/memory/embedder/ollama/embedder_test.go new file mode 100644 index 0000000000..bb601fae1b --- /dev/null +++ b/pkg/memory/embedder/ollama/embedder_test.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ollama_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/memory/embedder/ollama" +) + +func TestEmbed(t *testing.T) { + t.Parallel() + + want := []float32{0.1, 0.2, 0.3} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/api/embeddings", r.URL.Path) + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": want}) + })) + t.Cleanup(srv.Close) + + e, err := ollama.New(srv.URL, "nomic-embed-text") + require.NoError(t, err) + require.Equal(t, 3, e.Dimensions()) + + got, err := e.Embed(context.Background(), "hello world") + require.NoError(t, err) + require.InDeltaSlice(t, want, got, 0.001) +} diff --git a/pkg/memory/errors.go b/pkg/memory/errors.go new file mode 100644 index 0000000000..77aa21a1da --- /dev/null +++ b/pkg/memory/errors.go @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import "errors" + +// ErrNotFound is returned when a memory entry does not exist. +var ErrNotFound = errors.New("memory entry not found") + +// ErrReadOnly is returned when an agent attempts to mutate an entry whose +// source type is read-only (SourceSkill or SourceResource). Use the +// management REST API to modify resource entries. +var ErrReadOnly = errors.New("memory entry is read-only") diff --git a/pkg/memory/interfaces.go b/pkg/memory/interfaces.go new file mode 100644 index 0000000000..5f26d39076 --- /dev/null +++ b/pkg/memory/interfaces.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import "context" + +//go:generate mockgen -destination mocks/mock_store.go -package mocks github.com/stacklok/toolhive/pkg/memory Store +//go:generate mockgen -destination mocks/mock_vector.go -package mocks github.com/stacklok/toolhive/pkg/memory VectorStore +//go:generate mockgen -destination mocks/mock_embedder.go -package mocks github.com/stacklok/toolhive/pkg/memory Embedder + +// Store is the structured persistence layer for memory entries. +// It handles CRUD, lifecycle transitions, and score updates. +// Implementations must be safe for concurrent use. +type Store interface { + Create(ctx context.Context, entry Entry) error + Get(ctx context.Context, id string) (Entry, error) + // Update replaces the content of an existing entry and appends the + // previous content to History. The embedding must be recomputed by + // the caller (Service) after this call succeeds. + Update(ctx context.Context, id string, content string, author AuthorType, correctionNote string) error + Flag(ctx context.Context, id string, reason string) error + Unflag(ctx context.Context, id string) error + Delete(ctx context.Context, id string) error + List(ctx context.Context, filter ListFilter) ([]Entry, error) + Archive(ctx context.Context, id string, reason ArchiveReason, ref string) error + IncrementAccess(ctx context.Context, id string) error + UpdateScores(ctx context.Context, id string, trustScore, stalenessScore float32) error + // ListExpired returns all active entries whose ExpiresAt is in the past. + ListExpired(ctx context.Context) ([]Entry, error) + // ListActive returns all non-archived entries for score recomputation. + ListActive(ctx context.Context) ([]Entry, error) +} + +// VectorStore stores and queries embedding vectors for memory entries. +// Implementations must be safe for concurrent use. +type VectorStore interface { + // Upsert stores or replaces the embedding for the given entry ID. + Upsert(ctx context.Context, id string, embedding []float32) error + // Search returns the topK entries most similar to query, restricted by filter. + Search(ctx context.Context, query []float32, topK int, filter VectorFilter) ([]ScoredID, error) + Delete(ctx context.Context, id string) error +} + +// Embedder converts text to a fixed-dimension float32 vector. +// Implementations must be safe for concurrent use. +type Embedder interface { + Embed(ctx context.Context, text string) ([]float32, error) + // Dimensions returns the fixed vector length produced by this embedder. + Dimensions() int +} diff --git a/pkg/memory/mocks/mock_embedder.go b/pkg/memory/mocks/mock_embedder.go new file mode 100644 index 0000000000..ee6fbc7b24 --- /dev/null +++ b/pkg/memory/mocks/mock_embedder.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/memory (interfaces: Embedder) +// +// Generated by this command: +// +// mockgen -destination mocks/mock_embedder.go -package mocks github.com/stacklok/toolhive/pkg/memory Embedder +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockEmbedder is a mock of Embedder interface. +type MockEmbedder struct { + ctrl *gomock.Controller + recorder *MockEmbedderMockRecorder + isgomock struct{} +} + +// MockEmbedderMockRecorder is the mock recorder for MockEmbedder. +type MockEmbedderMockRecorder struct { + mock *MockEmbedder +} + +// NewMockEmbedder creates a new mock instance. +func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder { + mock := &MockEmbedder{ctrl: ctrl} + mock.recorder = &MockEmbedderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder { + return m.recorder +} + +// Dimensions mocks base method. +func (m *MockEmbedder) Dimensions() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dimensions") + ret0, _ := ret[0].(int) + return ret0 +} + +// Dimensions indicates an expected call of Dimensions. +func (mr *MockEmbedderMockRecorder) Dimensions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dimensions", reflect.TypeOf((*MockEmbedder)(nil).Dimensions)) +} + +// Embed mocks base method. +func (m *MockEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Embed", ctx, text) + ret0, _ := ret[0].([]float32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Embed indicates an expected call of Embed. +func (mr *MockEmbedderMockRecorder) Embed(ctx, text any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Embed", reflect.TypeOf((*MockEmbedder)(nil).Embed), ctx, text) +} diff --git a/pkg/memory/mocks/mock_store.go b/pkg/memory/mocks/mock_store.go new file mode 100644 index 0000000000..29e3d1f860 --- /dev/null +++ b/pkg/memory/mocks/mock_store.go @@ -0,0 +1,214 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/memory (interfaces: Store) +// +// Generated by this command: +// +// mockgen -destination mocks/mock_store.go -package mocks github.com/stacklok/toolhive/pkg/memory Store +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + memory "github.com/stacklok/toolhive/pkg/memory" + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder + isgomock struct{} +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Archive mocks base method. +func (m *MockStore) Archive(ctx context.Context, id string, reason memory.ArchiveReason, ref string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Archive", ctx, id, reason, ref) + ret0, _ := ret[0].(error) + return ret0 +} + +// Archive indicates an expected call of Archive. +func (mr *MockStoreMockRecorder) Archive(ctx, id, reason, ref any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Archive", reflect.TypeOf((*MockStore)(nil).Archive), ctx, id, reason, ref) +} + +// Create mocks base method. +func (m *MockStore) Create(ctx context.Context, entry memory.Entry) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", ctx, entry) + ret0, _ := ret[0].(error) + return ret0 +} + +// Create indicates an expected call of Create. +func (mr *MockStoreMockRecorder) Create(ctx, entry any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockStore)(nil).Create), ctx, entry) +} + +// Delete mocks base method. +func (m *MockStore) Delete(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockStoreMockRecorder) Delete(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), ctx, id) +} + +// Flag mocks base method. +func (m *MockStore) Flag(ctx context.Context, id, reason string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flag", ctx, id, reason) + ret0, _ := ret[0].(error) + return ret0 +} + +// Flag indicates an expected call of Flag. +func (mr *MockStoreMockRecorder) Flag(ctx, id, reason any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flag", reflect.TypeOf((*MockStore)(nil).Flag), ctx, id, reason) +} + +// Get mocks base method. +func (m *MockStore) Get(ctx context.Context, id string) (memory.Entry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, id) + ret0, _ := ret[0].(memory.Entry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockStoreMockRecorder) Get(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), ctx, id) +} + +// IncrementAccess mocks base method. +func (m *MockStore) IncrementAccess(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementAccess", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// IncrementAccess indicates an expected call of IncrementAccess. +func (mr *MockStoreMockRecorder) IncrementAccess(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementAccess", reflect.TypeOf((*MockStore)(nil).IncrementAccess), ctx, id) +} + +// List mocks base method. +func (m *MockStore) List(ctx context.Context, filter memory.ListFilter) ([]memory.Entry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx, filter) + ret0, _ := ret[0].([]memory.Entry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockStoreMockRecorder) List(ctx, filter any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStore)(nil).List), ctx, filter) +} + +// ListActive mocks base method. +func (m *MockStore) ListActive(ctx context.Context) ([]memory.Entry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListActive", ctx) + ret0, _ := ret[0].([]memory.Entry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListActive indicates an expected call of ListActive. +func (mr *MockStoreMockRecorder) ListActive(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListActive", reflect.TypeOf((*MockStore)(nil).ListActive), ctx) +} + +// ListExpired mocks base method. +func (m *MockStore) ListExpired(ctx context.Context) ([]memory.Entry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListExpired", ctx) + ret0, _ := ret[0].([]memory.Entry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListExpired indicates an expected call of ListExpired. +func (mr *MockStoreMockRecorder) ListExpired(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListExpired", reflect.TypeOf((*MockStore)(nil).ListExpired), ctx) +} + +// Unflag mocks base method. +func (m *MockStore) Unflag(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unflag", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unflag indicates an expected call of Unflag. +func (mr *MockStoreMockRecorder) Unflag(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unflag", reflect.TypeOf((*MockStore)(nil).Unflag), ctx, id) +} + +// Update mocks base method. +func (m *MockStore) Update(ctx context.Context, id, content string, author memory.AuthorType, correctionNote string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", ctx, id, content, author, correctionNote) + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockStoreMockRecorder) Update(ctx, id, content, author, correctionNote any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), ctx, id, content, author, correctionNote) +} + +// UpdateScores mocks base method. +func (m *MockStore) UpdateScores(ctx context.Context, id string, trustScore, stalenessScore float32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateScores", ctx, id, trustScore, stalenessScore) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateScores indicates an expected call of UpdateScores. +func (mr *MockStoreMockRecorder) UpdateScores(ctx, id, trustScore, stalenessScore any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateScores", reflect.TypeOf((*MockStore)(nil).UpdateScores), ctx, id, trustScore, stalenessScore) +} diff --git a/pkg/memory/mocks/mock_vector.go b/pkg/memory/mocks/mock_vector.go new file mode 100644 index 0000000000..0df2f3d80c --- /dev/null +++ b/pkg/memory/mocks/mock_vector.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/memory (interfaces: VectorStore) +// +// Generated by this command: +// +// mockgen -destination mocks/mock_vector.go -package mocks github.com/stacklok/toolhive/pkg/memory VectorStore +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + memory "github.com/stacklok/toolhive/pkg/memory" + gomock "go.uber.org/mock/gomock" +) + +// MockVectorStore is a mock of VectorStore interface. +type MockVectorStore struct { + ctrl *gomock.Controller + recorder *MockVectorStoreMockRecorder + isgomock struct{} +} + +// MockVectorStoreMockRecorder is the mock recorder for MockVectorStore. +type MockVectorStoreMockRecorder struct { + mock *MockVectorStore +} + +// NewMockVectorStore creates a new mock instance. +func NewMockVectorStore(ctrl *gomock.Controller) *MockVectorStore { + mock := &MockVectorStore{ctrl: ctrl} + mock.recorder = &MockVectorStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockVectorStore) EXPECT() *MockVectorStoreMockRecorder { + return m.recorder +} + +// Delete mocks base method. +func (m *MockVectorStore) Delete(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockVectorStoreMockRecorder) Delete(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockVectorStore)(nil).Delete), ctx, id) +} + +// Search mocks base method. +func (m *MockVectorStore) Search(ctx context.Context, query []float32, topK int, filter memory.VectorFilter) ([]memory.ScoredID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Search", ctx, query, topK, filter) + ret0, _ := ret[0].([]memory.ScoredID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Search indicates an expected call of Search. +func (mr *MockVectorStoreMockRecorder) Search(ctx, query, topK, filter any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockVectorStore)(nil).Search), ctx, query, topK, filter) +} + +// Upsert mocks base method. +func (m *MockVectorStore) Upsert(ctx context.Context, id string, embedding []float32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", ctx, id, embedding) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockVectorStoreMockRecorder) Upsert(ctx, id, embedding any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockVectorStore)(nil).Upsert), ctx, id, embedding) +} diff --git a/pkg/memory/scoring.go b/pkg/memory/scoring.go new file mode 100644 index 0000000000..93997aff87 --- /dev/null +++ b/pkg/memory/scoring.go @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import ( + "math" + "time" +) + +const ( + authorWeightHuman = 1.0 + authorWeightAgent = 0.7 + halfLifeDays = 180.0 + maxCorrectionPenalty = 0.30 + correctionPenaltyPerCorrection = 0.05 + flagTrustMultiplier = 0.5 + maxStalenessAccessDays = 90.0 + flagStalenessBonus = 0.3 + correctionStalenessPerItem = 0.1 + maxCorrectionStaleness = 0.3 +) + +// ComputeTrustScore returns a value in [0,1] representing how trustworthy +// this memory entry is. Higher = more trustworthy. +// +// Formula: author_weight × age_decay × (1 - correction_penalty) × flag_multiplier +func ComputeTrustScore(entry Entry) float32 { + weight := authorWeightAgent + if entry.Author == AuthorHuman { + weight = authorWeightHuman + } + + ageInDays := time.Since(entry.CreatedAt).Hours() / 24 + decay := math.Exp(-ageInDays * math.Log(2) / halfLifeDays) + + corrections := len(entry.History) + correctionPenalty := math.Min(float64(corrections)*correctionPenaltyPerCorrection, maxCorrectionPenalty) + + flagMultiplier := 1.0 + if entry.FlaggedAt != nil { + flagMultiplier = flagTrustMultiplier + } + + score := weight * decay * (1 - correctionPenalty) * flagMultiplier + return float32(math.Max(0, math.Min(1, score))) +} + +// ComputeStalenessScore returns a value in [0,1] representing how stale +// this memory entry is. Higher = more stale (more likely to need review). +// +// Formula: access_age_normalized + flag_bonus + correction_bonus +func ComputeStalenessScore(entry Entry) float32 { + lastAccess := entry.LastAccessedAt + if lastAccess.IsZero() { + lastAccess = entry.CreatedAt + } + daysSinceAccess := time.Since(lastAccess).Hours() / 24 + base := math.Min(daysSinceAccess/maxStalenessAccessDays, 1.0) + + flagBonus := 0.0 + if entry.FlaggedAt != nil { + flagBonus = flagStalenessBonus + } + + corrections := len(entry.History) + correctionBonus := math.Min(float64(corrections)*correctionStalenessPerItem, maxCorrectionStaleness) + + return float32(math.Min(1.0, base+flagBonus+correctionBonus)) +} diff --git a/pkg/memory/scoring_test.go b/pkg/memory/scoring_test.go new file mode 100644 index 0000000000..e8a5da982e --- /dev/null +++ b/pkg/memory/scoring_test.go @@ -0,0 +1,138 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/memory" +) + +func TestComputeTrustScore(t *testing.T) { + t.Parallel() + now := time.Now() + + tests := []struct { + name string + entry memory.Entry + wantMin float32 + wantMax float32 + }{ + { + name: "fresh human entry has high trust", + entry: memory.Entry{ + Author: memory.AuthorHuman, + CreatedAt: now, + }, + wantMin: 0.95, + wantMax: 1.0, + }, + { + name: "fresh agent entry has lower trust than human", + entry: memory.Entry{ + Author: memory.AuthorAgent, + CreatedAt: now, + }, + wantMin: 0.65, + wantMax: 0.75, + }, + { + name: "flagged entry has halved trust", + entry: func() memory.Entry { + ft := now + return memory.Entry{ + Author: memory.AuthorHuman, + CreatedAt: now, + FlaggedAt: &ft, + } + }(), + wantMin: 0.45, + wantMax: 0.55, + }, + { + name: "two corrections reduce trust", + entry: memory.Entry{ + Author: memory.AuthorHuman, + CreatedAt: now, + History: []memory.Revision{{}, {}}, + }, + wantMin: 0.85, + wantMax: 0.95, + }, + { + name: "old entry has decayed trust", + entry: memory.Entry{ + Author: memory.AuthorHuman, + CreatedAt: now.AddDate(0, 0, -180), // half-life + }, + wantMin: 0.45, + wantMax: 0.55, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + score := memory.ComputeTrustScore(tc.entry) + require.GreaterOrEqual(t, score, tc.wantMin, "trust score too low") + require.LessOrEqual(t, score, tc.wantMax, "trust score too high") + }) + } +} + +func TestComputeStalenessScore(t *testing.T) { + t.Parallel() + now := time.Now() + + tests := []struct { + name string + entry memory.Entry + wantMin float32 + wantMax float32 + }{ + { + name: "recently accessed entry is fresh", + entry: memory.Entry{ + CreatedAt: now, + LastAccessedAt: now, + }, + wantMin: 0.0, + wantMax: 0.05, + }, + { + name: "entry not accessed for 90 days is stale", + entry: memory.Entry{ + CreatedAt: now.AddDate(0, 0, -90), + LastAccessedAt: now.AddDate(0, 0, -90), + }, + wantMin: 0.95, + wantMax: 1.0, + }, + { + name: "flagged entry adds staleness bonus", + entry: func() memory.Entry { + ft := now + return memory.Entry{ + CreatedAt: now, + LastAccessedAt: now, + FlaggedAt: &ft, + } + }(), + wantMin: 0.28, + wantMax: 0.32, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + score := memory.ComputeStalenessScore(tc.entry) + require.GreaterOrEqual(t, score, tc.wantMin) + require.LessOrEqual(t, score, tc.wantMax) + }) + } +} diff --git a/pkg/memory/service.go b/pkg/memory/service.go new file mode 100644 index 0000000000..7db0e45f82 --- /dev/null +++ b/pkg/memory/service.go @@ -0,0 +1,205 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +const ( + conflictSimilarityThreshold = float32(0.85) + defaultConflictTopK = 5 + // stalenessSearchPenaltyWeight controls how much staleness reduces ranking score. + stalenessSearchPenaltyWeight = float32(0.3) +) + +// Service orchestrates Store, VectorStore, and Embedder to provide +// the full memory lifecycle including conflict detection and scoring. +type Service struct { + store Store + vectors VectorStore + embedder Embedder + log *zap.Logger +} + +// NewService constructs a Service. All dependencies are required. +// +// The Store provides durable persistence for memory entries. +// The VectorStore enables semantic similarity search over entry embeddings. +// The Embedder converts text to vectors; the caller is responsible for +// ensuring the same Embedder is used consistently — switching embedders +// will invalidate stored vectors. +func NewService(store Store, vectors VectorStore, embedder Embedder, log *zap.Logger) (*Service, error) { + if store == nil { + return nil, fmt.Errorf("store is required") + } + if vectors == nil { + return nil, fmt.Errorf("vector store is required") + } + if embedder == nil { + return nil, fmt.Errorf("embedder is required") + } + if log == nil { + return nil, fmt.Errorf("logger is required") + } + return &Service{store: store, vectors: vectors, embedder: embedder, log: log}, nil +} + +// RememberInput is the input to Service.Remember. +type RememberInput struct { + Content string + Type Type + Tags []string + Author AuthorType + AgentID string + SessionID string + Source SourceType + SkillRef string + TTLDays *int + // Force bypasses conflict detection and writes unconditionally. + Force bool +} + +// RememberResult is returned by Service.Remember. +// If Conflicts is non-empty, MemoryID is empty and the write was not performed. +type RememberResult struct { + MemoryID string + Conflicts []ConflictResult +} + +// Remember embeds content, checks for conflicts, and writes the entry if none found. +// When Force is true the conflict check is skipped entirely. +func (s *Service) Remember(ctx context.Context, in RememberInput) (*RememberResult, error) { + embedding, err := s.embedder.Embed(ctx, in.Content) + if err != nil { + return nil, fmt.Errorf("embedding content: %w", err) + } + + if !in.Force { + conflicts, err := s.detectConflicts(ctx, embedding, in.Type) + if err != nil { + return nil, fmt.Errorf("detecting conflicts: %w", err) + } + if len(conflicts) > 0 { + return &RememberResult{Conflicts: conflicts}, nil + } + } + + id := "mem_" + uuid.New().String() + now := time.Now().UTC() + entry := Entry{ + ID: id, + Type: in.Type, + Content: in.Content, + Tags: in.Tags, + Author: in.Author, + AgentID: in.AgentID, + SessionID: in.SessionID, + Source: sourceOrDefault(in.Source), + SkillRef: in.SkillRef, + Status: EntryStatusActive, + TTLDays: in.TTLDays, + CreatedAt: now, + UpdatedAt: now, + } + if in.TTLDays != nil { + t := now.AddDate(0, 0, *in.TTLDays) + entry.ExpiresAt = &t + } + entry.TrustScore = ComputeTrustScore(entry) + entry.StalenessScore = ComputeStalenessScore(entry) + + if err := s.store.Create(ctx, entry); err != nil { + return nil, fmt.Errorf("creating entry: %w", err) + } + if err := s.vectors.Upsert(ctx, id, embedding); err != nil { + // Best-effort rollback: remove the orphaned store entry. + _ = s.store.Delete(ctx, id) + return nil, fmt.Errorf("upserting vector: %w", err) + } + + return &RememberResult{MemoryID: id}, nil +} + +// Search embeds the query, searches the vector store, fetches entries, and +// increments access counts. +func (s *Service) Search(ctx context.Context, query string, memType *Type, topK int) ([]ScoredEntry, error) { + if topK <= 0 { + topK = 10 + } + embedding, err := s.embedder.Embed(ctx, query) + if err != nil { + return nil, fmt.Errorf("embedding query: %w", err) + } + + active := EntryStatusActive + ids, err := s.vectors.Search(ctx, embedding, topK, VectorFilter{Type: memType, Status: &active}) + if err != nil { + return nil, fmt.Errorf("vector search: %w", err) + } + + var results []ScoredEntry + for _, scored := range ids { + entry, err := s.store.Get(ctx, scored.ID) + if err != nil { + s.log.Warn("skipping missing entry", zap.String("id", scored.ID), zap.Error(err)) + continue + } + // Increment access count; failure is non-fatal. + _ = s.store.IncrementAccess(ctx, scored.ID) + // Composite score: boost by trust, penalise by staleness. + composite := scored.Similarity * entry.TrustScore * (1 - stalenessSearchPenaltyWeight*entry.StalenessScore) + results = append(results, ScoredEntry{Entry: entry, Similarity: composite}) + } + sort.Slice(results, func(i, j int) bool { + return results[i].Similarity > results[j].Similarity + }) + return results, nil +} + +// detectConflicts returns any existing entries whose embedding similarity to +// the candidate exceeds conflictSimilarityThreshold. +func (s *Service) detectConflicts(ctx context.Context, embedding []float32, memType Type) ([]ConflictResult, error) { + active := EntryStatusActive + candidates, err := s.vectors.Search(ctx, embedding, defaultConflictTopK, VectorFilter{ + Type: &memType, + Status: &active, + }) + if err != nil { + return nil, err + } + + var conflicts []ConflictResult + for _, c := range candidates { + if c.Similarity < conflictSimilarityThreshold { + continue + } + entry, err := s.store.Get(ctx, c.ID) + if err != nil { + // Skip entries that can't be fetched; they may have been deleted concurrently. + s.log.Warn("skipping conflict candidate", zap.String("id", c.ID), zap.Error(err)) + continue + } + conflicts = append(conflicts, ConflictResult{ + ID: entry.ID, + Content: entry.Content, + Similarity: c.Similarity, + TrustScore: entry.TrustScore, + }) + } + return conflicts, nil +} + +func sourceOrDefault(s SourceType) SourceType { + if s == "" { + return SourceMemory + } + return s +} diff --git a/pkg/memory/service_test.go b/pkg/memory/service_test.go new file mode 100644 index 0000000000..8e0d721acc --- /dev/null +++ b/pkg/memory/service_test.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + + "github.com/stacklok/toolhive/pkg/memory" + "github.com/stacklok/toolhive/pkg/memory/mocks" +) + +func TestService_Remember_NoConflict(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + vectors := mocks.NewMockVectorStore(ctrl) + embedder := mocks.NewMockEmbedder(ctrl) + + emb := []float32{1, 0, 0} + embedder.EXPECT().Embed(gomock.Any(), "test fact").Return(emb, nil) + active := memory.EntryStatusActive + vectors.EXPECT().Search(gomock.Any(), emb, 5, memory.VectorFilter{ + Type: ptrOf(memory.TypeSemantic), + Status: &active, + }).Return(nil, nil) + store.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) + vectors.EXPECT().Upsert(gomock.Any(), gomock.Any(), emb).Return(nil) + + svc, err := memory.NewService(store, vectors, embedder, zaptest.NewLogger(t)) + require.NoError(t, err) + + result, err := svc.Remember(context.Background(), memory.RememberInput{ + Content: "test fact", + Type: memory.TypeSemantic, + Author: memory.AuthorHuman, + }) + require.NoError(t, err) + require.NotEmpty(t, result.MemoryID) + require.Empty(t, result.Conflicts) +} + +func TestService_Remember_ConflictDetected(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + vectors := mocks.NewMockVectorStore(ctrl) + embedder := mocks.NewMockEmbedder(ctrl) + + emb := []float32{1, 0, 0} + embedder.EXPECT().Embed(gomock.Any(), "conflicting fact").Return(emb, nil) + active := memory.EntryStatusActive + vectors.EXPECT().Search(gomock.Any(), emb, 5, memory.VectorFilter{ + Type: ptrOf(memory.TypeSemantic), + Status: &active, + }).Return([]memory.ScoredID{{ID: "mem_existing", Similarity: 0.92}}, nil) + store.EXPECT().Get(gomock.Any(), "mem_existing").Return(memory.Entry{ + ID: "mem_existing", + Content: "existing fact", + }, nil) + + svc, err := memory.NewService(store, vectors, embedder, zaptest.NewLogger(t)) + require.NoError(t, err) + + result, err := svc.Remember(context.Background(), memory.RememberInput{ + Content: "conflicting fact", + Type: memory.TypeSemantic, + Author: memory.AuthorAgent, + }) + require.NoError(t, err) + require.Empty(t, result.MemoryID) + require.Len(t, result.Conflicts, 1) + require.Equal(t, "mem_existing", result.Conflicts[0].ID) +} + +func TestService_Remember_Force(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + vectors := mocks.NewMockVectorStore(ctrl) + embedder := mocks.NewMockEmbedder(ctrl) + + emb := []float32{1, 0, 0} + embedder.EXPECT().Embed(gomock.Any(), "forced fact").Return(emb, nil) + store.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) + vectors.EXPECT().Upsert(gomock.Any(), gomock.Any(), emb).Return(nil) + + svc, err := memory.NewService(store, vectors, embedder, zaptest.NewLogger(t)) + require.NoError(t, err) + + result, err := svc.Remember(context.Background(), memory.RememberInput{ + Content: "forced fact", + Type: memory.TypeSemantic, + Author: memory.AuthorHuman, + Force: true, + }) + require.NoError(t, err) + require.NotEmpty(t, result.MemoryID) +} + +func TestService_Search_CompositeScoring(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + store := mocks.NewMockStore(ctrl) + vectors := mocks.NewMockVectorStore(ctrl) + embedder := mocks.NewMockEmbedder(ctrl) + + emb := []float32{1, 0, 0} + embedder.EXPECT().Embed(gomock.Any(), "auth endpoint").Return(emb, nil) + + active := memory.EntryStatusActive + // Two results: high raw similarity but stale/flagged vs lower similarity but fresh+trusted. + vectors.EXPECT().Search(gomock.Any(), emb, 10, memory.VectorFilter{Status: &active}). + Return([]memory.ScoredID{ + {ID: "stale_high", Similarity: 0.95}, + {ID: "fresh_low", Similarity: 0.80}, + }, nil) + + now := time.Now() + flagTime := now.Add(-24 * time.Hour) + + store.EXPECT().Get(gomock.Any(), "stale_high").Return(memory.Entry{ + ID: "stale_high", Author: memory.AuthorAgent, + TrustScore: 0.5, StalenessScore: 0.8, CreatedAt: now, FlaggedAt: &flagTime, + }, nil) + store.EXPECT().IncrementAccess(gomock.Any(), "stale_high").Return(nil) + + store.EXPECT().Get(gomock.Any(), "fresh_low").Return(memory.Entry{ + ID: "fresh_low", Author: memory.AuthorHuman, + TrustScore: 1.0, StalenessScore: 0.0, CreatedAt: now, + }, nil) + store.EXPECT().IncrementAccess(gomock.Any(), "fresh_low").Return(nil) + + svc, err := memory.NewService(store, vectors, embedder, zaptest.NewLogger(t)) + require.NoError(t, err) + + results, err := svc.Search(context.Background(), "auth endpoint", nil, 0) + require.NoError(t, err) + require.Len(t, results, 2) + + // fresh_low (composite ≈ 0.80) should rank above stale_high (0.95 × 0.5 × (1-0.3×0.8) ≈ 0.361) + require.Equal(t, "fresh_low", results[0].Entry.ID) + require.Equal(t, "stale_high", results[1].Entry.ID) + require.Greater(t, results[0].Similarity, results[1].Similarity) +} + +func ptrOf[T any](v T) *T { return &v } diff --git a/pkg/memory/sqlite/db.go b/pkg/memory/sqlite/db.go new file mode 100644 index 0000000000..50a56d0ece --- /dev/null +++ b/pkg/memory/sqlite/db.go @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package sqlite provides SQLite-backed implementations of the memory.Store +// and memory.VectorStore interfaces. +package sqlite + +import ( + "context" + "database/sql" + "embed" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + + "github.com/pressly/goose/v3" + _ "modernc.org/sqlite" // SQLite driver +) + +//go:embed migrations/*.sql +var migrations embed.FS + +// DB wraps a *sql.DB connection for the memory SQLite database. +type DB struct { + db *sql.DB +} + +// Open opens (or creates) the memory SQLite database at path. +func Open(ctx context.Context, path string) (_ *DB, err error) { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { + return nil, fmt.Errorf("creating database directory: %w", err) + } + + dsn := fmt.Sprintf("file:%s?_txlock=immediate", path) + sqlDB, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + success := false + defer func() { + if !success { + if closeErr := sqlDB.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("closing database after failure: %w", closeErr)) + } + } + }() + + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + + if err = applyPragmas(sqlDB); err != nil { + return nil, err + } + + if err = runMigrations(ctx, sqlDB); err != nil { + return nil, err + } + + if err = sqlDB.PingContext(ctx); err != nil { + return nil, fmt.Errorf("verifying connection: %w", err) + } + + success = true + return &DB{db: sqlDB}, nil +} + +// Close closes the underlying database connection. +func (d *DB) Close() error { return d.db.Close() } + +// DB returns the underlying *sql.DB. +func (d *DB) DB() *sql.DB { return d.db } + +func applyPragmas(db *sql.DB) error { + for _, p := range []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA busy_timeout=5000", + "PRAGMA synchronous=NORMAL", + "PRAGMA foreign_keys=ON", + "PRAGMA cache_size=-2000", + } { + if _, err := db.Exec(p); err != nil { + return fmt.Errorf("applying pragma %q: %w", p, err) + } + } + return nil +} + +func runMigrations(ctx context.Context, db *sql.DB) error { + migrationsFS, err := fs.Sub(migrations, "migrations") + if err != nil { + return fmt.Errorf("creating migrations sub-filesystem: %w", err) + } + provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrationsFS, + goose.WithAllowOutofOrder(false), + ) + if err != nil { + return fmt.Errorf("creating goose provider: %w", err) + } + if _, err := provider.Up(ctx); err != nil { + return fmt.Errorf("running migrations: %w", err) + } + return nil +} diff --git a/pkg/memory/sqlite/migrations/001_initial.sql b/pkg/memory/sqlite/migrations/001_initial.sql new file mode 100644 index 0000000000..717f8061d9 --- /dev/null +++ b/pkg/memory/sqlite/migrations/001_initial.sql @@ -0,0 +1,60 @@ +-- +goose Up + +CREATE TABLE IF NOT EXISTS memory_entries ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL CHECK (type IN ('semantic','procedural')), + content TEXT NOT NULL, + tags TEXT NOT NULL DEFAULT '[]', -- JSON array + author TEXT NOT NULL CHECK (author IN ('human','agent')), + agent_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL DEFAULT '', + source TEXT NOT NULL CHECK (source IN ('memory','skill')), + skill_ref TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','flagged','expired','archived')), + trust_score REAL NOT NULL DEFAULT 0, + staleness_score REAL NOT NULL DEFAULT 0, + access_count INTEGER NOT NULL DEFAULT 0, + last_accessed_at TEXT, + flagged_at TEXT, + flag_reason TEXT NOT NULL DEFAULT '', + ttl_days INTEGER, + expires_at TEXT, + archived_at TEXT, + consolidated_into TEXT NOT NULL DEFAULT '', + crystallized_into TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS memory_revisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entry_id TEXT NOT NULL REFERENCES memory_entries(id) ON DELETE CASCADE, + content TEXT NOT NULL, + author TEXT NOT NULL, + correction_note TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL +); + +-- Embeddings stored as a JSON array of float32 values. +-- Queries load all vectors for a type+status combination and compute +-- cosine similarity in Go. Switch to an external VectorStore provider +-- for datasets > 100K entries. +CREATE TABLE IF NOT EXISTS memory_embeddings ( + entry_id TEXT PRIMARY KEY REFERENCES memory_entries(id) ON DELETE CASCADE, + embedding TEXT NOT NULL -- JSON []float32 +); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_type_status + ON memory_entries(type, status); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_expires_at + ON memory_entries(expires_at) WHERE expires_at IS NOT NULL; + +-- +goose Down + +DROP INDEX IF EXISTS idx_memory_entries_expires_at; +DROP INDEX IF EXISTS idx_memory_entries_type_status; +DROP TABLE IF EXISTS memory_embeddings; +DROP TABLE IF EXISTS memory_revisions; +DROP TABLE IF EXISTS memory_entries; diff --git a/pkg/memory/sqlite/migrations/002_add_episodic_type.sql b/pkg/memory/sqlite/migrations/002_add_episodic_type.sql new file mode 100644 index 0000000000..d614ce3d39 --- /dev/null +++ b/pkg/memory/sqlite/migrations/002_add_episodic_type.sql @@ -0,0 +1,83 @@ +-- +goose Up + +-- SQLite does not support ALTER COLUMN, so we recreate the table with the +-- updated CHECK constraint to include the 'episodic' memory type. + +CREATE TABLE IF NOT EXISTS memory_entries_new ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL CHECK (type IN ('semantic','procedural','episodic')), + content TEXT NOT NULL, + tags TEXT NOT NULL DEFAULT '[]', + author TEXT NOT NULL CHECK (author IN ('human','agent')), + agent_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL DEFAULT '', + source TEXT NOT NULL CHECK (source IN ('memory','skill')), + skill_ref TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','flagged','expired','archived')), + trust_score REAL NOT NULL DEFAULT 0, + staleness_score REAL NOT NULL DEFAULT 0, + access_count INTEGER NOT NULL DEFAULT 0, + last_accessed_at TEXT, + flagged_at TEXT, + flag_reason TEXT NOT NULL DEFAULT '', + ttl_days INTEGER, + expires_at TEXT, + archived_at TEXT, + consolidated_into TEXT NOT NULL DEFAULT '', + crystallized_into TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +INSERT INTO memory_entries_new SELECT * FROM memory_entries; +DROP TABLE memory_entries; +ALTER TABLE memory_entries_new RENAME TO memory_entries; + +CREATE INDEX IF NOT EXISTS idx_memory_entries_type_status + ON memory_entries(type, status); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_expires_at + ON memory_entries(expires_at) WHERE expires_at IS NOT NULL; + +-- +goose Down + +-- Revert: drop episodic rows then recreate the narrower constraint. +DELETE FROM memory_entries WHERE type = 'episodic'; + +CREATE TABLE IF NOT EXISTS memory_entries_old ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL CHECK (type IN ('semantic','procedural')), + content TEXT NOT NULL, + tags TEXT NOT NULL DEFAULT '[]', + author TEXT NOT NULL CHECK (author IN ('human','agent')), + agent_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL DEFAULT '', + source TEXT NOT NULL CHECK (source IN ('memory','skill')), + skill_ref TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','flagged','expired','archived')), + trust_score REAL NOT NULL DEFAULT 0, + staleness_score REAL NOT NULL DEFAULT 0, + access_count INTEGER NOT NULL DEFAULT 0, + last_accessed_at TEXT, + flagged_at TEXT, + flag_reason TEXT NOT NULL DEFAULT '', + ttl_days INTEGER, + expires_at TEXT, + archived_at TEXT, + consolidated_into TEXT NOT NULL DEFAULT '', + crystallized_into TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +INSERT INTO memory_entries_old SELECT * FROM memory_entries; +DROP TABLE memory_entries; +ALTER TABLE memory_entries_old RENAME TO memory_entries; + +CREATE INDEX IF NOT EXISTS idx_memory_entries_type_status + ON memory_entries(type, status); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_expires_at + ON memory_entries(expires_at) WHERE expires_at IS NOT NULL; diff --git a/pkg/memory/sqlite/migrations/003_add_resource_source.sql b/pkg/memory/sqlite/migrations/003_add_resource_source.sql new file mode 100644 index 0000000000..d1455edf44 --- /dev/null +++ b/pkg/memory/sqlite/migrations/003_add_resource_source.sql @@ -0,0 +1,83 @@ +-- +goose Up + +-- SQLite does not support ALTER COLUMN, so we recreate the table with the +-- updated CHECK constraint to include the 'resource' source type. + +CREATE TABLE IF NOT EXISTS memory_entries_new ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL CHECK (type IN ('semantic','procedural','episodic')), + content TEXT NOT NULL, + tags TEXT NOT NULL DEFAULT '[]', + author TEXT NOT NULL CHECK (author IN ('human','agent')), + agent_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL DEFAULT '', + source TEXT NOT NULL CHECK (source IN ('memory','skill','resource')), + skill_ref TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','flagged','expired','archived')), + trust_score REAL NOT NULL DEFAULT 0, + staleness_score REAL NOT NULL DEFAULT 0, + access_count INTEGER NOT NULL DEFAULT 0, + last_accessed_at TEXT, + flagged_at TEXT, + flag_reason TEXT NOT NULL DEFAULT '', + ttl_days INTEGER, + expires_at TEXT, + archived_at TEXT, + consolidated_into TEXT NOT NULL DEFAULT '', + crystallized_into TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +INSERT INTO memory_entries_new SELECT * FROM memory_entries; +DROP TABLE memory_entries; +ALTER TABLE memory_entries_new RENAME TO memory_entries; + +CREATE INDEX IF NOT EXISTS idx_memory_entries_type_status + ON memory_entries(type, status); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_expires_at + ON memory_entries(expires_at) WHERE expires_at IS NOT NULL; + +-- +goose Down + +-- Revert: drop resource rows then recreate the narrower constraint. +DELETE FROM memory_entries WHERE source = 'resource'; + +CREATE TABLE IF NOT EXISTS memory_entries_old ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL CHECK (type IN ('semantic','procedural','episodic')), + content TEXT NOT NULL, + tags TEXT NOT NULL DEFAULT '[]', + author TEXT NOT NULL CHECK (author IN ('human','agent')), + agent_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL DEFAULT '', + source TEXT NOT NULL CHECK (source IN ('memory','skill')), + skill_ref TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','flagged','expired','archived')), + trust_score REAL NOT NULL DEFAULT 0, + staleness_score REAL NOT NULL DEFAULT 0, + access_count INTEGER NOT NULL DEFAULT 0, + last_accessed_at TEXT, + flagged_at TEXT, + flag_reason TEXT NOT NULL DEFAULT '', + ttl_days INTEGER, + expires_at TEXT, + archived_at TEXT, + consolidated_into TEXT NOT NULL DEFAULT '', + crystallized_into TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +INSERT INTO memory_entries_old SELECT * FROM memory_entries; +DROP TABLE memory_entries; +ALTER TABLE memory_entries_old RENAME TO memory_entries; + +CREATE INDEX IF NOT EXISTS idx_memory_entries_type_status + ON memory_entries(type, status); + +CREATE INDEX IF NOT EXISTS idx_memory_entries_expires_at + ON memory_entries(expires_at) WHERE expires_at IS NOT NULL; diff --git a/pkg/memory/sqlite/store.go b/pkg/memory/sqlite/store.go new file mode 100644 index 0000000000..80fe0c9eb0 --- /dev/null +++ b/pkg/memory/sqlite/store.go @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sqlite + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// Store implements memory.Store using SQLite. +type Store struct { + db *sql.DB +} + +// NewStore creates a new SQLite-backed Store. +func NewStore(wrapper *DB) *Store { + return &Store{db: wrapper.DB()} +} + +var _ memory.Store = (*Store)(nil) + +// Create inserts a new memory entry. +func (s *Store) Create(ctx context.Context, e memory.Entry) error { + tags, err := json.Marshal(e.Tags) + if err != nil { + return fmt.Errorf("marshalling tags: %w", err) + } + + _, err = s.db.ExecContext(ctx, ` + INSERT INTO memory_entries + (id, type, content, tags, author, agent_id, session_id, source, skill_ref, + status, trust_score, staleness_score, access_count, last_accessed_at, + ttl_days, expires_at, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`, + e.ID, string(e.Type), e.Content, string(tags), + string(e.Author), e.AgentID, e.SessionID, string(e.Source), e.SkillRef, + string(e.Status), e.TrustScore, e.StalenessScore, e.AccessCount, + nullableTime(e.LastAccessedAt), + e.TTLDays, nullableTimePtr(e.ExpiresAt), + e.CreatedAt.UTC().Format(time.RFC3339Nano), + e.UpdatedAt.UTC().Format(time.RFC3339Nano), + ) + return err +} + +// Get retrieves a single entry by ID, including its revision history. +func (s *Store) Get(ctx context.Context, id string) (memory.Entry, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, type, content, tags, author, agent_id, session_id, source, skill_ref, + status, trust_score, staleness_score, access_count, last_accessed_at, + flagged_at, flag_reason, ttl_days, expires_at, archived_at, + consolidated_into, crystallized_into, created_at, updated_at + FROM memory_entries WHERE id = ?`, id) + + e, err := scanEntry(row) + if errors.Is(err, sql.ErrNoRows) { + return memory.Entry{}, fmt.Errorf("entry %q: %w", id, memory.ErrNotFound) + } + if err != nil { + return memory.Entry{}, err + } + + e.History, err = s.loadHistory(ctx, id) + return e, err +} + +// Update replaces content and appends the old content to revisions. +func (s *Store) Update(ctx context.Context, id, content string, author memory.AuthorType, note string) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + + var oldContent string + if err := tx.QueryRowContext(ctx, `SELECT content FROM memory_entries WHERE id = ?`, id).Scan(&oldContent); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("entry %q: %w", id, memory.ErrNotFound) + } + return err + } + + if _, err := tx.ExecContext(ctx, + `INSERT INTO memory_revisions (entry_id, content, author, correction_note, created_at) + VALUES (?, ?, ?, ?, ?)`, + id, oldContent, string(author), note, time.Now().UTC().Format(time.RFC3339Nano), + ); err != nil { + return err + } + + if _, err := tx.ExecContext(ctx, + `UPDATE memory_entries SET content = ?, updated_at = ? WHERE id = ?`, + content, time.Now().UTC().Format(time.RFC3339Nano), id, + ); err != nil { + return err + } + + return tx.Commit() +} + +// Flag marks an entry as potentially stale. +func (s *Store) Flag(ctx context.Context, id, reason string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + _, err := s.db.ExecContext(ctx, + `UPDATE memory_entries SET status='flagged', flagged_at=?, flag_reason=?, updated_at=? WHERE id=?`, + now, reason, now, id) + return err +} + +// Unflag clears the flag on an entry. +func (s *Store) Unflag(ctx context.Context, id string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + _, err := s.db.ExecContext(ctx, + `UPDATE memory_entries SET status='active', flagged_at=NULL, flag_reason='', updated_at=? WHERE id=?`, + now, id) + return err +} + +// Delete permanently removes an entry. +func (s *Store) Delete(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id=?`, id) + return err +} + +// List returns entries matching the filter. +func (s *Store) List(ctx context.Context, f memory.ListFilter) ([]memory.Entry, error) { + query := `SELECT id, type, content, tags, author, agent_id, session_id, source, skill_ref, + status, trust_score, staleness_score, access_count, last_accessed_at, + flagged_at, flag_reason, ttl_days, expires_at, archived_at, + consolidated_into, crystallized_into, created_at, updated_at + FROM memory_entries WHERE 1=1` + var args []any + + if f.Type != nil { + query += " AND type=?" + args = append(args, string(*f.Type)) + } + if f.Author != nil { + query += " AND author=?" + args = append(args, string(*f.Author)) + } + if f.Source != nil { + query += " AND source=?" + args = append(args, string(*f.Source)) + } + if f.Status != nil { + query += " AND status=?" + args = append(args, string(*f.Status)) + } + if f.CreatedAfter != nil { + query += " AND created_at >= ?" + args = append(args, f.CreatedAfter.UTC().Format(time.RFC3339Nano)) + } + if f.CreatedBefore != nil { + query += " AND created_at <= ?" + args = append(args, f.CreatedBefore.UTC().Format(time.RFC3339Nano)) + } + + query += " ORDER BY created_at DESC" + if f.Limit > 0 { + query += " LIMIT ? OFFSET ?" + args = append(args, f.Limit, f.Offset) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var entries []memory.Entry + for rows.Next() { + e, err := scanEntry(rows) + if err != nil { + return nil, err + } + entries = append(entries, e) + } + return entries, rows.Err() +} + +// Archive transitions an entry to archived status. +func (s *Store) Archive(ctx context.Context, id string, reason memory.ArchiveReason, ref string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + field := consolidatedField(reason) + _, err := s.db.ExecContext(ctx, + fmt.Sprintf(`UPDATE memory_entries SET status='archived', archived_at=?, %s=?, updated_at=? WHERE id=?`, field), + now, ref, now, id) + return err +} + +// IncrementAccess increments the access counter and updates last_accessed_at. +func (s *Store) IncrementAccess(ctx context.Context, id string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + _, err := s.db.ExecContext(ctx, + `UPDATE memory_entries SET access_count=access_count+1, last_accessed_at=?, updated_at=? WHERE id=?`, + now, now, id) + return err +} + +// UpdateScores persists recomputed trust and staleness scores. +func (s *Store) UpdateScores(ctx context.Context, id string, trust, staleness float32) error { + _, err := s.db.ExecContext(ctx, + `UPDATE memory_entries SET trust_score=?, staleness_score=? WHERE id=?`, + trust, staleness, id) + return err +} + +// ListExpired returns active entries whose TTL has elapsed. +func (s *Store) ListExpired(ctx context.Context) ([]memory.Entry, error) { + now := time.Now().UTC().Format(time.RFC3339Nano) + rows, err := s.db.QueryContext(ctx, + `SELECT id, type, content, tags, author, agent_id, session_id, source, skill_ref, + status, trust_score, staleness_score, access_count, last_accessed_at, + flagged_at, flag_reason, ttl_days, expires_at, archived_at, + consolidated_into, crystallized_into, created_at, updated_at + FROM memory_entries + WHERE expires_at IS NOT NULL AND expires_at <= ? AND status NOT IN ('expired','archived')`, now) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var entries []memory.Entry + for rows.Next() { + e, err := scanEntry(rows) + if err != nil { + return nil, err + } + entries = append(entries, e) + } + return entries, rows.Err() +} + +// ListActive returns all active and flagged entries for score recomputation. +func (s *Store) ListActive(ctx context.Context) ([]memory.Entry, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, type, content, tags, author, agent_id, session_id, source, skill_ref, + status, trust_score, staleness_score, access_count, last_accessed_at, + flagged_at, flag_reason, ttl_days, expires_at, archived_at, + consolidated_into, crystallized_into, created_at, updated_at + FROM memory_entries WHERE status IN ('active','flagged')`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var entries []memory.Entry + for rows.Next() { + e, err := scanEntry(rows) + if err != nil { + return nil, err + } + entries = append(entries, e) + } + return entries, rows.Err() +} + +// ---- helpers ---- + +type scanner interface { + Scan(dest ...any) error +} + +func scanEntry(sc scanner) (memory.Entry, error) { + var e memory.Entry + var ( + mtype, author, source, status string + tagsJSON string + lastAccessed, flaggedAt sql.NullString + expiresAt, archivedAt sql.NullString + createdAt, updatedAt string + ) + err := sc.Scan( + &e.ID, &mtype, &e.Content, &tagsJSON, &author, + &e.AgentID, &e.SessionID, &source, &e.SkillRef, + &status, &e.TrustScore, &e.StalenessScore, &e.AccessCount, &lastAccessed, + &flaggedAt, &e.FlagReason, &e.TTLDays, &expiresAt, &archivedAt, + &e.ConsolidatedInto, &e.CrystallizedInto, &createdAt, &updatedAt, + ) + if err != nil { + return memory.Entry{}, err + } + e.Type = memory.Type(mtype) + e.Author = memory.AuthorType(author) + e.Source = memory.SourceType(source) + e.Status = memory.EntryStatus(status) + _ = json.Unmarshal([]byte(tagsJSON), &e.Tags) + e.CreatedAt, _ = parseTime(createdAt) + e.UpdatedAt, _ = parseTime(updatedAt) + if lastAccessed.Valid { + t, _ := parseTime(lastAccessed.String) + e.LastAccessedAt = t + } + if flaggedAt.Valid { + t, _ := parseTime(flaggedAt.String) + e.FlaggedAt = &t + } + if expiresAt.Valid { + t, _ := parseTime(expiresAt.String) + e.ExpiresAt = &t + } + if archivedAt.Valid { + t, _ := parseTime(archivedAt.String) + e.ArchivedAt = &t + } + return e, nil +} + +func (s *Store) loadHistory(ctx context.Context, entryID string) ([]memory.Revision, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT content, author, correction_note, created_at + FROM memory_revisions WHERE entry_id=? ORDER BY created_at ASC`, entryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var revs []memory.Revision + for rows.Next() { + var r memory.Revision + var author, createdAt string + if err := rows.Scan(&r.Content, &author, &r.CorrectionNote, &createdAt); err != nil { + return nil, err + } + r.Author = memory.AuthorType(author) + r.Timestamp, _ = parseTime(createdAt) + revs = append(revs, r) + } + return revs, rows.Err() +} + +func nullableTime(t time.Time) any { + if t.IsZero() { + return nil + } + return t.UTC().Format(time.RFC3339Nano) +} + +func nullableTimePtr(t *time.Time) any { + if t == nil { + return nil + } + return t.UTC().Format(time.RFC3339Nano) +} + +func parseTime(s string) (time.Time, error) { + return time.Parse(time.RFC3339Nano, s) +} + +func consolidatedField(reason memory.ArchiveReason) string { + if reason == memory.ArchiveReasonCrystallized { + return "crystallized_into" + } + return "consolidated_into" +} + +func rollback(tx *sql.Tx) { + _ = tx.Rollback() +} diff --git a/pkg/memory/sqlite/store_test.go b/pkg/memory/sqlite/store_test.go new file mode 100644 index 0000000000..8b009d12b6 --- /dev/null +++ b/pkg/memory/sqlite/store_test.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sqlite_test + +import ( + "context" + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/memory" + memorysqlite "github.com/stacklok/toolhive/pkg/memory/sqlite" +) + +func openTestDB(t *testing.T) *memorysqlite.DB { + t.Helper() + dir := t.TempDir() + resolved, _ := filepath.EvalSymlinks(dir) + db, err := memorysqlite.Open(context.Background(), filepath.Join(resolved, "memory.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func TestMemoryStore_CreateAndGet(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + + entry := memory.Entry{ + ID: "mem_test_001", + Type: memory.TypeSemantic, + Content: "we deploy to us-east-1", + Tags: []string{"deployment", "infra"}, + Author: memory.AuthorHuman, + Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := store.Create(context.Background(), entry) + require.NoError(t, err) + + got, err := store.Get(context.Background(), "mem_test_001") + require.NoError(t, err) + require.Equal(t, entry.ID, got.ID) + require.Equal(t, entry.Content, got.Content) + require.Equal(t, entry.Tags, got.Tags) + require.Equal(t, entry.Author, got.Author) + require.Equal(t, entry.Status, got.Status) +} + +func TestMemoryStore_Update(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + + entry := memory.Entry{ + ID: "mem_test_002", + Type: memory.TypeSemantic, + Content: "old content", + Author: memory.AuthorHuman, + Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + require.NoError(t, store.Create(context.Background(), entry)) + + err := store.Update(context.Background(), "mem_test_002", "new content", memory.AuthorHuman, "corrected") + require.NoError(t, err) + + got, err := store.Get(context.Background(), "mem_test_002") + require.NoError(t, err) + require.Equal(t, "new content", got.Content) + require.Len(t, got.History, 1) + require.Equal(t, "old content", got.History[0].Content) + require.Equal(t, "corrected", got.History[0].CorrectionNote) +} + +func TestMemoryStore_Archive(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + + entry := memory.Entry{ + ID: "mem_test_003", + Type: memory.TypeProcedural, + Content: "check Docker health before E2E tests", + Author: memory.AuthorAgent, + Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + require.NoError(t, store.Create(context.Background(), entry)) + + err := store.Archive(context.Background(), "mem_test_003", memory.ArchiveReasonConsolidated, "mem_test_consolidated") + require.NoError(t, err) + + got, err := store.Get(context.Background(), "mem_test_003") + require.NoError(t, err) + require.Equal(t, memory.EntryStatusArchived, got.Status) + require.Equal(t, "mem_test_consolidated", got.ConsolidatedInto) + require.NotNil(t, got.ArchivedAt) +} + +func TestMemoryStore_List(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + + ctx := context.Background() + for i, content := range []string{"fact A", "fact B", "procedure X"} { + mtype := memory.TypeSemantic + if i == 2 { + mtype = memory.TypeProcedural + } + require.NoError(t, store.Create(ctx, memory.Entry{ + ID: fmt.Sprintf("mem_list_%d", i), + Type: mtype, + Content: content, + Author: memory.AuthorHuman, + Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + } + + sem := memory.TypeSemantic + results, err := store.List(ctx, memory.ListFilter{Type: &sem, Limit: 10}) + require.NoError(t, err) + require.Len(t, results, 2) +} + +func TestMemoryStore_ListTimeRange(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + ctx := context.Background() + + past := time.Now().Add(-2 * time.Hour) + recent := time.Now().Add(-30 * time.Minute) + + require.NoError(t, store.Create(ctx, memory.Entry{ + ID: "mem_old", Type: memory.TypeEpisodic, Content: "old event", + Author: memory.AuthorAgent, Source: memory.SourceMemory, Status: memory.EntryStatusActive, + CreatedAt: past, UpdatedAt: past, + })) + require.NoError(t, store.Create(ctx, memory.Entry{ + ID: "mem_new", Type: memory.TypeEpisodic, Content: "recent event", + Author: memory.AuthorAgent, Source: memory.SourceMemory, Status: memory.EntryStatusActive, + CreatedAt: recent, UpdatedAt: recent, + })) + + cutoff := time.Now().Add(-1 * time.Hour) + results, err := store.List(ctx, memory.ListFilter{CreatedAfter: &cutoff}) + require.NoError(t, err) + require.Len(t, results, 1) + require.Equal(t, "mem_new", results[0].ID) +} diff --git a/pkg/memory/sqlite/vector.go b/pkg/memory/sqlite/vector.go new file mode 100644 index 0000000000..a0e734833b --- /dev/null +++ b/pkg/memory/sqlite/vector.go @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sqlite + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "sort" + + "github.com/stacklok/toolhive/pkg/memory" +) + +// VectorStore implements memory.VectorStore using SQLite blob storage and +// Go-native cosine similarity. Suitable for datasets up to ~100K entries. +// Use an external VectorStore (Qdrant, pgvector) for larger datasets. +type VectorStore struct { + db *sql.DB +} + +// NewVectorStore creates a new SQLite-backed VectorStore. +func NewVectorStore(wrapper *DB) *VectorStore { + return &VectorStore{db: wrapper.DB()} +} + +var _ memory.VectorStore = (*VectorStore)(nil) + +// Upsert stores or replaces the embedding for entry id. +func (v *VectorStore) Upsert(ctx context.Context, id string, embedding []float32) error { + data, err := json.Marshal(embedding) + if err != nil { + return fmt.Errorf("marshalling embedding: %w", err) + } + _, err = v.db.ExecContext(ctx, + `INSERT INTO memory_embeddings (entry_id, embedding) VALUES (?,?) + ON CONFLICT(entry_id) DO UPDATE SET embedding=excluded.embedding`, + id, string(data)) + return err +} + +// Search loads all embeddings matching the filter, computes cosine similarity +// against query, and returns the topK results in descending score order. +func (v *VectorStore) Search( + ctx context.Context, query []float32, topK int, filter memory.VectorFilter, +) ([]memory.ScoredID, error) { + q := `SELECT e.entry_id, e.embedding + FROM memory_embeddings e + JOIN memory_entries m ON m.id = e.entry_id + WHERE 1=1` + var args []any + if filter.Type != nil { + q += " AND m.type=?" + args = append(args, string(*filter.Type)) + } + if filter.Status != nil { + q += " AND m.status=?" + args = append(args, string(*filter.Status)) + } + if filter.Source != nil { + q += " AND m.source=?" + args = append(args, string(*filter.Source)) + } + + rows, err := v.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + qNorm := l2Norm(query) + if qNorm == 0 { + return nil, fmt.Errorf("query vector has zero magnitude") + } + + var scored []memory.ScoredID + for rows.Next() { + var id, embJSON string + if err := rows.Scan(&id, &embJSON); err != nil { + return nil, err + } + var emb []float32 + if err := json.Unmarshal([]byte(embJSON), &emb); err != nil { + continue + } + sim := cosineSimilarity(query, emb, qNorm) + scored = append(scored, memory.ScoredID{ID: id, Similarity: sim}) + } + if err := rows.Err(); err != nil { + return nil, err + } + + sort.Slice(scored, func(i, j int) bool { + return scored[i].Similarity > scored[j].Similarity + }) + if topK > 0 && len(scored) > topK { + scored = scored[:topK] + } + return scored, nil +} + +// Delete removes the embedding for entry id. +func (v *VectorStore) Delete(ctx context.Context, id string) error { + _, err := v.db.ExecContext(ctx, `DELETE FROM memory_embeddings WHERE entry_id=?`, id) + return err +} + +func cosineSimilarity(a, b []float32, aNorm float32) float32 { + if len(a) != len(b) || aNorm == 0 { + return 0 + } + bNorm := l2Norm(b) + if bNorm == 0 { + return 0 + } + var dot float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + } + return float32(dot / (float64(aNorm) * float64(bNorm))) +} + +func l2Norm(v []float32) float32 { + var sum float64 + for _, x := range v { + sum += float64(x) * float64(x) + } + return float32(math.Sqrt(sum)) +} diff --git a/pkg/memory/sqlite/vector_test.go b/pkg/memory/sqlite/vector_test.go new file mode 100644 index 0000000000..b613db0edc --- /dev/null +++ b/pkg/memory/sqlite/vector_test.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sqlite_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/memory" + memorysqlite "github.com/stacklok/toolhive/pkg/memory/sqlite" +) + +func TestVectorStore_UpsertAndSearch(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + vectors := memorysqlite.NewVectorStore(db) + + ctx := context.Background() + + entries := []struct { + id string + embedding []float32 + }{ + {"vec_001", []float32{1, 0, 0}}, + {"vec_002", []float32{0.9, 0.1, 0}}, + {"vec_003", []float32{0, 0, 1}}, + } + for _, e := range entries { + require.NoError(t, store.Create(ctx, memory.Entry{ + ID: e.id, Type: memory.TypeSemantic, Content: "c", + Author: memory.AuthorAgent, Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + })) + require.NoError(t, vectors.Upsert(ctx, e.id, e.embedding)) + } + + query := []float32{0.95, 0.05, 0} + results, err := vectors.Search(ctx, query, 2, memory.VectorFilter{}) + require.NoError(t, err) + require.Len(t, results, 2) + + ids := []string{results[0].ID, results[1].ID} + require.Contains(t, ids, "vec_001") + require.Contains(t, ids, "vec_002") + require.NotContains(t, ids, "vec_003") + + require.GreaterOrEqual(t, results[0].Similarity, results[1].Similarity) +} + +func TestVectorStore_Delete(t *testing.T) { + t.Parallel() + db := openTestDB(t) + store := memorysqlite.NewStore(db) + vectors := memorysqlite.NewVectorStore(db) + + ctx := context.Background() + require.NoError(t, store.Create(ctx, memory.Entry{ + ID: "vec_del", Type: memory.TypeSemantic, Content: "c", + Author: memory.AuthorAgent, Source: memory.SourceMemory, + Status: memory.EntryStatusActive, + })) + require.NoError(t, vectors.Upsert(ctx, "vec_del", []float32{1, 0, 0})) + require.NoError(t, vectors.Delete(ctx, "vec_del")) + + results, err := vectors.Search(ctx, []float32{1, 0, 0}, 5, memory.VectorFilter{}) + require.NoError(t, err) + for _, r := range results { + require.NotEqual(t, "vec_del", r.ID) + } +} diff --git a/pkg/memory/types.go b/pkg/memory/types.go new file mode 100644 index 0000000000..7369656d3b --- /dev/null +++ b/pkg/memory/types.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package memory defines the types and interfaces for ToolHive's shared long-term memory system. +package memory + +import "time" + +// Type distinguishes the two long-term memory namespaces. +type Type string + +const ( + // TypeSemantic represents factual, aggregated knowledge and world-state memories + // (e.g. "company does not sponsor visas"). Contrast with TypeEpisodic. + TypeSemantic Type = "semantic" + // TypeProcedural represents how-to knowledge and step-based memories. + TypeProcedural Type = "procedural" + // TypeEpisodic represents time-indexed event records tied to a specific + // moment (e.g. "recruiter archived candidate on 2024-03-15 — visa required"). + // Use CreatedAfter/CreatedBefore in ListFilter to query timelines. + TypeEpisodic Type = "episodic" +) + +// AuthorType records whether a memory was written by a human or an agent. +type AuthorType string + +const ( + // AuthorHuman indicates the memory was written by a human user. + AuthorHuman AuthorType = "human" + // AuthorAgent indicates the memory was written by an AI agent. + AuthorAgent AuthorType = "agent" +) + +// SourceType records whether a memory entry originates from the store or is a +// read-only index of an installed Skill. +type SourceType string + +const ( + // SourceMemory indicates the entry originates from the writable memory store. + SourceMemory SourceType = "memory" + // SourceSkill indicates the entry is a read-only index of an installed Skill. + SourceSkill SourceType = "skill" + // SourceResource indicates the entry is a UI-managed resource document that + // is read-only to agents. Resources are written via the management REST API + // and are progressively discovered by agents through memory_search and MCP + // Resources protocol (resources/list, resources/read). + SourceResource SourceType = "resource" +) + +// EntryStatus is the lifecycle state of a memory entry. +type EntryStatus string + +const ( + // EntryStatusActive indicates the entry is in normal use. + EntryStatusActive EntryStatus = "active" + // EntryStatusFlagged indicates the entry has been marked for review. + EntryStatusFlagged EntryStatus = "flagged" + // EntryStatusExpired indicates the entry has passed its TTL. + EntryStatusExpired EntryStatus = "expired" + // EntryStatusArchived indicates the entry has been moved to the archive. + EntryStatusArchived EntryStatus = "archived" +) + +// ArchiveReason records why an entry was archived. +type ArchiveReason string + +const ( + // ArchiveReasonConsolidated indicates the entry was merged into a newer entry. + ArchiveReasonConsolidated ArchiveReason = "consolidated" + // ArchiveReasonCrystallized indicates the entry was promoted to a skill. + ArchiveReasonCrystallized ArchiveReason = "crystallized" + // ArchiveReasonManual indicates the entry was manually archived. + ArchiveReasonManual ArchiveReason = "manual" + // ArchiveReasonExpired indicates the entry exceeded its TTL. + ArchiveReasonExpired ArchiveReason = "expired" +) + +// Entry is the core domain type representing one stored memory. +type Entry struct { + ID string + Type Type + Content string + Tags []string + Author AuthorType + AgentID string + SessionID string + Source SourceType + SkillRef string + Status EntryStatus + TrustScore float32 + StalenessScore float32 + AccessCount int + LastAccessedAt time.Time + FlaggedAt *time.Time + FlagReason string + TTLDays *int + ExpiresAt *time.Time + ArchivedAt *time.Time + ConsolidatedInto string + CrystallizedInto string + History []Revision + CreatedAt time.Time + UpdatedAt time.Time +} + +// Revision records a single correction to a memory entry. +type Revision struct { + Content string + Author AuthorType + CorrectionNote string + Timestamp time.Time +} + +// ListFilter restricts results returned by MemoryStore.List. +type ListFilter struct { + Type *Type + Author *AuthorType + Tags []string + Source *SourceType + Status *EntryStatus + CreatedAfter *time.Time + CreatedBefore *time.Time + Limit int + Offset int +} + +// VectorFilter restricts similarity search to a subset of entries. +type VectorFilter struct { + Type *Type + Status *EntryStatus + Source *SourceType +} + +// ScoredID pairs an entry ID with its cosine similarity to a query. +type ScoredID struct { + ID string + Similarity float32 +} + +// ScoredEntry pairs a full Entry with its similarity to a query. +type ScoredEntry struct { + Entry Entry + Similarity float32 +} + +// ConflictResult describes a potentially conflicting existing memory returned +// during a write conflict check. +type ConflictResult struct { + ID string + Content string + Similarity float32 + TrustScore float32 +} diff --git a/pkg/memory/types_test.go b/pkg/memory/types_test.go new file mode 100644 index 0000000000..8163259793 --- /dev/null +++ b/pkg/memory/types_test.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/memory" +) + +func TestMemoryTypeConstants(t *testing.T) { + t.Parallel() + require.Equal(t, memory.Type("semantic"), memory.TypeSemantic) + require.Equal(t, memory.Type("procedural"), memory.TypeProcedural) + require.Equal(t, memory.Type("episodic"), memory.TypeEpisodic) + require.Equal(t, memory.AuthorType("human"), memory.AuthorHuman) + require.Equal(t, memory.AuthorType("agent"), memory.AuthorAgent) + require.Equal(t, memory.EntryStatus("active"), memory.EntryStatusActive) + require.Equal(t, memory.EntryStatus("flagged"), memory.EntryStatusFlagged) + require.Equal(t, memory.EntryStatus("expired"), memory.EntryStatusExpired) + require.Equal(t, memory.EntryStatus("archived"), memory.EntryStatusArchived) + require.Equal(t, memory.SourceType("memory"), memory.SourceMemory) + require.Equal(t, memory.SourceType("skill"), memory.SourceSkill) + require.Equal(t, memory.ArchiveReason("consolidated"), memory.ArchiveReasonConsolidated) + require.Equal(t, memory.ArchiveReason("crystallized"), memory.ArchiveReasonCrystallized) + require.Equal(t, memory.ArchiveReason("manual"), memory.ArchiveReasonManual) + require.Equal(t, memory.ArchiveReason("expired"), memory.ArchiveReasonExpired) +}