From 0e71541a5410f531ad78a76f1e813b94098d4647 Mon Sep 17 00:00:00 2001 From: Andrzej J Skalski Date: Thu, 5 Feb 2026 17:08:36 +0100 Subject: [PATCH 1/2] fix(langserver): support go-to-definition for plugin-defined rules Previously, go-to-definition only worked for core builtin functions. Plugin-defined rules like go_library, go_repo, etc. would return no results because they were parsed by a different parser instance than the one used by the language server. Changes: - Use parse.InitParser() to initialize the parser on BuildState, then get the same parser via parse.GetAspParser() for the language server - Add periodic loading of function definitions (every 2 seconds) so go-to-definition works progressively while the full parse runs - Add Range() method to cmap types to iterate over parsed ASTs - Add AllFunctionsByFile() to asp.Parser to retrieve function definitions - Fix file URIs to use absolute paths --- src/cmap/cerrmap.go | 11 +++++ src/cmap/cmap.go | 19 ++++++++ src/parse/asp/parser.go | 18 ++++++++ src/parse/init.go | 13 ++++++ tools/build_langserver/lsp/BUILD | 1 + tools/build_langserver/lsp/definition.go | 34 +++++++++----- tools/build_langserver/lsp/lsp.go | 58 +++++++++++++++++++++++- 7 files changed, 140 insertions(+), 14 deletions(-) diff --git a/src/cmap/cerrmap.go b/src/cmap/cerrmap.go index 687c9191bf..78c7dc8ed6 100644 --- a/src/cmap/cerrmap.go +++ b/src/cmap/cerrmap.go @@ -78,3 +78,14 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) { } return v.Val, v.Err } + +// Range calls f for each key-value pair in the map. +// No particular consistency guarantees are made during iteration. +func (m *ErrMap[K, V]) Range(f func(key K, val V)) { + m.m.Range(func(key K, val errV[V]) { + if val.Err != nil { + return // skip errors + } + f(key, val.Val) + }) +} diff --git a/src/cmap/cmap.go b/src/cmap/cmap.go index ce8508b454..f8058ef732 100644 --- a/src/cmap/cmap.go +++ b/src/cmap/cmap.go @@ -94,6 +94,14 @@ func (m *Map[K, V]) Values() []V { return ret } +// Range calls f for each key-value pair in the map. +// No particular consistency guarantees are made during iteration. +func (m *Map[K, V]) Range(f func(key K, val V)) { + for i := 0; i < len(m.shards); i++ { + m.shards[i].Range(f) + } +} + // An awaitableValue represents a value in the map & an awaitable channel for it to exist. type awaitableValue[V any] struct { Val V @@ -195,3 +203,14 @@ func (s *shard[K, V]) Contains(key K) bool { _, ok := s.m[key] return ok } + +// Range calls f for each key-value pair in this shard. +func (s *shard[K, V]) Range(f func(key K, val V)) { + s.l.RLock() + defer s.l.RUnlock() + for k, v := range s.m { + if v.Wait == nil { // Only include completed values + f(k, v.Val) + } + } +} diff --git a/src/parse/asp/parser.go b/src/parse/asp/parser.go index 36f055b966..f67a8605ae 100644 --- a/src/parse/asp/parser.go +++ b/src/parse/asp/parser.go @@ -257,6 +257,24 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) { } } +// AllFunctionsByFile returns all function definitions grouped by filename. +// This includes functions from builtins, plugins, and subincludes. +// It iterates over the ASTs stored by the interpreter. +func (p *Parser) AllFunctionsByFile() map[string][]*Statement { + if p.interpreter == nil || p.interpreter.asts == nil { + return nil + } + result := make(map[string][]*Statement) + p.interpreter.asts.Range(func(filename string, stmts []*Statement) { + for _, stmt := range stmts { + if stmt.FuncDef != nil { + result[filename] = append(result[filename], stmt) + } + } + }) + return result +} + // whitelistedKwargs returns true if the given built-in function name is allowed to // be called as non-kwargs. // TODO(peterebden): Come up with a syntax that exposes this directly in the file. diff --git a/src/parse/init.go b/src/parse/init.go index 663e265104..ee67dacda5 100644 --- a/src/parse/init.go +++ b/src/parse/init.go @@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState { return state } +// GetAspParser returns the underlying asp.Parser from the state's parser. +// This is useful for tools like the language server that need direct access to AST information. +// Returns nil if the state's parser is not set or is not an aspParser. +func GetAspParser(state *core.BuildState) *asp.Parser { + if state.Parser == nil { + return nil + } + if ap, ok := state.Parser.(*aspParser); ok { + return ap.parser + } + return nil +} + // aspParser implements the core.Parser interface around our parser package. type aspParser struct { parser *asp.Parser diff --git a/tools/build_langserver/lsp/BUILD b/tools/build_langserver/lsp/BUILD index f8c5f3485c..d0be35f006 100644 --- a/tools/build_langserver/lsp/BUILD +++ b/tools/build_langserver/lsp/BUILD @@ -17,6 +17,7 @@ go_library( "//rules", "//src/core", "//src/fs", + "//src/parse", "//src/parse/asp", "//src/plz", "//tools/build_langserver/lsp/astutils", diff --git a/tools/build_langserver/lsp/definition.go b/tools/build_langserver/lsp/definition.go index 9ee1c5df18..50c0fdd51c 100644 --- a/tools/build_langserver/lsp/definition.go +++ b/tools/build_langserver/lsp/definition.go @@ -18,20 +18,20 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca ast := h.parseIfNeeded(doc) f := doc.AspFile() - var locs []lsp.Location + locs := []lsp.Location{} pos := aspPos(params.Position) asp.WalkAST(ast, func(expr *asp.Expression) bool { - if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) { + exprStart := f.Pos(expr.Pos) + exprEnd := f.Pos(expr.EndPos) + if !asp.WithinRange(pos, exprStart, exprEnd) { return false } - if expr.Val.Ident != nil { if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" { locs = append(locs, loc) } return false } - if expr.Val.String != "" { label := astutils.TrimStrLit(expr.Val.String) if loc := h.findLabel(doc.PkgName, label); loc.URI != "" { @@ -39,20 +39,19 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca } return false } - return true }) - // It might also be a statement. + // It might also be a statement (e.g. a function call like go_library(...)) asp.WalkAST(ast, func(stmt *asp.Statement) bool { if stmt.Ident != nil { - endPos := f.Pos(stmt.Pos) + stmtStart := f.Pos(stmt.Pos) + endPos := stmtStart // TODO(jpoole): The AST should probably just have this information endPos.Column += len(stmt.Ident.Name) - if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) { - return false + if !asp.WithinRange(pos, stmtStart, endPos) { + return true // continue to other statements } - if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" { locs = append(locs, loc) } @@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location { } pkg := h.state.Graph.PackageByLabel(l) + if pkg == nil { + return lsp.Location{} + } uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) loc := lsp.Location{URI: uri} doc, err := h.maybeOpenDoc(uri) @@ -137,9 +139,17 @@ func findName(args []asp.CallArgument) string { // findGlobal returns the location of a global of the given name. func (h *Handler) findGlobal(name string) lsp.Location { - if f, present := h.builtins[name]; present { + h.mutex.Lock() + f, present := h.builtins[name] + h.mutex.Unlock() + if present { + filename := f.Pos.Filename + // Make path absolute if it's relative + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } return lsp.Location{ - URI: lsp.DocumentURI("file://" + f.Pos.Filename), + URI: lsp.DocumentURI("file://" + filename), Range: rng(f.Pos, f.EndPos), } } diff --git a/tools/build_langserver/lsp/lsp.go b/tools/build_langserver/lsp/lsp.go index b979360177..0662dc529c 100644 --- a/tools/build_langserver/lsp/lsp.go +++ b/tools/build_langserver/lsp/lsp.go @@ -20,6 +20,7 @@ import ( "github.com/thought-machine/please/rules" "github.com/thought-machine/please/src/core" "github.com/thought-machine/please/src/fs" + "github.com/thought-machine/please/src/parse" "github.com/thought-machine/please/src/parse/asp" "github.com/thought-machine/please/src/plz" ) @@ -195,16 +196,38 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul } h.state = core.NewBuildState(config) h.state.NeedBuild = false - // We need an unwrapped parser instance as well for raw access. - h.parser = asp.NewParser(h.state) + // Initialize the parser on state first, so that plz.RunHost uses the same parser. + // This ensures plugin subincludes are stored in the same AST cache we use. + parse.InitParser(h.state) + h.parser = parse.GetAspParser(h.state) + if h.parser == nil { + return nil, fmt.Errorf("failed to get asp parser from state") + } // Parse everything in the repo up front. // This is a lot easier than trying to do clever partial parses later on, although // eventually we may want that if we start dealing with truly large repos. go func() { + // Start a goroutine to periodically load parser functions as they become available. + // This allows go-to-definition to work progressively while the full parse runs. + done := make(chan struct{}) + go func() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + h.loadParserFunctions() + } + } + }() plz.RunHost(core.WholeGraph, h.state) + close(done) log.Debug("initial parse complete") h.buildPackageTree() log.Debug("built completion package tree") + h.loadParserFunctions() }() // Record all the builtin functions now if err := h.loadBuiltins(); err != nil { @@ -268,6 +291,37 @@ func (h *Handler) loadBuiltins() error { return nil } +// loadParserFunctions loads function definitions from the parser's ASTs. +// This includes plugin-defined functions like go_library, python_library, etc. +func (h *Handler) loadParserFunctions() { + funcsByFile := h.parser.AllFunctionsByFile() + if funcsByFile == nil { + return + } + h.mutex.Lock() + defer h.mutex.Unlock() + for filename, stmts := range funcsByFile { + // Read the file to create a File object for position conversion + data, err := os.ReadFile(filename) + if err != nil { + log.Warning("failed to read file %s: %v", filename, err) + continue + } + file := asp.NewFile(filename, data) + for _, stmt := range stmts { + name := stmt.FuncDef.Name + // Only add if not already present (don't override core builtins) + if _, present := h.builtins[name]; !present { + h.builtins[name] = builtin{ + Stmt: stmt, + Pos: file.Pos(stmt.Pos), + EndPos: file.Pos(stmt.EndPos), + } + } + } + } +} + // fromURI converts a DocumentURI to a path. func fromURI(uri lsp.DocumentURI) string { if !strings.HasPrefix(string(uri), "file://") { From 57b77f32794db566d5c21bb8168f65f9f1ce757d Mon Sep 17 00:00:00 2001 From: Andrzej J Skalski Date: Fri, 27 Feb 2026 16:55:38 +0100 Subject: [PATCH 2/2] made builtins an array, as suggested in review --- tools/build_langserver/lsp/completion.go | 6 +++--- tools/build_langserver/lsp/definition.go | 5 +++-- tools/build_langserver/lsp/lsp.go | 23 ++++++++++------------- tools/build_langserver/lsp/lsp_test.go | 4 ++-- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tools/build_langserver/lsp/completion.go b/tools/build_langserver/lsp/completion.go index 081af91949..4348b24fba 100644 --- a/tools/build_langserver/lsp/completion.go +++ b/tools/build_langserver/lsp/completion.go @@ -109,10 +109,10 @@ func (h *Handler) completeString(doc *doc, s string, line, col int) (*lsp.Comple // completeIdent completes an arbitrary identifier func (h *Handler) completeIdent(doc *doc, s string, line, col int) (*lsp.CompletionList, error) { list := &lsp.CompletionList{} - for name, f := range h.builtins { - if strings.HasPrefix(name, s) { + for name, builtins := range h.builtins { + if strings.HasPrefix(name, s) && len(builtins) > 0 { item := completionItem(name, "", line, col) - item.Documentation = f.Stmt.FuncDef.Docstring + item.Documentation = builtins[0].Stmt.FuncDef.Docstring item.Kind = lsp.CIKFunction list.Items = append(list.Items, item) } diff --git a/tools/build_langserver/lsp/definition.go b/tools/build_langserver/lsp/definition.go index 50c0fdd51c..3085d5d03f 100644 --- a/tools/build_langserver/lsp/definition.go +++ b/tools/build_langserver/lsp/definition.go @@ -140,9 +140,10 @@ func findName(args []asp.CallArgument) string { // findGlobal returns the location of a global of the given name. func (h *Handler) findGlobal(name string) lsp.Location { h.mutex.Lock() - f, present := h.builtins[name] + builtins := h.builtins[name] h.mutex.Unlock() - if present { + if len(builtins) > 0 { + f := builtins[0] filename := f.Pos.Filename // Make path absolute if it's relative if !filepath.IsAbs(filename) { diff --git a/tools/build_langserver/lsp/lsp.go b/tools/build_langserver/lsp/lsp.go index 0662dc529c..fbdbdcceb0 100644 --- a/tools/build_langserver/lsp/lsp.go +++ b/tools/build_langserver/lsp/lsp.go @@ -34,7 +34,7 @@ type Handler struct { mutex sync.Mutex // guards docs state *core.BuildState parser *asp.Parser - builtins map[string]builtin + builtins map[string][]builtin pkgs *pkg root string } @@ -56,7 +56,7 @@ func NewHandler() *Handler { return &Handler{ docs: map[string]*doc{}, pkgs: &pkg{}, - builtins: map[string]builtin{}, + builtins: map[string][]builtin{}, } } @@ -216,6 +216,7 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul for { select { case <-done: + h.loadParserFunctions() return case <-ticker.C: h.loadParserFunctions() @@ -227,7 +228,6 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul log.Debug("initial parse complete") h.buildPackageTree() log.Debug("built completion package tree") - h.loadParserFunctions() }() // Record all the builtin functions now if err := h.loadBuiltins(); err != nil { @@ -279,11 +279,11 @@ func (h *Handler) loadBuiltins() error { f := asp.NewFile(dest, data) for _, stmt := range stmts { if stmt.FuncDef != nil { - h.builtins[stmt.FuncDef.Name] = builtin{ + h.builtins[stmt.FuncDef.Name] = append(h.builtins[stmt.FuncDef.Name], builtin{ Stmt: stmt, Pos: f.Pos(stmt.Pos), EndPos: f.Pos(stmt.EndPos), - } + }) } } } @@ -310,14 +310,11 @@ func (h *Handler) loadParserFunctions() { file := asp.NewFile(filename, data) for _, stmt := range stmts { name := stmt.FuncDef.Name - // Only add if not already present (don't override core builtins) - if _, present := h.builtins[name]; !present { - h.builtins[name] = builtin{ - Stmt: stmt, - Pos: file.Pos(stmt.Pos), - EndPos: file.Pos(stmt.EndPos), - } - } + h.builtins[name] = append(h.builtins[name], builtin{ + Stmt: stmt, + Pos: file.Pos(stmt.Pos), + EndPos: file.Pos(stmt.EndPos), + }) } } } diff --git a/tools/build_langserver/lsp/lsp_test.go b/tools/build_langserver/lsp/lsp_test.go index ca7a0b4c76..8d05240f8c 100644 --- a/tools/build_langserver/lsp/lsp_test.go +++ b/tools/build_langserver/lsp/lsp_test.go @@ -458,7 +458,7 @@ func TestCompletionFunction(t *testing.T) { Kind: lsp.CIKFunction, InsertTextFormat: lsp.ITFPlainText, TextEdit: textEdit("plugin_repo", 0, 4, 0), - Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring, + Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring, }}, }, completions) } @@ -492,7 +492,7 @@ func TestCompletionPartialFunction(t *testing.T) { Kind: lsp.CIKFunction, InsertTextFormat: lsp.ITFPlainText, TextEdit: textEdit("plugin_repo", 0, 9, 0), - Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring, + Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring, }}, }, completions) }