diff --git a/Makefile b/Makefile index 11dacc626..b2c584eb7 100644 --- a/Makefile +++ b/Makefile @@ -59,10 +59,10 @@ clean: ## Delete intermediate build artifacts test: build ## Run unit tests $(GO) test -race -cover ./... $(GO) test -tags protolegacy ./... - cd internal/benchmarks && SKIP_DOWNLOAD_GOOGLEAPIS=true $(GO) test -race -cover ./... .PHONY: benchmarks benchmarks: build ## Run benchmarks + $(GO) test -bench -benchmem -v ./experimental/benchmark cd internal/benchmarks && $(GO) test -bench=. -benchmem -v ./... .PHONY: build diff --git a/experimental/ast/commas.go b/experimental/ast/commas.go index 2acd59cc3..a8dd98697 100644 --- a/experimental/ast/commas.go +++ b/experimental/ast/commas.go @@ -66,7 +66,7 @@ func (c commas[T, _]) AppendComma(value T, comma token.Token) { } func (c commas[T, _]) InsertComma(n int, value T, comma token.Token) { - c.file.Nodes().panicIfNotOurs(comma) + c.file.Nodes().panicIfNotOurs(comma.Context()) v := c.SliceInserter.Unwrap(n, value) v.Comma = comma.ID() @@ -74,6 +74,6 @@ func (c commas[T, _]) InsertComma(n int, value T, comma token.Token) { } func (c commas[T, _]) SetComma(n int, comma token.Token) { - c.file.Nodes().panicIfNotOurs(comma) + c.file.Nodes().panicIfNotOurs(comma.Context()) (*c.SliceInserter.Slice)[n].Comma = comma.ID() } diff --git a/experimental/ast/decl.go b/experimental/ast/decl.go index 7cbf98f89..9c68a2567 100644 --- a/experimental/ast/decl.go +++ b/experimental/ast/decl.go @@ -124,18 +124,24 @@ func (d DeclAny) AsRange() DeclRange { // Span implements [source.Spanner]. func (d DeclAny) Span() source.Span { - // At most one of the below will produce a non-zero decl, and that will be - // the span selected by source.Join. If all of them are zero, this produces - // the zero span. - return source.Join( - d.AsEmpty(), - d.AsSyntax(), - d.AsPackage(), - d.AsImport(), - d.AsDef(), - d.AsBody(), - d.AsRange(), - ) + switch d.Kind() { + case DeclKindBody: + return d.AsBody().Span() + case DeclKindDef: + return d.AsDef().Span() + case DeclKindEmpty: + return d.AsEmpty().Span() + case DeclKindImport: + return d.AsImport().Span() + case DeclKindPackage: + return d.AsPackage().Span() + case DeclKindRange: + return d.AsRange().Span() + case DeclKindSyntax: + return d.AsSyntax().Span() + default: + return source.Span{} + } } func (DeclKind) DecodeDynID(lo, _ int32) DeclKind { diff --git a/experimental/ast/decl_body.go b/experimental/ast/decl_body.go index 83a61df13..e64854299 100644 --- a/experimental/ast/decl_body.go +++ b/experimental/ast/decl_body.go @@ -108,7 +108,7 @@ func (d DeclBody) Decls() seq.Inserter[DeclAny] { return id.WrapDyn(d.Context(), id.NewDyn(k, p)) }, func(_ int, d DeclAny) (DeclKind, id.ID[DeclAny]) { - d.Context().Nodes().panicIfNotOurs(d) + d.Context().Nodes().panicIfNotOurs(d.Context()) return d.ID().Kind(), d.ID().Value() }, ) diff --git a/experimental/ast/decl_def.go b/experimental/ast/decl_def.go index 0d0e56f8c..22168d5c6 100644 --- a/experimental/ast/decl_def.go +++ b/experimental/ast/decl_def.go @@ -520,15 +520,15 @@ func (d DeclDef) Span() source.Span { return source.Span{} } - return source.Join( - d.Type(), - d.Name(), - d.Signature(), - d.Equals(), - d.Value(), - d.Options(), - d.Body(), - d.Semicolon(), + return source.JoinSpans( + d.Type().Span(), + d.Name().Span(), + d.Signature().Span(), + d.Equals().Span(), + d.Value().Span(), + d.Options().Span(), + d.Body().Span(), + d.Semicolon().Span(), ) } diff --git a/experimental/ast/decl_file.go b/experimental/ast/decl_file.go index c8ec06cce..78f965c73 100644 --- a/experimental/ast/decl_file.go +++ b/experimental/ast/decl_file.go @@ -323,7 +323,7 @@ func (d DeclImport) ModifierTokens() seq.Inserter[token.Token] { return seq.NewSliceInserter(&d.Raw().modifiers, func(_ int, e token.ID) token.Token { return id.Wrap(d.Context().Stream(), e) }, func(_ int, t token.Token) token.ID { - d.Context().Nodes().panicIfNotOurs(t) + d.Context().Nodes().panicIfNotOurs(t.Context()) return t.ID() }, ) diff --git a/experimental/ast/decl_range.go b/experimental/ast/decl_range.go index f0c4cb041..56757d03d 100644 --- a/experimental/ast/decl_range.go +++ b/experimental/ast/decl_range.go @@ -93,7 +93,7 @@ func (d DeclRange) Ranges() Commas[ExprAny] { return id.WrapDyn(d.Context(), c.Value) }, func(_ int, e ExprAny) withComma[id.Dyn[ExprAny, ExprKind]] { - d.Context().Nodes().panicIfNotOurs(e) + d.Context().Nodes().panicIfNotOurs(e.Context()) return withComma[id.Dyn[ExprAny, ExprKind]]{Value: e.ID()} }, ), diff --git a/experimental/ast/expr.go b/experimental/ast/expr.go index f1b76be3f..dbf0ef77b 100644 --- a/experimental/ast/expr.go +++ b/experimental/ast/expr.go @@ -143,18 +143,26 @@ func (e ExprAny) AsField() ExprField { // Span implements [source.Spanner]. func (e ExprAny) Span() source.Span { - // At most one of the below will produce a non-nil type, and that will be - // the span selected by source.Join. If all of them are nil, this produces - // the nil span. - return source.Join( - e.AsLiteral(), - e.AsPath(), - e.AsPrefixed(), - e.AsRange(), - e.AsArray(), - e.AsDict(), - e.AsField(), - ) + switch e.Kind() { + case ExprKindArray: + return e.AsArray().Span() + case ExprKindDict: + return e.AsDict().Span() + case ExprKindError: + return e.AsError().Span() + case ExprKindField: + return e.AsField().Span() + case ExprKindLiteral: + return e.AsLiteral().Span() + case ExprKindPath: + return e.AsPath().Span() + case ExprKindPrefixed: + return e.AsPrefixed().Span() + case ExprKindRange: + return e.AsRange().Span() + default: + return source.Span{} + } } // ExprError represents an unrecoverable parsing error in an expression context. diff --git a/experimental/ast/expr_array.go b/experimental/ast/expr_array.go index baa745ae2..8f5a4d494 100644 --- a/experimental/ast/expr_array.go +++ b/experimental/ast/expr_array.go @@ -69,7 +69,7 @@ func (e ExprArray) Elements() Commas[ExprAny] { return id.WrapDyn(e.Context(), c.Value) }, func(_ int, e ExprAny) withComma[id.Dyn[ExprAny, ExprKind]] { - e.Context().Nodes().panicIfNotOurs(e) + e.Context().Nodes().panicIfNotOurs(e.Context()) return withComma[id.Dyn[ExprAny, ExprKind]]{Value: e.ID()} }, ), diff --git a/experimental/ast/expr_dict.go b/experimental/ast/expr_dict.go index 437935157..1af3ae939 100644 --- a/experimental/ast/expr_dict.go +++ b/experimental/ast/expr_dict.go @@ -73,7 +73,7 @@ func (e ExprDict) Elements() Commas[ExprField] { return id.Wrap(e.Context(), c.Value) }, func(_ int, e ExprField) withComma[id.ID[ExprField]] { - e.Context().Nodes().panicIfNotOurs(e) + e.Context().Nodes().panicIfNotOurs(e.Context()) return withComma[id.ID[ExprField]]{Value: e.ID()} }, ), diff --git a/experimental/ast/nodes.go b/experimental/ast/nodes.go index 0c26036ad..8d975da16 100644 --- a/experimental/ast/nodes.go +++ b/experimental/ast/nodes.go @@ -42,7 +42,7 @@ func (n *Nodes) File() *File { // // To create a path component with an extension value, see [Nodes.NewExtensionComponent]. func (n *Nodes) NewPathComponent(separator, name token.Token) PathComponent { - n.panicIfNotOurs(separator, name) + n.panicIfNotOurs(separator.Context(), name.Context()) if !separator.IsZero() { if separator.Kind() != token.Keyword || (separator.Text() != "." && separator.Text() != "/") { panic(fmt.Sprintf("protocompile/ast: passed non '.' or '/' separator to NewPathComponent: %s", separator)) @@ -62,7 +62,7 @@ func (n *Nodes) NewPathComponent(separator, name token.Token) PathComponent { // NewExtensionComponent returns a new extension path component containing the // given path. func (n *Nodes) NewExtensionComponent(separator token.Token, path Path) PathComponent { - n.panicIfNotOurs(separator, path) + n.panicIfNotOurs(separator.Context(), path.Context()) if !separator.IsZero() { if separator.Kind() != token.Keyword || (separator.Text() != "." && separator.Text() != "/") { panic(fmt.Sprintf("protocompile/ast: passed non '.' or '/' separator to NewPathComponent: %s", separator)) @@ -75,15 +75,14 @@ func (n *Nodes) NewExtensionComponent(separator token.Token, path Path) PathComp start := stream.NewPunct("(") end := stream.NewPunct(")") var children []token.Token - path.Components(func(pc PathComponent) bool { + for pc := range path.Components() { if !pc.Separator().IsZero() { children = append(children, pc.Separator()) } if !pc.Name().IsZero() { children = append(children, pc.Name()) } - return true - }) + } stream.NewFused(start, end, children...) name = start.ID() @@ -107,7 +106,7 @@ func (n *Nodes) NewPath(components ...PathComponent) Path { } for _, t := range components { - n.panicIfNotOurs(t) + n.panicIfNotOurs(t.Context()) } stream := n.stream @@ -139,7 +138,7 @@ func (n *Nodes) NewPath(components ...PathComponent) Path { // NewDeclEmpty creates a new DeclEmpty node. func (n *Nodes) NewDeclEmpty(semicolon token.Token) DeclEmpty { - n.panicIfNotOurs(semicolon) + n.panicIfNotOurs(semicolon.Context()) decl := id.Wrap(n.File(), id.ID[DeclEmpty](n.decls.empties.NewCompressed(rawDeclEmpty{ semi: semicolon.ID(), @@ -150,7 +149,9 @@ func (n *Nodes) NewDeclEmpty(semicolon token.Token) DeclEmpty { // NewDeclSyntax creates a new DeclSyntax node. func (n *Nodes) NewDeclSyntax(args DeclSyntaxArgs) DeclSyntax { - n.panicIfNotOurs(args.Keyword, args.Equals, args.Value, args.Options, args.Semicolon) + n.panicIfNotOurs( + args.Keyword.Context(), args.Equals.Context(), args.Value.Context(), + args.Options.Context(), args.Semicolon.Context()) return id.Wrap(n.File(), id.ID[DeclSyntax](n.decls.syntaxes.NewCompressed(rawDeclSyntax{ keyword: args.Keyword.ID(), @@ -163,7 +164,8 @@ func (n *Nodes) NewDeclSyntax(args DeclSyntaxArgs) DeclSyntax { // NewDeclPackage creates a new DeclPackage node. func (n *Nodes) NewDeclPackage(args DeclPackageArgs) DeclPackage { - n.panicIfNotOurs(args.Keyword, args.Path, args.Options, args.Semicolon) + n.panicIfNotOurs(args.Keyword.Context(), args.Path.Context(), + args.Options.Context(), args.Semicolon.Context()) return id.Wrap(n.File(), id.ID[DeclPackage](n.decls.packages.NewCompressed(rawDeclPackage{ keyword: args.Keyword.ID(), @@ -175,14 +177,15 @@ func (n *Nodes) NewDeclPackage(args DeclPackageArgs) DeclPackage { // NewDeclImport creates a new DeclImport node. func (n *Nodes) NewDeclImport(args DeclImportArgs) DeclImport { - n.panicIfNotOurs(args.Keyword, args.ImportPath, args.Options, args.Semicolon) + n.panicIfNotOurs(args.Keyword.Context(), args.ImportPath.Context(), + args.Options.Context(), args.Semicolon.Context()) return id.Wrap(n.File(), id.ID[DeclImport](n.decls.imports.NewCompressed(rawDeclImport{ keyword: args.Keyword.ID(), modifiers: slices.Collect(iterx.Map( slices.Values(args.Modifiers), func(t token.Token) token.ID { - n.panicIfNotOurs(t) + n.panicIfNotOurs(t.Context()) return t.ID() }), ), @@ -195,8 +198,9 @@ func (n *Nodes) NewDeclImport(args DeclImportArgs) DeclImport { // NewDeclDef creates a new DeclDef node. func (n *Nodes) NewDeclDef(args DeclDefArgs) DeclDef { n.panicIfNotOurs( - args.Keyword, args.Type, args.Name, args.Returns, - args.Equals, args.Value, args.Options, args.Body, args.Semicolon) + args.Keyword.Context(), args.Type.Context(), args.Name.Context(), + args.Returns.Context(), args.Equals.Context(), args.Value.Context(), + args.Options.Context(), args.Body.Context(), args.Semicolon.Context()) raw := rawDeclDef{ name: args.Name.raw, @@ -225,7 +229,7 @@ func (n *Nodes) NewDeclDef(args DeclDefArgs) DeclDef { // // To add declarations to the returned body, use [DeclBody.Append]. func (n *Nodes) NewDeclBody(braces token.Token) DeclBody { - n.panicIfNotOurs(braces) + n.panicIfNotOurs(braces.Context()) return id.Wrap(n.File(), id.ID[DeclBody](n.decls.bodies.NewCompressed(rawDeclBody{ braces: braces.ID(), @@ -236,7 +240,7 @@ func (n *Nodes) NewDeclBody(braces token.Token) DeclBody { // // To add ranges to the returned declaration, use [DeclRange.Append]. func (n *Nodes) NewDeclRange(args DeclRangeArgs) DeclRange { - n.panicIfNotOurs(args.Keyword, args.Options, args.Semicolon) + n.panicIfNotOurs(args.Keyword.Context(), args.Options.Context(), args.Semicolon.Context()) return id.Wrap(n.File(), id.ID[DeclRange](n.decls.ranges.NewCompressed(rawDeclRange{ keyword: args.Keyword.ID(), @@ -247,7 +251,7 @@ func (n *Nodes) NewDeclRange(args DeclRangeArgs) DeclRange { // NewExprPrefixed creates a new ExprPrefixed node. func (n *Nodes) NewExprPrefixed(args ExprPrefixedArgs) ExprPrefixed { - n.panicIfNotOurs(args.Prefix, args.Expr) + n.panicIfNotOurs(args.Prefix.Context(), args.Expr.Context()) return id.Wrap(n.File(), id.ID[ExprPrefixed](n.exprs.prefixes.NewCompressed(rawExprPrefixed{ prefix: args.Prefix.ID(), @@ -257,7 +261,7 @@ func (n *Nodes) NewExprPrefixed(args ExprPrefixedArgs) ExprPrefixed { // NewExprRange creates a new ExprRange node. func (n *Nodes) NewExprRange(args ExprRangeArgs) ExprRange { - n.panicIfNotOurs(args.Start, args.To, args.End) + n.panicIfNotOurs(args.Start.Context(), args.To.Context(), args.End.Context()) return id.Wrap(n.File(), id.ID[ExprRange](n.exprs.ranges.NewCompressed(rawExprRange{ to: args.To.ID(), @@ -270,7 +274,7 @@ func (n *Nodes) NewExprRange(args ExprRangeArgs) ExprRange { // // To add elements to the returned expression, use [ExprArray.Append]. func (n *Nodes) NewExprArray(brackets token.Token) ExprArray { - n.panicIfNotOurs(brackets) + n.panicIfNotOurs(brackets.Context()) return id.Wrap(n.File(), id.ID[ExprArray](n.exprs.arrays.NewCompressed(rawExprArray{ brackets: brackets.ID(), @@ -281,7 +285,7 @@ func (n *Nodes) NewExprArray(brackets token.Token) ExprArray { // // To add elements to the returned expression, use [ExprDict.Append]. func (n *Nodes) NewExprDict(braces token.Token) ExprDict { - n.panicIfNotOurs(braces) + n.panicIfNotOurs(braces.Context()) return id.Wrap(n.File(), id.ID[ExprDict](n.exprs.dicts.NewCompressed(rawExprDict{ braces: braces.ID(), @@ -290,7 +294,7 @@ func (n *Nodes) NewExprDict(braces token.Token) ExprDict { // NewExprField creates a new ExprPrefixed node. func (n *Nodes) NewExprField(args ExprFieldArgs) ExprField { - n.panicIfNotOurs(args.Key, args.Colon, args.Value) + n.panicIfNotOurs(args.Key.Context(), args.Colon.Context(), args.Value.Context()) return id.Wrap(n.File(), id.ID[ExprField](n.exprs.fields.NewCompressed(rawExprField{ key: args.Key.ID(), @@ -301,7 +305,7 @@ func (n *Nodes) NewExprField(args ExprFieldArgs) ExprField { // NewTypePrefixed creates a new TypePrefixed node. func (n *Nodes) NewTypePrefixed(args TypePrefixedArgs) TypePrefixed { - n.panicIfNotOurs(args.Prefix, args.Type) + n.panicIfNotOurs(args.Prefix.Context(), args.Type.Context()) return id.Wrap(n.File(), id.ID[TypePrefixed](n.types.prefixes.NewCompressed(rawTypePrefixed{ prefix: args.Prefix.ID(), @@ -313,7 +317,7 @@ func (n *Nodes) NewTypePrefixed(args TypePrefixedArgs) TypePrefixed { // // To add arguments to the returned type, use [TypeGeneric.Append]. func (n *Nodes) NewTypeGeneric(args TypeGenericArgs) TypeGeneric { - n.panicIfNotOurs(args.Path, args.AngleBrackets) + n.panicIfNotOurs(args.Path.Context(), args.AngleBrackets.Context()) return id.Wrap(n.File(), id.ID[TypeGeneric](n.types.generics.NewCompressed(rawTypeGeneric{ path: args.Path.raw, @@ -323,7 +327,7 @@ func (n *Nodes) NewTypeGeneric(args TypeGenericArgs) TypeGeneric { // NewCompactOptions creates a new CompactOptions node. func (n *Nodes) NewCompactOptions(brackets token.Token) CompactOptions { - n.panicIfNotOurs(brackets) + n.panicIfNotOurs(brackets.Context()) return id.Wrap(n.File(), id.ID[CompactOptions](n.options.NewCompressed(rawCompactOptions{ brackets: brackets.ID(), @@ -333,30 +337,27 @@ func (n *Nodes) NewCompactOptions(brackets token.Token) CompactOptions { // panicIfNotOurs checks that a contextual value is owned by this context, and panics if not. // // Does not panic if that is zero or has a zero context. Panics if n is zero. -func (n *Nodes) panicIfNotOurs(that ...any) { +func (n *Nodes) panicIfNotOurs(that ...interface{ Path() string }) { for _, that := range that { - if that == nil { + var path string + switch ctx := that.(type) { + case nil: continue - } - var path string - switch that := that.(type) { - case interface{ Context() *token.Stream }: - ctx := that.Context() + case *token.Stream: if ctx == nil || ctx == n.File().Stream() { continue } path = ctx.Path() - case interface{ Context() *File }: - ctx := that.Context() + case *File: if ctx == nil || ctx == n.File() { continue } path = ctx.Stream().Path() default: - continue + panic(fmt.Errorf("protocompile/ast: invalid type %T", that)) } panic(fmt.Sprintf( diff --git a/experimental/ast/options.go b/experimental/ast/options.go index 88bd63dea..ed30ecbbd 100644 --- a/experimental/ast/options.go +++ b/experimental/ast/options.go @@ -44,7 +44,7 @@ type Option struct { // Span implements [source.Spanner]. func (o Option) Span() source.Span { - return source.Join(o.Path, o.Equals, o.Value) + return source.JoinSpans(o.Path.Span(), o.Equals.Span(), o.Value.Span()) } type rawOption struct { @@ -76,7 +76,8 @@ func (o CompactOptions) Entries() Commas[Option] { return c.Value.With(o.Context()) }, func(_ int, v Option) withComma[rawOption] { - o.Context().Nodes().panicIfNotOurs(v.Path, v.Equals, v.Value) + o.Context().Nodes().panicIfNotOurs( + v.Path.Context(), v.Equals.Context(), v.Value.Context()) return withComma[rawOption]{Value: rawOption{ path: v.Path.ID(), equals: v.Equals.ID(), diff --git a/experimental/ast/path.go b/experimental/ast/path.go index fe794d921..b61af9719 100644 --- a/experimental/ast/path.go +++ b/experimental/ast/path.go @@ -16,6 +16,7 @@ package ast import ( "fmt" + "iter" "strings" "github.com/bufbuild/protocompile/experimental/ast/predeclared" @@ -83,7 +84,7 @@ func (p Path) ID() PathID { // Absolute returns whether this path starts with a dot. func (p Path) Absolute() bool { - first, ok := iterx.First(p.Components) + first, ok := iterx.First(p.Components()) return ok && !first.Separator().IsZero() } @@ -98,7 +99,7 @@ func (p Path) IsSynthetic() bool { // // If called on zero or a relative path, returns p. func (p Path) ToRelative() Path { - for pc := range p.Components { + for pc := range p.Components() { if !pc.IsEmpty() { p.raw.start = pc.name break @@ -110,11 +111,15 @@ func (p Path) ToRelative() Path { // AsIdent returns the single identifier that comprises this path, or // the zero token. func (p Path) AsIdent() token.Token { - first, _ := iterx.OnlyOne(p.Components) - if !first.Separator().IsZero() { + if p.raw.start != p.raw.end { return token.Zero } - return first.AsIdent() + + tok := id.Wrap(p.Context().Stream(), p.raw.start) + if tok.Kind() != token.Ident { + return token.Zero + } + return tok } // AsPredeclared returns the [predeclared.Name] that this path represents. @@ -133,7 +138,7 @@ func (p Path) AsKeyword() keyword.Keyword { // IsIdents returns whether p is a sequence of exactly the given identifiers. func (p Path) IsIdents(idents ...string) bool { - for i, pc := range iterx.Enumerate(p.Components) { + for i, pc := range iterx.Enumerate(p.Components()) { if i >= len(idents) || pc.AsIdent().Text() != idents[i] { break } @@ -149,56 +154,58 @@ func (p Path) IsIdents(idents ...string) bool { func (p Path) Span() source.Span { // No need to check for zero here, if p is zero both start and end will be // zero tokens. - return source.Join( - id.Wrap(p.Context().Stream(), p.raw.start), - id.Wrap(p.Context().Stream(), p.raw.end), + return source.JoinSpans( + id.Wrap(p.Context().Stream(), p.raw.start).Span(), + id.Wrap(p.Context().Stream(), p.raw.end).Span(), ) } // Components is an [iter.Seq] that ranges over each component in this path. // Specifically, it yields the (possibly zero) dot that precedes the component, // and the identifier token. -func (p Path) Components(yield func(PathComponent) bool) { - if p.IsZero() { - return - } - - var cursor *token.Cursor - first := id.Wrap(p.Context().Stream(), p.raw.start) - if p.IsSynthetic() { - cursor = first.SyntheticChildren(p.raw.synthRange()) - } else { - cursor = token.NewCursorAt(first) - } +func (p Path) Components() iter.Seq[PathComponent] { + return func(yield func(PathComponent) bool) { + if p.IsZero() { + return + } - var sep token.Token - var idx uint32 - for tok := range cursor.Rest() { - if !p.IsSynthetic() && tok.ID() > p.raw.end { - // We've reached the end of the path. - break + var cursor *token.Cursor + first := id.Wrap(p.Context().Stream(), p.raw.start) + if p.IsSynthetic() { + cursor = first.SyntheticChildren(p.raw.synthRange()) + } else { + cursor = token.NewCursorAt(first) } - if tok.Text() == "." || tok.Text() == "/" { - if !sep.IsZero() { - // Uh-oh, empty path component! - if !yield(PathComponent{p.withContext, p.raw, sep.ID(), 0, idx}) { - return + var sep token.Token + var idx uint32 + for tok := range cursor.Rest() { + if !p.IsSynthetic() && tok.ID() > p.raw.end { + // We've reached the end of the path. + break + } + + if tok.Text() == "." || tok.Text() == "/" { + if !sep.IsZero() { + // Uh-oh, empty path component! + if !yield(PathComponent{p.withContext, p.raw, sep.ID(), 0, idx}) { + return + } + idx++ } - idx++ + sep = tok + continue } - sep = tok - continue - } - if !yield(PathComponent{p.withContext, p.raw, sep.ID(), tok.ID(), idx}) { - return + if !yield(PathComponent{p.withContext, p.raw, sep.ID(), tok.ID(), idx}) { + return + } + idx++ + sep = token.Zero + } + if !sep.IsZero() { + yield(PathComponent{p.withContext, p.raw, sep.ID(), 0, idx}) } - idx++ - sep = token.Zero - } - if !sep.IsZero() { - yield(PathComponent{p.withContext, p.raw, sep.ID(), 0, idx}) } } @@ -223,7 +230,7 @@ func (p Path) Split(n int) (prefix, suffix Path) { var i int var prev PathComponent var found bool - for pc := range p.Components { + for pc := range p.Components() { if n > 0 { prev = pc n-- @@ -293,7 +300,7 @@ func (p Path) Canonicalized() string { } func (p Path) canonicalized(out *strings.Builder) { - for i, pc := range iterx.Enumerate(p.Components) { + for i, pc := range iterx.Enumerate(p.Components()) { if pc.Name().IsZero() { continue } @@ -313,7 +320,7 @@ func (p Path) canonicalized(out *strings.Builder) { func (p Path) isCanonical() bool { var prev PathComponent - for pc := range p.Components { + for pc := range p.Components() { sep := pc.Separator() name := pc.Name() @@ -510,7 +517,7 @@ func (p PathComponent) IsEmpty() bool { // Next returns the next path component after this one, if there is one. func (p PathComponent) Next() PathComponent { _, after := p.SplitAfter() - next, _ := iterx.First(after.Components) + next, _ := iterx.First(after.Components()) return next } @@ -557,7 +564,7 @@ func (p PathComponent) AsIdent() token.Token { // Span implements [source.Spanner]. func (p PathComponent) Span() source.Span { - return source.Join(p.Separator(), p.Name()) + return source.JoinSpans(p.Separator().Span(), p.Name().Span()) } func (p PathID) synthRange() (start, end int) { diff --git a/experimental/ast/path_test.go b/experimental/ast/path_test.go index d0257cb34..a6f8a6b3f 100644 --- a/experimental/ast/path_test.go +++ b/experimental/ast/path_test.go @@ -80,24 +80,24 @@ func TestNaturalSplit(t *testing.T) { pathEq(t, start, components[:2]) pathEq(t, end, components[2:]) - start, end = nth(path.Components, 0).SplitBefore() + start, end = nth(path.Components(), 0).SplitBefore() pathEq(t, start, [][2]token.Token{}) pathEq(t, end, components) - start, end = nth(path.Components, 0).SplitAfter() + start, end = nth(path.Components(), 0).SplitAfter() pathEq(t, start, components[:1]) pathEq(t, end, components[1:]) - start, end = nth(path.Components, 1).SplitBefore() + start, end = nth(path.Components(), 1).SplitBefore() pathEq(t, start, components[:1]) pathEq(t, end, components[1:]) - start, end = nth(path.Components, 1).SplitAfter() + start, end = nth(path.Components(), 1).SplitAfter() pathEq(t, start, components[:2]) pathEq(t, end, components[2:]) - start, end = nth(path.Components, 3).SplitBefore() + start, end = nth(path.Components(), 3).SplitBefore() pathEq(t, start, components[:3]) pathEq(t, end, components[3:]) - start, end = nth(path.Components, 3).SplitAfter() + start, end = nth(path.Components(), 3).SplitAfter() pathEq(t, start, components) pathEq(t, end, [][2]token.Token{}) } @@ -153,7 +153,7 @@ func TestSyntheticSplit(t *testing.T) { func pathEq(t *testing.T, path ast.Path, want [][2]token.Token) { t.Helper() - components := slices.Collect(iterx.Map(path.Components, func(pc ast.PathComponent) [2]token.Token { + components := slices.Collect(iterx.Map(path.Components(), func(pc ast.PathComponent) [2]token.Token { return [2]token.Token{pc.Separator(), pc.Name()} })) stringEq(t, components, want) diff --git a/experimental/ast/type.go b/experimental/ast/type.go index f5d05784c..ec7f59705 100644 --- a/experimental/ast/type.go +++ b/experimental/ast/type.go @@ -122,14 +122,18 @@ func (t TypeAny) RemovePrefixes() TypeAny { // source.Span implements [source.Spanner]. func (t TypeAny) Span() source.Span { - // At most one of the below will produce a non-zero type, and that will be - // the span selected by source.Join. If all of them are zero, this produces - // the zero span. - return source.Join( - t.AsPath(), - t.AsPrefixed(), - t.AsGeneric(), - ) + switch t.Kind() { + case TypeKindError: + return t.AsError().Span() + case TypeKindGeneric: + return t.AsGeneric().Span() + case TypeKindPath: + return t.AsPath().Span() + case TypeKindPrefixed: + return t.AsPrefixed().Span() + default: + return source.Span{} + } } // TypeError represents an unrecoverable parsing error in a type context. diff --git a/experimental/ast/type_generic.go b/experimental/ast/type_generic.go index a7a0690b6..06cf448a4 100644 --- a/experimental/ast/type_generic.go +++ b/experimental/ast/type_generic.go @@ -139,7 +139,7 @@ func (d TypeList) Brackets() token.Token { // SetBrackets sets the token tree for the brackets wrapping the argument list. func (d TypeList) SetBrackets(brackets token.Token) { - d.Context().Nodes().panicIfNotOurs(brackets) + d.Context().Nodes().panicIfNotOurs(brackets.Context()) d.raw.brackets = brackets.ID() } @@ -159,7 +159,7 @@ func (d TypeList) At(n int) TypeAny { // SetAt implements [seq.Setter]. func (d TypeList) SetAt(n int, ty TypeAny) { - d.Context().Nodes().panicIfNotOurs(ty) + d.Context().Nodes().panicIfNotOurs(ty.Context()) d.raw.args[n].Value = ty.ID() } @@ -185,14 +185,14 @@ func (d TypeList) AppendComma(value TypeAny, comma token.Token) { // InsertComma implements [Commas]. func (d TypeList) InsertComma(n int, ty TypeAny, comma token.Token) { - d.Context().Nodes().panicIfNotOurs(ty, comma) + d.Context().Nodes().panicIfNotOurs(ty.Context(), comma.Context()) d.raw.args = slices.Insert(d.raw.args, n, withComma[id.Dyn[TypeAny, TypeKind]]{ty.ID(), comma.ID()}) } // SetComma implements [Commas]. func (d TypeList) SetComma(n int, comma token.Token) { - d.Context().Nodes().panicIfNotOurs(comma) + d.Context().Nodes().panicIfNotOurs(comma.Context()) d.raw.args[n].Comma = comma.ID() } diff --git a/experimental/ast/type_prefixed.go b/experimental/ast/type_prefixed.go index 83aba5b11..8aba25b79 100644 --- a/experimental/ast/type_prefixed.go +++ b/experimental/ast/type_prefixed.go @@ -98,5 +98,5 @@ func (t TypePrefixed) Span() source.Span { return source.Span{} } - return source.Join(t.PrefixToken(), t.Type()) + return source.JoinSpans(t.PrefixToken().Span(), t.Type().Span()) } diff --git a/experimental/benchmark/benchmark_test.go b/experimental/benchmark/benchmark_test.go new file mode 100644 index 000000000..fcadec167 --- /dev/null +++ b/experimental/benchmark/benchmark_test.go @@ -0,0 +1,138 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package benchmark + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bufbuild/protocompile/experimental/fdp" + "github.com/bufbuild/protocompile/experimental/incremental" + "github.com/bufbuild/protocompile/experimental/incremental/queries" + "github.com/bufbuild/protocompile/experimental/ir" + "github.com/bufbuild/protocompile/experimental/source" + "github.com/bufbuild/protocompile/internal/ext/bitsx" + "github.com/bufbuild/protocompile/internal/testing/googleapis" + "github.com/bufbuild/protocompile/internal/testing/memory" +) + +func BenchmarkCompileGoogleapis(b *testing.B) { + workspace, sources := googleapis.Get() + if workspace == nil { + b.Skip() + } + sources = &source.Openers{sources, source.WKTs()} + benchmark(b, sources, workspace) +} + +func BenchmarkCompileDescriptor(b *testing.B) { + sources := &source.Openers{source.WKTs()} + workspace := source.NewWorkspace("google/protobuf/descriptor.proto") + benchmark(b, sources, workspace) +} + +func benchmark(b *testing.B, sources source.Opener, workspace source.Workspace) { + for _, what := range []string{"hot", "cold"} { + hot := what == "hot" + b.Run(what, func(b *testing.B) { + b.Run("link", func(b *testing.B) { + exec := incremental.New() + sess := new(ir.Session) + for b.Loop() { + if !hot { + exec = incremental.New() + sess = new(ir.Session) + } + _, _, _ = incremental.Run(b.Context(), exec, queries.Link{ + Opener: sources, + Session: sess, + Workspace: workspace, + }) + } + }) + + b.Run("desc", func(b *testing.B) { + exec := incremental.New() + sess := new(ir.Session) + for b.Loop() { + if !hot { + exec = incremental.New() + sess = new(ir.Session) + } + result, _, _ := incremental.Run(b.Context(), exec, queries.Link{ + Opener: sources, + Session: sess, + Workspace: workspace, + }) + _, _ = fdp.DescriptorSetBytes(result[0].Value) + } + }) + + b.Run("sci", func(b *testing.B) { + exec := incremental.New() + sess := new(ir.Session) + for b.Loop() { + if !hot { + exec = incremental.New() + sess = new(ir.Session) + } + result, _, _ := incremental.Run(b.Context(), exec, queries.Link{ + Opener: sources, + Session: sess, + Workspace: workspace, + }) + _, _ = fdp.DescriptorSetBytes(result[0].Value, fdp.IncludeSourceCodeInfo(true)) + } + }) + }) + } +} + +func TestCompileGoogleapisMemory(t *testing.T) { + workspace, sources := googleapis.Get() + if workspace == nil { + t.Skip() + } + sources = &source.Openers{sources, source.WKTs()} + testMemory(t, sources, workspace) +} + +func TestCompileDescriptorMemory(t *testing.T) { + sources := &source.Openers{source.WKTs()} + workspace := source.NewWorkspace("google/protobuf/descriptor.proto") + testMemory(t, sources, workspace) +} + +func testMemory(t *testing.T, sources source.Opener, workspace source.Workspace) { + exec := incremental.New() + sess := new(ir.Session) + results, _, err := incremental.Run(t.Context(), exec, queries.Link{ + Opener: sources, + Session: sess, + Workspace: workspace, + }) + require.NoError(t, err) + + runtime.GC() + m := new(runtime.MemStats) + runtime.ReadMemStats(m) + t.Logf("heap usage: %v", bitsx.ByteSize(m.Alloc)) + + tape := new(memory.MeasuringTape) + tape.Measure(results) + t.Logf("reachable memory: %v", bitsx.ByteSize(tape.Usage())) +} diff --git a/experimental/internal/astx/encode.go b/experimental/internal/astx/encode.go index 6ec8855e1..57fa0027a 100644 --- a/experimental/internal/astx/encode.go +++ b/experimental/internal/astx/encode.go @@ -159,7 +159,7 @@ func (c *protoEncoder) path(path ast.Path) *compilerpb.Path { proto := &compilerpb.Path{ Span: c.span(path), } - for pc := range path.Components { + for pc := range path.Components() { component := new(compilerpb.Path_Component) switch pc.Separator().Text() { case ".": diff --git a/experimental/internal/taxa/classify.go b/experimental/internal/taxa/classify.go index 56d58d274..94f3f129f 100644 --- a/experimental/internal/taxa/classify.go +++ b/experimental/internal/taxa/classify.go @@ -70,7 +70,7 @@ func Classify(node source.Spanner) Noun { case *ast.File: return TopLevel case ast.Path: - if first, ok := iterx.OnlyOne(node.Components); ok && first.Separator().IsZero() { + if first, ok := iterx.OnlyOne(node.Components()); ok && first.Separator().IsZero() { if id := first.AsIdent(); !id.IsZero() { return Classify(id) } @@ -157,11 +157,7 @@ func Classify(node source.Spanner) Noun { case ast.DefExtend: return Extend case ast.DefOption: - var first ast.PathComponent - node.Path.Components(func(pc ast.PathComponent) bool { - first = pc - return false - }) + first, _ := iterx.First(node.Path.Components()) if !first.AsExtension().IsZero() { return CustomOption } diff --git a/experimental/ir/ir_features.go b/experimental/ir/ir_features.go index ce32b32b6..bc449af8a 100644 --- a/experimental/ir/ir_features.go +++ b/experimental/ir/ir_features.go @@ -41,7 +41,7 @@ type FeatureInfo struct { } type rawFeatureSet struct { - features map[featureKey]rawFeature + features map[uint64]rawFeature parent id.ID[FeatureSet] options id.ID[Value] } @@ -58,10 +58,6 @@ type rawFeatureInfo struct { deprecationWarning string } -type featureKey struct { - extension, field *rawMember -} - type featureDefault struct { edition syntax.Syntax value id.ID[Value] @@ -96,8 +92,12 @@ func (fs FeatureSet) LookupCustom(extension, field Member) Feature { if fs.IsZero() { return Feature{} } + + // This key is guaranteed to be unique, because FQNs are unique. This + // allows us to use Go's fast path for 64-bit integer keys. + key := uint64(extension.InternedFullName())<<32 | uint64(field.InternedFullName()) + // First, check if this value is cached. - key := featureKey{extension.Raw(), field.Raw()} if f, ok := fs.Raw().features[key]; ok { return Feature{id.WrapContext(fs.Context()), f} } @@ -134,7 +134,7 @@ func (fs FeatureSet) LookupCustom(extension, field Member) Feature { } if fs.Raw().features == nil { - fs.Raw().features = make(map[featureKey]rawFeature) + fs.Raw().features = make(map[uint64]rawFeature) } fs.Raw().features[key] = raw return Feature{id.WrapContext(fs.Context()), raw} diff --git a/experimental/ir/ir_file.go b/experimental/ir/ir_file.go index 0b01e6807..b9213f7de 100644 --- a/experimental/ir/ir_file.go +++ b/experimental/ir/ir_file.go @@ -358,7 +358,7 @@ func (f *File) Deprecated() Value { // imported by the file. The symbols are returned in an arbitrary but fixed // order. func (f *File) Symbols() seq.Indexer[Symbol] { - var symbols []Ref[Symbol] + var symbols []symbol if f != nil { symbols = f.imported } @@ -375,7 +375,7 @@ func (f *File) FindSymbol(fqn FullName) Symbol { // ExportedSymbols returns this file's exported symbols. func (f *File) ExportedSymbols() seq.Indexer[Symbol] { - var symbols []Ref[Symbol] + var symbols []symbol if f != nil { symbols = f.exported } @@ -385,8 +385,8 @@ func (f *File) ExportedSymbols() seq.Indexer[Symbol] { func (f *File) symbols(symtab symtab) seq.Indexer[Symbol] { return seq.NewFixedSlice( symtab, - func(_ int, r Ref[Symbol]) Symbol { - return GetRef(f, r) + func(_ int, r symbol) Symbol { + return GetRef(f, r.ref) }, ) } diff --git a/experimental/ir/ir_symbol.go b/experimental/ir/ir_symbol.go index 068847b57..afaeff179 100644 --- a/experimental/ir/ir_symbol.go +++ b/experimental/ir/ir_symbol.go @@ -311,7 +311,12 @@ var optionTargets = [...]OptionTarget{ // // The elements of a symtab are sorted by the [intern.ID] of their FQN, allowing // for O(n) merging of symbol tables. -type symtab []Ref[Symbol] +type symtab []symbol + +type symbol struct { + fqn intern.ID // This avoids an extra pointer chase in lookupBytes. + ref Ref[Symbol] +} var resolveScratch = sync.Pool{ New: func() any { return new([]byte) }, @@ -322,18 +327,18 @@ func symtabMerge(file *File, tables iter.Seq[symtab], fileForTable func(int) *Fi return slicesx.MergeKeySeq( tables, - func(which int, elem Ref[Symbol]) intern.ID { + func(which int, elem symbol) intern.ID { f := fileForTable(which) - return GetRef(f, elem).InternedFullName() + return GetRef(f, elem.ref).InternedFullName() }, - func(which int, elem Ref[Symbol]) Ref[Symbol] { + func(which int, elem symbol) symbol { // We need top map the file number from src to the current one. src := fileForTable(which) if src != file { - theirs := GetRef(src, elem) + theirs := GetRef(src, elem.ref) ours := file.imports.byPath[theirs.Context().InternedPath()] - elem.file = int32(ours + 1) + elem.ref.file = int32(ours + 1) } return elem @@ -344,23 +349,24 @@ func symtabMerge(file *File, tables iter.Seq[symtab], fileForTable func(int) *Fi // sort sorts this symbol table according to the value of each intern // ID. func (s symtab) sort(file *File) { - slices.SortFunc(s, func(a, b Ref[Symbol]) int { - symA := GetRef(file, a) - symB := GetRef(file, b) - return cmp.Compare(symA.InternedFullName(), symB.InternedFullName()) + for i := range s { + s[i].fqn = GetRef(file, s[i].ref).InternedFullName() + } + slices.SortFunc(s, func(a, b symbol) int { + return cmp.Compare(a.fqn, b.fqn) }) } // lookupBytes looks up a symbol with the given fully-qualified name. func (s symtab) lookup(file *File, fqn intern.ID) Ref[Symbol] { - idx, ok := slicesx.BinarySearchKey(s, fqn, func(r Ref[Symbol]) intern.ID { - return GetRef(file, r).InternedFullName() + idx, ok := slicesx.BinarySearchKey(s, fqn, func(r symbol) intern.ID { + return r.fqn }) if !ok { return Ref[Symbol]{} } - return s[idx] + return s[idx].ref } // lookupBytes looks up a symbol with the given fully-qualified name. @@ -369,14 +375,35 @@ func (s symtab) lookupBytes(file *File, fqn []byte) Ref[Symbol] { if !ok { return Ref[Symbol]{} } - idx, ok := slicesx.BinarySearchKey(s, id, func(r Ref[Symbol]) intern.ID { - return GetRef(file, r).InternedFullName() - }) + + var idx int + // Manual inlining of slices.BinarySearch. + // + // Doing this avoids log(len(s)) virtual calls to compare + { + x, target := s, id + n := len(x) + // Define cmp(x[-1], target) < 0 and cmp(x[n], target) >= 0 . + // Invariant: cmp(x[i - 1], target) < 0, cmp(x[j], target) >= 0. + i, j := 0, n + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + if x[h].fqn < target { + i = h + 1 // preserves cmp(x[i - 1], target) < 0 + } else { + j = h // preserves cmp(x[j], target) >= 0 + } + } + // i == j, cmp(x[i-1], target) < 0, and cmp(x[j], target) (= cmp(x[i], target)) >= 0 => answer is i. + idx, ok = i, i < n && x[i].fqn == target + } + if !ok { return Ref[Symbol]{} } - return s[idx] + return s[idx].ref } // resolve attempts to resolve the relative path name within the given scope @@ -493,7 +520,9 @@ func (s symtab) resolve( again: for { r := s.lookupBytes(file, candidate) - remarks.Apply(report.Debugf("candidate: `%s`", candidate)) + if remarks != nil { + remarks.Apply(report.Debugf("candidate: `%s`", candidate)) + } if !r.IsZero() { found = r diff --git a/experimental/ir/ir_value.go b/experimental/ir/ir_value.go index 8f1962137..f5a1a2d6c 100644 --- a/experimental/ir/ir_value.go +++ b/experimental/ir/ir_value.go @@ -288,14 +288,17 @@ func (v Value) Container() MessageValue { // The indexer will be nonempty except for the zero Value. That is to say, unset // fields of [MessageValue]s are not represented as a distinct "empty" Value. func (v Value) Elements() seq.Indexer[Element] { - return seq.NewFixedSlice(v.getElements(), func(n int, bits rawValueBits) Element { - return Element{ - withContext: id.WrapContext(v.Context()), - index: n, - value: v, - bits: bits, - } - }) + return seq.Slice[Element, rawValueBits]{ + Slice: v.getElements(), + Wrap: func(n int, bits rawValueBits) Element { + return Element{ + withContext: id.WrapContext(v.Context()), + index: n, + value: v, + bits: bits, + } + }, + } } // IsTopLevel returns whether this value corresponds with a top-level option declaration diff --git a/experimental/ir/lower_eval.go b/experimental/ir/lower_eval.go index de5ed8256..d4447be97 100644 --- a/experimental/ir/lower_eval.go +++ b/experimental/ir/lower_eval.go @@ -571,7 +571,7 @@ func (e *evaluator) evalMessage(args evalArgs, expr ast.ExprDict) Value { splitURL := func(path ast.Path) (before, after ast.Path) { // Figure out what part of the key expression actually contains // the domain. Look for the last component whose separator is a /. - pc, _ := iterx.Last(iterx.Filter(path.Components, func(pc ast.PathComponent) bool { + pc, _ := iterx.Last(iterx.Filter(path.Components(), func(pc ast.PathComponent) bool { return pc.Separator().Text() == "/" })) hostSpan := path.Span() diff --git a/experimental/ir/lower_options.go b/experimental/ir/lower_options.go index 0d90debff..9f42450c5 100644 --- a/experimental/ir/lower_options.go +++ b/experimental/ir/lower_options.go @@ -44,7 +44,7 @@ func resolveEarlyOptions(file *File) { option := def.AsOption().Option // If this option's path has more than one component, skip. - first, ok := iterx.OnlyOne(option.Path.Components) + first, ok := iterx.OnlyOne(option.Path.Components()) if !ok || !first.Separator().IsZero() { continue } @@ -337,7 +337,7 @@ func validateOptionTargetsInValue(m MessageValue, decl source.Span, target Optio if path := key.AsPath(); !path.IsZero() { // Pull out the last component. // TODO: write a function on Path that does this cheaply. - last, _ := iterx.Last(path.Components) + last, _ := iterx.Last(path.Components()) span = last.Name().Span() } @@ -393,7 +393,7 @@ func (r optionRef) resolve() { field := current.Field() var path ast.Path var raw slot - for pc := range r.def.Path.Components { + for pc := range r.def.Path.Components() { // If this is the first iteration, use the *Options value as the current // message. message := field.Element() diff --git a/experimental/ir/lower_symbols.go b/experimental/ir/lower_symbols.go index b378db8a5..ef32176f0 100644 --- a/experimental/ir/lower_symbols.go +++ b/experimental/ir/lower_symbols.go @@ -39,7 +39,9 @@ func buildLocalSymbols(file *File) { kind: SymbolKindPackage, fqn: file.InternedPackage(), }) - file.exported = append(file.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + file.exported = append(file.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) for ty := range seq.Values(file.AllTypes()) { newTypeSymbol(ty) @@ -77,7 +79,9 @@ func newTypeSymbol(ty Type) { fqn: ty.InternedFullName(), data: arena.Untyped(c.arenas.types.Compress(ty.Raw())), }) - c.exported = append(c.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + c.exported = append(c.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) } func newFieldSymbol(f Member) { @@ -93,7 +97,9 @@ func newFieldSymbol(f Member) { fqn: f.InternedFullName(), data: arena.Untyped(c.arenas.members.Compress(f.Raw())), }) - c.exported = append(c.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + c.exported = append(c.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) } func newOneofSymbol(o Oneof) { @@ -103,7 +109,9 @@ func newOneofSymbol(o Oneof) { fqn: o.InternedFullName(), data: arena.Untyped(c.arenas.oneofs.Compress(o.Raw())), }) - c.exported = append(c.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + c.exported = append(c.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) } func newServiceSymbol(s Service) { @@ -113,7 +121,9 @@ func newServiceSymbol(s Service) { fqn: s.InternedFullName(), data: arena.Untyped(c.arenas.services.Compress(s.Raw())), }) - c.exported = append(c.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + c.exported = append(c.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) } func newMethodSymbol(m Method) { @@ -123,7 +133,9 @@ func newMethodSymbol(m Method) { fqn: m.InternedFullName(), data: arena.Untyped(c.arenas.methods.Compress(m.Raw())), }) - c.exported = append(c.exported, Ref[Symbol]{id: id.ID[Symbol](sym)}) + c.exported = append(c.exported, symbol{ + ref: Ref[Symbol]{id: id.ID[Symbol](sym)}, + }) } // mergeImportedSymbolTables builds a symbol table of every imported symbol. @@ -201,14 +213,14 @@ func mergeImportedSymbolTables(file *File, r *report.Report) { func dedupSymbols(file *File, symbols *symtab, r *report.Report) { *symbols = slicesx.DedupKey( *symbols, - func(r Ref[Symbol]) intern.ID { return GetRef(file, r).InternedFullName() }, - func(refs []Ref[Symbol]) Ref[Symbol] { + func(r symbol) intern.ID { return r.fqn }, + func(refs []symbol) symbol { if len(refs) == 1 { return refs[0] } slices.SortFunc(refs, cmpx.Map( - func(r Ref[Symbol]) Symbol { return GetRef(file, r) }, + func(r symbol) Symbol { return GetRef(file, r.ref) }, cmpx.Key(Symbol.Kind), // Packages sort first, reserved names sort last. cmpx.Key(func(s Symbol) string { // NOTE: we do not choose a winner based on the path's intern @@ -219,14 +231,14 @@ func dedupSymbols(file *File, symbols *symtab, r *report.Report) { cmpx.Key(func(s Symbol) int { return s.Definition().Start }), )) - types := mapsx.CollectSet(iterx.FilterMap(slices.Values(refs), func(r Ref[Symbol]) (ast.DeclDef, bool) { - s := GetRef(file, r) + types := mapsx.CollectSet(iterx.FilterMap(slices.Values(refs), func(r symbol) (ast.DeclDef, bool) { + s := GetRef(file, r.ref) ty := s.AsType() return ty.AST(), !ty.IsZero() })) isFirst := true - refs = slices.DeleteFunc(refs, func(r Ref[Symbol]) bool { - s := GetRef(file, r) + refs = slices.DeleteFunc(refs, func(r symbol) bool { + s := GetRef(file, r.ref) if !isFirst && !s.AsMember().Container().MapField().IsZero() { // Ignore all symbols that are map entry fields, because those // can only be duplicated when two map entry messages' names @@ -253,8 +265,8 @@ func dedupSymbols(file *File, symbols *symtab, r *report.Report) { refs = slicesx.Dedup(refs) if len(refs) > 1 && r != nil { r.Error(errDuplicates{ - symbols: slices.Collect(slicesx.Map(refs, func(r Ref[Symbol]) Symbol { - return GetRef(file, r) + symbols: slices.Collect(slicesx.Map(refs, func(r symbol) Symbol { + return GetRef(file, r.ref) })), }) } diff --git a/experimental/parser/legalize_def.go b/experimental/parser/legalize_def.go index 308884a91..0dd752df9 100644 --- a/experimental/parser/legalize_def.go +++ b/experimental/parser/legalize_def.go @@ -94,9 +94,9 @@ func legalizeTypeDefLike(p *parser, what taxa.Noun, def ast.DeclDef) { // Look for a separator, and use that instead. We can't "just" pick out // the first separator, because def.Name might be a one-component // extension path, e.g. (a.b.c). - def.Name().Components(func(pc ast.PathComponent) bool { + for pc := range def.Name().Components() { if pc.Separator().IsZero() { - return true + continue } err = errtoken.Unexpected{ @@ -105,8 +105,9 @@ func legalizeTypeDefLike(p *parser, what taxa.Noun, def ast.DeclDef) { RepeatUnexpected: true, } - return false - }) + + break + } p.Error(err).Apply( report.Notef("the name of a %s must be a single identifier", what), diff --git a/experimental/parser/legalize_option.go b/experimental/parser/legalize_option.go index d049e7758..90960cb29 100644 --- a/experimental/parser/legalize_option.go +++ b/experimental/parser/legalize_option.go @@ -194,7 +194,7 @@ func legalizeValue(p *parser, decl source.Span, parent ast.ExprAny, value ast.Ex case ast.ExprKindPath: path := kv.Key().AsPath() - first, _ := iterx.First(path.Components) + first, _ := iterx.First(path.Components()) if !first.AsExtension().IsZero() { // TODO: move this into ir/lower_eval.go p.Errorf("cannot name extension field using `(...)` in %s", taxa.Dict).Apply( @@ -225,7 +225,7 @@ func legalizeValue(p *parser, decl source.Span, parent ast.ExprAny, value ast.Ex break } - slashIdx, _ := iterx.Find(path.Components, func(pc ast.PathComponent) bool { + slashIdx, _ := iterx.Find(path.Components(), func(pc ast.PathComponent) bool { return pc.Separator().Keyword() == keyword.Div }) if slashIdx != -1 { diff --git a/experimental/parser/legalize_path.go b/experimental/parser/legalize_path.go index 51d0bfb6a..3634d6724 100644 --- a/experimental/parser/legalize_path.go +++ b/experimental/parser/legalize_path.go @@ -20,7 +20,6 @@ import ( "github.com/bufbuild/protocompile/experimental/report" "github.com/bufbuild/protocompile/experimental/token" "github.com/bufbuild/protocompile/experimental/token/keyword" - "github.com/bufbuild/protocompile/internal/ext/iterx" ) // pathOptions is configuration for [legalizePath]. @@ -47,14 +46,14 @@ func legalizePath(p *parser, where taxa.Place, path ast.Path, opts pathOptions) var bytes, components int var slash token.Token - for i, pc := range iterx.Enumerate(path.Components) { + for pc := range path.Components() { bytes += pc.Separator().Span().Len() // Just Len() here is technically incorrect, because it could be an // extension, but MaxBytes is never used with AllowExts. bytes += pc.Name().Span().Len() components++ - if i == 0 && !opts.AllowAbsolute && pc.Separator().Text() == "." { + if pc.IsFirst() && !opts.AllowAbsolute && pc.Separator().Text() == "." { p.Errorf("unexpected absolute path %s", where).Apply( report.Snippetf(path, "expected a path without a leading `%s`", pc.Separator().Text()), report.SuggestEdits(path, "remove the leading `.`", report.Edit{Start: 0, End: 1}), diff --git a/experimental/parser/parse_def.go b/experimental/parser/parse_def.go index 684bf4800..dbd42cf25 100644 --- a/experimental/parser/parse_def.go +++ b/experimental/parser/parse_def.go @@ -247,16 +247,16 @@ func (defOutputs) parse(p *defParser) source.Span { } if !list.IsZero() { - return source.Join(returns, list) + return source.JoinSpans(returns.Span(), list.Span()) } - return source.Join(returns, ty) + return source.JoinSpans(returns.Span(), ty.Span()) } func (defOutputs) prev(p *defParser) source.Span { if !p.outputTy.IsZero() { - return source.Join(p.args.Returns, p.outputTy) + return source.JoinSpans(p.args.Returns.Span(), p.outputTy.Span()) } - return source.Join(p.args.Returns, p.outputs) + return source.JoinSpans(p.args.Returns.Span(), p.outputs.Span()) } type defValue struct{} @@ -321,14 +321,14 @@ func (defValue) parse(p *defParser) source.Span { p.args.Equals = eq p.args.Value = expr } - return source.Join(eq, expr) + return source.JoinSpans(eq.Span(), expr.Span()) } func (defValue) prev(p *defParser) source.Span { if p.args.Value.IsZero() { return source.Span{} } - return source.Join(p.args.Equals, p.args.Value) + return source.JoinSpans(p.args.Equals.Span(), p.args.Value.Span()) } type defOptions struct{} diff --git a/experimental/parser/parse_type.go b/experimental/parser/parse_type.go index 298b08d07..bf4dd9348 100644 --- a/experimental/parser/parse_type.go +++ b/experimental/parser/parse_type.go @@ -84,7 +84,7 @@ func parseTypeImpl(p *parser, c *token.Cursor, where taxa.Place, pathAfter bool) break // Absolute paths cannot start with a modifier, so we are done. } - first, _ := iterx.First(tyPath.Components) + first, _ := iterx.First(tyPath.Components()) ident := first.AsIdent() if ident.IsZero() { break // If this starts with an extension, we're done. diff --git a/experimental/source/span.go b/experimental/source/span.go index 8f0808631..1ec3f2ded 100644 --- a/experimental/source/span.go +++ b/experimental/source/span.go @@ -22,6 +22,7 @@ import ( "unicode/utf8" "github.com/bufbuild/protocompile/experimental/source/length" + "github.com/bufbuild/protocompile/internal/ext/iterx" ) // Spanner is any type with a [Span]. @@ -186,14 +187,23 @@ func (s Span) String() string { // If there are at least two distinct files among the non-zero spans, // this function panics. func Join(spans ...Spanner) Span { - return JoinSeq[Spanner](slices.Values(spans)) + return JoinSeq(slices.Values(spans)) } // JoinSeq is like [Join], but takes a sequence of any spannable type. func JoinSeq[S Spanner](seq iter.Seq[S]) Span { + return JoinSpanSeq(iterx.Map(seq, func(s S) Span { return GetSpan(s) })) +} + +// See go.dev/issue/78336. +func JoinSpans(spans ...Span) Span { + return JoinSpanSeq(slices.Values(spans)) +} + +// See go.dev/issue/78336. +func JoinSpanSeq(seq iter.Seq[Span]) Span { joined := Span{Start: math.MaxInt} - for spanner := range seq { - span := GetSpan(spanner) + for span := range seq { if span.IsZero() { continue } diff --git a/experimental/token/cursor.go b/experimental/token/cursor.go index 73e6ba506..273c6af0c 100644 --- a/experimental/token/cursor.go +++ b/experimental/token/cursor.go @@ -52,6 +52,11 @@ type CursorMark struct { // // Panics if the token is zero or synthetic. func NewCursorAt(tok Token) *Cursor { + return newCursorAt(new(Cursor), tok) +} + +//go:noinline +func newCursorAt(c *Cursor, tok Token) *Cursor { if tok.IsZero() { panic(fmt.Sprintf("protocompile/token: passed zero token to NewCursorAt: %v", tok)) } @@ -59,11 +64,12 @@ func NewCursorAt(tok Token) *Cursor { panic(fmt.Sprintf("protocompile/token: passed synthetic token to NewCursorAt: %v", tok)) } - return &Cursor{ + *c = Cursor{ context: tok.Context(), idx: naturalIndex(tok.ID()), // Convert to 0-based index. isBackwards: tok.nat().IsClose(), // Set the direction to calculate the offset. } + return c } // NewSliceCursor returns a new cursor over a slice of token IDs in the given diff --git a/experimental/token/keyword/methods.go b/experimental/token/keyword/methods.go index b8f4bc8e8..2ceb55ebf 100644 --- a/experimental/token/keyword/methods.go +++ b/experimental/token/keyword/methods.go @@ -17,7 +17,6 @@ package keyword import ( "iter" - "github.com/bufbuild/protocompile/internal/ext/iterx" "github.com/bufbuild/protocompile/internal/trie" ) @@ -41,7 +40,13 @@ func Prefix(text string) Keyword { // Prefix returns an iterator over the keywords that can be returned by [Lookup] // which are prefixes of text, in ascending order of length. func Prefixes(text string) iter.Seq[Keyword] { - return iterx.Right(kwTrie.Prefixes(text)) + return func(yield func(Keyword) bool) { + for _, kw := range kwTrie.Prefixes(text) { + if !yield(kw) { + return + } + } + } } // Brackets returns the open and close brackets if k is a bracket keyword. diff --git a/experimental/token/token.go b/experimental/token/token.go index f522f848e..2a8124a66 100644 --- a/experimental/token/token.go +++ b/experimental/token/token.go @@ -306,23 +306,32 @@ func Fuse(open, close Token) { //nolint:predeclared,revive // For close. // // If the token is zero or is a leaf token, returns nil. func (t Token) Children() *Cursor { + // Make sure that Children is inlinable; this avoids heap allocations in + // the caller. + return t.children(new(Cursor)) +} + +//go:noinline +func (t Token) children(c *Cursor) *Cursor { if t.IsZero() || t.IsLeaf() { return nil } if impl := t.nat(); impl != nil { start, _ := t.StartEnd() - return &Cursor{ + *c = Cursor{ context: t.Context(), idx: naturalIndex(start.ID()) + 1, // Skip the start! } + return c } synth := t.synth() if synth.IsClose() { - return id.Wrap(t.Context(), synth.otherEnd).Children() + return id.Wrap(t.Context(), synth.otherEnd).children(c) } - return NewSliceCursor(t.Context(), synth.children) + *c = *NewSliceCursor(t.Context(), synth.children) + return c } // SyntheticChildren returns a cursor over the given subslice of the children diff --git a/internal/benchmarks/benchmark_test.go b/internal/benchmarks/benchmark_test.go index e6603a012..7d94c64df 100644 --- a/internal/benchmarks/benchmark_test.go +++ b/internal/benchmarks/benchmark_test.go @@ -15,26 +15,17 @@ package benchmarks import ( - "archive/tar" "bytes" - "compress/gzip" - "context" "errors" "fmt" - "io" "io/fs" - "net/http" "os" "os/exec" "path/filepath" - "reflect" "runtime" - "sort" - "strings" "sync" "sync/atomic" "testing" - "time" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc/protoparse" @@ -45,23 +36,19 @@ import ( "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/ast" + "github.com/bufbuild/protocompile/internal/ext/bitsx" "github.com/bufbuild/protocompile/internal/protoc" + "github.com/bufbuild/protocompile/internal/testing/googleapis" + "github.com/bufbuild/protocompile/internal/testing/memory" "github.com/bufbuild/protocompile/parser" "github.com/bufbuild/protocompile/parser/fastscan" "github.com/bufbuild/protocompile/protoutil" "github.com/bufbuild/protocompile/reporter" ) -const ( - googleapisCommit = "cb6fbe8784479b22af38c09a5039d8983e894566" -) - var ( protocPath string - skipDownload = os.Getenv("SKIP_DOWNLOAD_GOOGLEAPIS") == "true" - - googleapisURI = fmt.Sprintf("https://github.com/googleapis/googleapis/archive/%s.tar.gz", googleapisCommit) googleapisDir string googleapisSources []string ) @@ -85,152 +72,35 @@ func TestMain(m *testing.M) { os.Exit(1) } - var stat int - defer func() { - os.Exit(stat) - }() + stat := new(int) + defer os.Exit(*stat) + // After this point, we can set stat and return instead of directly calling os.Exit. // That allows deferred functions to execute, to perform cleanup, before exiting. - if !skipDownload { - dir, err := os.MkdirTemp("", "testdownloads") - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Could not create temporary directory: %v\n", err) - stat = 1 - return - } - defer func() { - if err := os.RemoveAll(dir); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Failed to cleanup temp directory %s: %v\n", dir, err) - } - }() - - if err := downloadAndExpand(context.Background(), googleapisURI, dir); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Failed to download and expand googleapis: %v\n", err) - stat = 1 - return - } - - googleapisDir = filepath.Join(dir, "googleapis-"+googleapisCommit) + "/" - var sourceSize int64 - err = filepath.Walk(googleapisDir, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return err - } - if !info.IsDir() && strings.HasSuffix(path, ".proto") { - relPath := strings.TrimPrefix(path, googleapisDir) - googleapisSources = append(googleapisSources, relPath) - sourceSize += info.Size() - } - return nil - }) - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Failed to enumerate googleapis source files: %v\n", err) - stat = 1 - return - } - sort.Strings(googleapisSources) - fmt.Printf("%d total source files found in googleapis (%d bytes).\n", len(googleapisSources), sourceSize) - } - - stat = m.Run() -} - -func downloadAndExpand(ctx context.Context, url, targetDir string) (e error) { - start := time.Now() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + dir, err := os.MkdirTemp("", "testdownloads") if err != nil { - return err - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - if resp.Body != nil { - defer func() { - if err = resp.Body.Close(); err != nil && e == nil { - e = err - } - }() - } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("downloading %s resulted in status code %s", url, resp.Status) - } - if err := os.MkdirAll(targetDir, 0777); err != nil { - return err - } - f, err := os.CreateTemp(targetDir, "testdownload.*.tar.gz") - if err != nil { - return err + _, _ = fmt.Fprintf(os.Stderr, "Could not create temporary directory: %v\n", err) + *stat = 1 + return } defer func() { - if f != nil { - if err := f.Close(); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "warning: failed to close %s: %v\n", f.Name(), err) - } + if err := os.RemoveAll(dir); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Failed to cleanup temp directory %s: %v\n", dir, err) } }() - n, err := io.Copy(f, resp.Body) - if err != nil { - return err - } - fmt.Printf("Downloaded %v; %d bytes (%v).\n", url, n, time.Since(start)) - archiveName := f.Name() - if err := f.Close(); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "warning: failed to close %s: %v\n", f.Name(), err) - } - f = nil - - f, err = os.OpenFile(archiveName, os.O_RDONLY, 0) - if err != nil { - return err - } - gzr, err := gzip.NewReader(f) - if err != nil { - return err + if err := googleapis.WriteTo(dir, 0666); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Failed to download and expand googleapis: %v\n", err) + *stat = 1 + return } - defer func() { - if err = gzr.Close(); err != nil && e == nil { - e = err - } - }() - tr := tar.NewReader(gzr) - count := 0 - for { - hdr, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return err - } - if hdr == nil { - continue - } - target := filepath.Join(targetDir, hdr.Name) - switch hdr.Typeflag { - case tar.TypeDir: - if err := os.MkdirAll(target, 0777); err != nil { - return err - } - case tar.TypeReg: - f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_EXCL, os.FileMode(hdr.Mode)) - if err != nil { - return err - } - if _, err := io.Copy(f, tr); err != nil { - return err - } - count++ - default: - // skip anything else - } - } - fmt.Printf("Expanded archive into %d files.\n", count) + googleapisDir = dir + ws, _ := googleapis.Get() + googleapisSources = ws.Paths() - return nil + *stat = m.Run() } func BenchmarkGoogleapisProtocompile(b *testing.B) { @@ -511,9 +381,6 @@ func writeToNull(b *testing.B, fds *descriptorpb.FileDescriptorSet) { } func TestGoogleapisProtocompileResultMemory(t *testing.T) { - if skipDownload { - t.Skip() - } c := protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -526,9 +393,6 @@ func TestGoogleapisProtocompileResultMemory(t *testing.T) { } func TestGoogleapisProtocompileResultMemoryNoSourceInfo(t *testing.T) { - if skipDownload { - t.Skip() - } c := protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -541,9 +405,6 @@ func TestGoogleapisProtocompileResultMemoryNoSourceInfo(t *testing.T) { } func TestGoogleapisProtocompileASTMemory(t *testing.T) { - if skipDownload { - t.Skip() - } var asts []*ast.FileNode for _, file := range googleapisSources { func() { @@ -564,9 +425,6 @@ func TestGoogleapisProtocompileASTMemory(t *testing.T) { } func TestGoogleapisProtoparseResultMemory(t *testing.T) { - if skipDownload { - t.Skip() - } p := protoparse.Parser{ ImportPaths: []string{googleapisDir}, IncludeSourceCodeInfo: true, @@ -577,9 +435,6 @@ func TestGoogleapisProtoparseResultMemory(t *testing.T) { } func TestGoogleapisProtoparseResultMemoryNoSourceInfo(t *testing.T) { - if skipDownload { - t.Skip() - } p := protoparse.Parser{ ImportPaths: []string{googleapisDir}, IncludeSourceCodeInfo: false, @@ -590,9 +445,6 @@ func TestGoogleapisProtoparseResultMemoryNoSourceInfo(t *testing.T) { } func TestGoogleapisProtoparseASTMemory(t *testing.T) { - if skipDownload { - t.Skip() - } p := protoparse.Parser{ IncludeSourceCodeInfo: true, } @@ -611,10 +463,10 @@ func measure(t *testing.T, v any) { runtime.GC() var m runtime.MemStats runtime.ReadMemStats(&m) - t.Logf("(heap used: %d bytes)", m.Alloc) + t.Logf("(heap used: %v)", bitsx.ByteSize(m.Alloc)) // and then try to directly measure just the given value - mt := newMeasuringTape() - mt.measure(reflect.ValueOf(v)) - t.Logf("memory used: %d bytes", mt.memoryUsed()) + mt := new(memory.MeasuringTape) + mt.Measure(v) + t.Logf("memory used: %v", bitsx.ByteSize(mt.Usage())) } diff --git a/internal/benchmarks/measure.go b/internal/benchmarks/measure.go deleted file mode 100644 index a553c7562..000000000 --- a/internal/benchmarks/measure.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2020-2025 Buf Technologies, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package benchmarks - -import ( - "math/bits" - "reflect" - - "github.com/igrmk/treemap/v2" -) - -type measuringTape struct { - bst *treemap.TreeMap[uintptr, uint64] - other uint64 -} - -func newMeasuringTape() *measuringTape { - return &measuringTape{ - bst: treemap.New[uintptr, uint64](), - } -} - -func (t *measuringTape) insert(start uintptr, length uint64) bool { - if start == 0 { - // nil ptr - return false - } - end := start + uintptr(length) - iter := t.bst.LowerBound(start) - if !iter.Valid() { - // tree is empty or all entries are too low to overlap - t.bst.Set(end, length) - return true - } - entryEnd := iter.Key() - entryStart := entryEnd - uintptr(iter.Value()) - if entryStart > end { - // range does not exist; add it - t.bst.Set(end, length) - return true - } - if entryStart <= start && entryEnd >= end { - // range is entirely encompassed in existing entry - return false - } - - // navigate back to find the first overlapping range and push - // start out if needed to encompass all overlaps - first := t.bst.Iterator().Key() - for entryStart > start { - if iter.Key() == first { - // can go no further - break - } - iter.Prev() - if iter.Key() < start { - // gone back too far - break - } - entryStart = iter.Key() - uintptr(iter.Value()) - } - if entryStart < start { - start = entryStart - } - - // find last overlapping range - if entryEnd < end { - for entryEnd < end { - // remove overlaps that will be replaced with - // new, larger, encompassing range - t.bst.Del(entryEnd) - - // Iterator doesn't like concurrent removal of node. So after - // Del above, we can't call Next; we have to re-search the tree - // for the next node. - iter = t.bst.LowerBound(entryEnd) - if !iter.Valid() { - // can go no further - break - } - st := iter.Key() - uintptr(iter.Value()) - if st > end { - // gone too far - break - } - entryEnd = iter.Key() - } - } - if entryEnd > end { - end = entryEnd - } - - t.bst.Set(end, uint64(end-start)) - return true -} - -func (t *measuringTape) memoryUsed() uint64 { - iter := t.bst.Iterator() - var total uint64 - for iter.Valid() { - total += iter.Value() - iter.Next() - } - return total + t.other -} - -func (t *measuringTape) measure(value reflect.Value) { - // We only need to measure outbound references. So we don't care about the size of the pointer itself - // if value is a pointer, since that is either passed by value (not on heap) or accounted for in the - // type that contains the pointer (which we'll have already measured). - - switch value.Kind() { - case reflect.Pointer: - if !t.insert(value.Pointer(), uint64(value.Type().Elem().Size())) { - return - } - t.measure(value.Elem()) - - case reflect.Slice: - if !t.insert(value.Pointer(), uint64(value.Cap())*uint64(value.Type().Elem().Size())) { - return - } - for i := range value.Len() { - t.measure(value.Index(i)) - } - - case reflect.Chan: - if !t.insert(value.Pointer(), uint64(value.Cap())*uint64(value.Type().Elem().Size())) { - return - } - // no way to query for objects in the channel's buffer :( - - case reflect.Map: - const mapHdrSz = 48 // estimate based on struct hmap in runtime/map.go - if !t.insert(value.Pointer(), mapHdrSz) { - return - } - - // Can't really get pointers to bucket arrays, - // so we estimate their size and add them via t.other. - buckets := numBuckets(value.Len()) - // estimate based on struct bmap in runtime/map.go - bucketSz := uint64(8 * (value.Type().Key().Size() + value.Type().Elem().Size() + 1)) - t.other += uint64(buckets) * bucketSz - - for iter := value.MapRange(); iter.Next(); { - t.measure(iter.Key()) - t.measure(iter.Value()) - } - - case reflect.Interface: - v := value.Elem() - if v.IsValid() { - if !isReference(v.Kind()) { - t.other += uint64(v.Type().Size()) - } - t.measure(v) - } - - case reflect.String: - t.insert(value.Pointer(), uint64(value.Len())) - - case reflect.Struct: - for i := range value.NumField() { - t.measure(value.Field(i)) - } - - default: - // nothing to do - } -} - -func numBuckets(mapSize int) int { - // each bucket holds 8 entries - buckets := mapSize / 8 - if mapSize > buckets*8 { - buckets++ - } - // Number of buckets is a power of two (map doubles each - // time it grows). - highestBit := 63 - bits.LeadingZeros64(uint64(buckets)) - if highestBit >= 0 { - powerOf2 := 1 << highestBit - if buckets > powerOf2 { - powerOf2 <<= 1 - } - buckets = powerOf2 - } - return buckets -} - -func isReference(k reflect.Kind) bool { - switch k { - case reflect.Pointer, reflect.Chan, reflect.Map, reflect.Func: - return true - default: - return false - } -} diff --git a/internal/benchmarks/measure_test.go b/internal/benchmarks/measure_test.go deleted file mode 100644 index 1886cd63a..000000000 --- a/internal/benchmarks/measure_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2020-2025 Buf Technologies, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package benchmarks - -import ( - "math/bits" - "reflect" - "testing" - - "github.com/igrmk/treemap/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMeasuringTapeInsert(t *testing.T) { - t.Parallel() - - mt := newMeasuringTape() - assert.True(t, mt.insert(100, 300)) // 100 -> 400 - verifyMap(t, mt.bst, 100, 400) - - // wholly contained - assert.False(t, mt.insert(100, 300)) - assert.False(t, mt.insert(150, 200)) - - // extends range start - assert.True(t, mt.insert(50, 300)) // 50 -> 350 - verifyMap(t, mt.bst, 50, 400) - - // extends range end - assert.True(t, mt.insert(300, 175)) // 300 -> 475 - verifyMap(t, mt.bst, 50, 475) - - // new range above - assert.True(t, mt.insert(1500, 100)) // 1500 -> 1600 - verifyMap(t, mt.bst, 50, 475, 1500, 1600) - - // new range below - assert.True(t, mt.insert(10, 10)) // 10 -> 20 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 1600) - - // new range above - assert.True(t, mt.insert(25000, 50000)) // 25,000 -> 75,000 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 1600, 25000, 75000) - - // new interior range - assert.True(t, mt.insert(1700, 300)) // 1700 -> 2000 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 1600, 1700, 2000, 25000, 75000) - - // new interior range - assert.True(t, mt.insert(2100, 300)) // 2100 -> 2400 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 1600, 1700, 2000, 2100, 2400, 25000, 75000) - - // matches range boundary, extends end - assert.True(t, mt.insert(2400, 100)) // 2400 -> 2500 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 1600, 1700, 2000, 2100, 2500, 25000, 75000) - - // matches both adjacent range boundaries, collapses - assert.True(t, mt.insert(1600, 100)) // 1600 -> 1700 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 2000, 2100, 2500, 25000, 75000) - - // matches range boundary, extends start - assert.True(t, mt.insert(24000, 1000)) // 24,000 -> 25,000 - verifyMap(t, mt.bst, 10, 20, 50, 475, 1500, 2000, 2100, 2500, 24000, 75000) - - // encompasses many ranges, collapses - assert.True(t, mt.insert(10, 3000)) // 10 -> 3010 - verifyMap(t, mt.bst, 10, 3010, 24000, 75000) - - // wholly contained - assert.False(t, mt.insert(1500, 1510)) // 1500 -> 3010 - - mt.other = 99 - assert.Equal(t, 54099, int(mt.memoryUsed())) -} - -func TestMeasuringTapeMeasure(t *testing.T) { - t.Parallel() - - mt := newMeasuringTape() - bytes := make([]byte, 1000000) - mt.measure(reflect.ValueOf(bytes)) - require.Equal(t, uint64(1000000), mt.memoryUsed()) - // these do nothing since they are part of already-measured slice - mt.measure(reflect.ValueOf(bytes[0:10])) - mt.measure(reflect.ValueOf(bytes[1000:10000])) - require.Equal(t, uint64(1000000), mt.memoryUsed()) - - int64s := make([]int64, 1000000) - mt.measure(reflect.ValueOf(int64s)) - require.Equal(t, uint64(9000000), mt.memoryUsed()) - - int64ptrs := make([]*int64, 1000000) - for i := range int64ptrs { - int64ptrs[i] = &int64s[i] - } - mt.measure(reflect.ValueOf(int64ptrs)) - // increase is only the size of slice, not pointed-to values, since all pointers - // point to locations in already-measured slice above - ptrsSz := uint64(1000000 * reflect.TypeOf(uintptr(0)).Size()) - require.Equal(t, 9000000+ptrsSz, mt.memoryUsed()) -} - -func verifyMap(t *testing.T, tree *treemap.TreeMap[uintptr, uint64], ranges ...uintptr) { - t.Helper() - require.Equal(t, 0, len(ranges)%2, "ranges must be even number of values") - - iter := tree.Iterator() - for i := 0; i < len(ranges); i += 2 { - require.True(t, iter.Valid()) - entryEnd := iter.Key() - entryStart := entryEnd - uintptr(iter.Value()) - type pair struct { - start, end uintptr - } - expected := pair{ranges[i], ranges[i+1]} - actual := pair{entryStart, entryEnd} - require.Equal(t, expected, actual) - iter.Next() - } -} - -func TestNumBuckets(t *testing.T) { - t.Parallel() - - assert.Equal(t, 0, numBuckets(0)) - assert.Equal(t, 1, numBuckets(8)) - assert.Equal(t, 2, numBuckets(9)) - assert.Equal(t, 2, numBuckets(16)) - assert.Equal(t, 4, numBuckets(17)) - assert.Equal(t, 4, numBuckets(32)) - assert.Equal(t, 8, numBuckets(33)) - - check := func(sz int) { - b := numBuckets(sz) - // power of 2 - assert.Equal(t, 1, bits.OnesCount(uint(b))) - // that fits given size (each bucket holds 8 entries) - assert.Less(t, b*4, sz) - assert.GreaterOrEqual(t, b*8, sz) - } - check(7364) - check(1234567) - check(918373645623) -} diff --git a/internal/ext/bitsx/bytesize.go b/internal/ext/bitsx/bytesize.go new file mode 100644 index 000000000..a52ccd0d5 --- /dev/null +++ b/internal/ext/bitsx/bytesize.go @@ -0,0 +1,40 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bitsx + +import ( + "fmt" + "math/bits" +) + +// ByteSize formats a number as a human-readable number of bytes. +func ByteSize[T Int](v T) string { + abs := v + if v < 0 { + abs = -v + } + + n := bits.Len64(uint64(abs)) + if n >= 30 { + return fmt.Sprintf("%.03f GB", float64(v)/float64(1024*1024*1024)) + } + if n >= 20 { + return fmt.Sprintf("%.03f MB", float64(v)/float64(1024*1024)) + } + if n >= 10 { + return fmt.Sprintf("%.03f KB", float64(v)/float64(1024)) + } + return fmt.Sprintf("%v.000 B", v) +} diff --git a/internal/ext/iterx/get.go b/internal/ext/iterx/get.go index 895fc5864..6d69c567a 100644 --- a/internal/ext/iterx/get.go +++ b/internal/ext/iterx/get.go @@ -54,8 +54,8 @@ func Last2[K, V any](seq iter.Seq2[K, V]) (k K, v V, ok bool) { // OnlyOne retrieves the only element of an iterator. func OnlyOne[T any](seq iter.Seq[T]) (v T, ok bool) { - for i, x := range Enumerate(seq) { - if i > 0 { + for x := range seq { + if ok { var z T // Ensure we return the zero value if there is more // than one element. diff --git a/internal/ext/reflectx/reflectx.go b/internal/ext/reflectx/reflectx.go new file mode 100644 index 000000000..d626ad89a --- /dev/null +++ b/internal/ext/reflectx/reflectx.go @@ -0,0 +1,61 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reflectx + +import "reflect" + +// UnwrapStruct removes as many layers of "one field wrappers" on v as possible. +// +// This means: +// 1. The one field in a struct with non-zero size. +// 2. The one element of a [1]T. +func UnwrapStruct(v reflect.Value) reflect.Value { +loop: + switch v.Kind() { + case reflect.Struct: + ty := v.Type() + + var nonzero reflect.StructField + for i := range ty.NumField() { + f := ty.Field(i) + if f.Offset > 0 { + // This catches the following problematic struct: + // + // struct { A int; B [0]int } + // + // Zero-sized fields after the last non-zero-sized field + // result in padding. + break loop + } + if f.Type.Size() > 0 { + if nonzero.Type != nil { + break loop + } + nonzero = f + } + } + + v = v.FieldByIndex(nonzero.Index) + goto loop + + case reflect.Array: + if v.Len() == 1 { + v = v.Index(0) + goto loop + } + } + + return v +} diff --git a/internal/ext/reflectx/reflectx_test.go b/internal/ext/reflectx/reflectx_test.go new file mode 100644 index 000000000..00675cc0e --- /dev/null +++ b/internal/ext/reflectx/reflectx_test.go @@ -0,0 +1,78 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reflectx_test + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/bufbuild/protocompile/internal/ext/reflectx" +) + +func TestUnwrap(t *testing.T) { + t.Parallel() + tests := []struct { + have, want reflect.Type + }{ + {have: reflect.TypeFor[int](), want: reflect.TypeFor[int]()}, + {have: reflect.TypeFor[[1]int](), want: reflect.TypeFor[int]()}, + {have: reflect.TypeFor[[2]int](), want: reflect.TypeFor[[2]int]()}, + + {have: reflect.TypeFor[struct { + V int + }](), want: reflect.TypeFor[int]()}, + + {have: reflect.TypeFor[struct { + _ [0]uint64 + V byte + }](), want: reflect.TypeFor[byte]()}, + + {have: reflect.TypeFor[struct { + _ [1]uint64 + V byte + }](), want: reflect.TypeFor[struct { + _ [1]uint64 + V byte + }]()}, + + {have: reflect.TypeFor[struct { + _ [0]uint32 + V [1]struct { + _ [0]uint64 + V int + } + }](), want: reflect.TypeFor[int]()}, + + {have: reflect.TypeFor[struct { + V uint32 + _ [0]uint32 + }](), want: reflect.TypeFor[struct { + V uint32 + _ [0]uint32 + }]()}, + } + + for _, tt := range tests { + t.Run(tt.have.Name(), func(t *testing.T) { + t.Parallel() + + v := reflect.New(tt.have).Elem() + v = reflectx.UnwrapStruct(v) + assert.Equal(t, tt.want, v.Type()) + }) + } +} diff --git a/internal/ext/unsafex/unsafex.go b/internal/ext/unsafex/unsafex.go index 4a9bcaa27..43bbc4f8f 100644 --- a/internal/ext/unsafex/unsafex.go +++ b/internal/ext/unsafex/unsafex.go @@ -166,3 +166,8 @@ func NoEscape[P ~*E, E any](ptr P) P { p = unsafe.Pointer(uintptr(p) ^ 0) //nolint:staticcheck return P(p) } + +// NoEscapeSlice marks a slice as not escaping, as by [NoEscape]. +func NoEscapeSlice[S ~[]E, E any](s S) S { + return unsafe.Slice(NoEscape(unsafe.SliceData(s)), cap(s))[:len(s)] +} diff --git a/internal/testing/googleapis/archive.go b/internal/testing/googleapis/archive.go new file mode 100644 index 000000000..ef8462b67 --- /dev/null +++ b/internal/testing/googleapis/archive.go @@ -0,0 +1,22 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by github.com/bufbuild/protocompile/internal/testing/googleapis/gen. DO NOT EDIT. + +package googleapis + +import _ "embed" + +//go:embed googleapis-cb6fbe8784479b22af38c09a5039d8983e894566.tar.gz +var archive string diff --git a/internal/testing/googleapis/gen/main.go b/internal/testing/googleapis/gen/main.go new file mode 100644 index 000000000..143e7614e --- /dev/null +++ b/internal/testing/googleapis/gen/main.go @@ -0,0 +1,168 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:gosec +package main + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "errors" + "flag" + "fmt" + "go/format" + "io" + "net/http" + "os" + "path/filepath" + "runtime/debug" + "time" + + "github.com/bufbuild/protocompile/internal/ext/bitsx" + "github.com/bufbuild/protocompile/internal/ext/flagx" +) + +func main() { + flagx.Main(func() (e error) { + commit := flag.Arg(0) + + out := fmt.Sprintf("googleapis-%s.tar.gz", commit) + _, err := os.Stat(out) + if err == nil || !errors.Is(err, os.ErrNotExist) { + return err + } + + url := fmt.Sprintf("https://github.com/googleapis/googleapis/archive/%s.tar.gz", commit) + + start := time.Now() + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + if resp.Body != nil { + defer func() { + if err = resp.Body.Close(); err != nil && e == nil { + e = err + } + }() + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("downloading %s resulted in status code %s", url, resp.Status) + } + + ar, err := os.Create(out) + if err != nil { + return err + } + + dir := "googleapis-" + commit + total, err := filterArchive(ar, resp.Body, func(path string) string { + rel, err := filepath.Rel(dir, path) + if err != nil || filepath.Ext(path) != ".proto" { + return "" + } + return rel + }) + if err != nil { + return err + } + elapsed := time.Since(start) + + stat, err := os.Stat(out) + if err != nil { + return err + } + + fmt.Printf("googleapis: downloaded commit %v (%v, compressed to %v, %0.3f%%) in %v\n", + commit, bitsx.ByteSize(total), bitsx.ByteSize(stat.Size()), float64(stat.Size())/float64(total)*100, elapsed) + + info, ok := debug.ReadBuildInfo() + if !ok { + return errors.New("debug: could not read build info") + } + + embed := new(bytes.Buffer) + fmt.Fprintf(embed, "// Code generated by %s. DO NOT EDIT.\n\n", info.Path) + fmt.Fprintf(embed, "package %s\n\n", os.Getenv("GOPACKAGE")) + fmt.Fprintf(embed, "import _ \"embed\"\n\n") + + fmt.Fprintf(embed, "//go:embed %s\n", out) + fmt.Fprintf(embed, "var archive string\n") + + src, err := format.Source(embed.Bytes()) + if err != nil { + return err + } + return os.WriteFile("archive.go", src, 0666) + }) +} + +func filterArchive(dst io.Writer, src io.Reader, filter func(string) string) (total int64, err error) { + dstGz, err := gzip.NewWriterLevel(dst, gzip.BestCompression) + if err != nil { + return total, err + } + dstAr := tar.NewWriter(dstGz) + + SrcGz, err := gzip.NewReader(src) + if err != nil { + return total, err + } + srcAr := tar.NewReader(SrcGz) + + for { + hdr, err := srcAr.Next() + if err == io.EOF { + break + } + if err != nil { + return total, err + } + + if hdr == nil || hdr.Typeflag != tar.TypeReg { + continue + } + + hdr.Name = filter(hdr.Name) + if hdr.Name == "" { + continue + } + + if err := dstAr.WriteHeader(hdr); err != nil { + return total, err + } + + n, err := io.Copy(dstAr, srcAr) + if err != nil { + return total, err + } + total += n + } + + if err := dstAr.Close(); err != nil { + return total, err + } + if err := dstGz.Close(); err != nil { + return total, err + } + return total, err +} diff --git a/internal/testing/googleapis/googleapis-cb6fbe8784479b22af38c09a5039d8983e894566.tar.gz b/internal/testing/googleapis/googleapis-cb6fbe8784479b22af38c09a5039d8983e894566.tar.gz new file mode 100644 index 000000000..1014e553e Binary files /dev/null and b/internal/testing/googleapis/googleapis-cb6fbe8784479b22af38c09a5039d8983e894566.tar.gz differ diff --git a/internal/testing/googleapis/googleapis.go b/internal/testing/googleapis/googleapis.go new file mode 100644 index 000000000..b332a336e --- /dev/null +++ b/internal/testing/googleapis/googleapis.go @@ -0,0 +1,123 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package googleapis makes a checked-in download of +// https://github.com/googleapis/googleapis available for use by tests. +// +// The checked in data at googleapis-xxx.tar.gz is governed by +// https://github.com/googleapis/googleapis/blob/master/LICENSE. +package googleapis + +//go:generate go run ./gen cb6fbe8784479b22af38c09a5039d8983e894566 + +import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/bufbuild/protocompile/experimental/source" +) + +var ( + opener source.Opener + workspace source.Workspace + once sync.Once +) + +// Get returns a workspace and opener containing the entire googleapis project, +// for use in tests. +func Get() (source.Workspace, source.Opener) { + once.Do(func() { + protos, err := unpack(archive) + if err != nil { + panic(fmt.Errorf("googleapis: %w", err)) + } + + var paths []string + for path := range protos.Get() { + paths = append(paths, path) + } + slices.Sort(paths) + + opener = protos + workspace = source.NewWorkspace(paths...) + }) + + return workspace, opener +} + +// WriteTo writes the entire googleapis tree onto the given directory. +func WriteTo(dir string, perm os.FileMode) error { + ws, op := Get() + for _, path := range ws.Paths() { + src, err := op.Open(path) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Join(dir, filepath.Dir(path)), 0777); err != nil { + return err + } + + f, err := os.OpenFile(filepath.Join(dir, path), os.O_CREATE|os.O_RDWR, perm) + if err != nil { + return err + } + + _, err = f.WriteString(src.Text()) + _ = f.Close() + if err != nil { + return err + } + } + return nil +} + +func unpack(archive string) (opener source.Map, e error) { + gz, err := gzip.NewReader(strings.NewReader(archive)) + if err != nil { + return opener, err + } + + ar := tar.NewReader(gz) + opener = source.NewMap(nil) + for { + hdr, err := ar.Next() + if err == io.EOF { + break + } + if err != nil { + return opener, err + } + + if hdr == nil || hdr.Typeflag != tar.TypeReg { + continue + } + + buf := new(strings.Builder) + if _, err := io.Copy(buf, ar); err != nil { //nolint:gosec + return opener, err + } + + opener.Add(hdr.Name, buf.String()) + } + + return opener, nil +} diff --git a/internal/testing/memory/measure.go b/internal/testing/memory/measure.go new file mode 100644 index 000000000..6cefbf807 --- /dev/null +++ b/internal/testing/memory/measure.go @@ -0,0 +1,154 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory + +import ( + "reflect" + + "github.com/bufbuild/protocompile/internal/ext/bitsx" + "github.com/bufbuild/protocompile/internal/ext/reflectx" + "github.com/bufbuild/protocompile/internal/interval" +) + +// MeasuringTape measures how much memory a particular value used. +type MeasuringTape struct { + // Which memory regions have already been measured. This allows us to + // detect cycles in a robust manner. + heap interval.Intersect[uintptr, struct{}] + visited map[[2]uintptr]struct{} + + // Extra memory that we cannot get an actual pointer to. + extra uint64 +} + +// Usage returns the total number of bytes used. +func (t *MeasuringTape) Usage() uint64 { + var total uint64 + for entry := range t.heap.Contiguous(false) { + total += uint64(entry.End - entry.Start + 1) + } + return total + t.extra +} + +// Measure records the memory transitively reachable through v. +func (t *MeasuringTape) Measure(v any) { + t.measure(reflect.ValueOf(v)) +} + +func (t *MeasuringTape) measure(v reflect.Value) { + insert := func(start uintptr, bytes int) bool { + if bytes == 0 { + return false + } + end := start + uintptr(bytes) + if _, ok := t.visited[[2]uintptr{start, end}]; ok { + return false + } + if t.visited == nil { + t.visited = make(map[[2]uintptr]struct{}) + } + t.visited[[2]uintptr{start, end}] = struct{}{} + + t.heap.Insert(start, end-1, struct{}{}) + return true + } + + // We only need to measure outbound references. So we don't care about the + // size of the pointer itself if value is a pointer, since that is either + // passed by value (not on heap) or accounted for in the type that contains + // the pointer (which we'll have already measured). + // + // Note that we cannot handle unsafe.Pointer, because reflection cannot + // tell us how large the pointee is. + + switch v.Kind() { + case reflect.Pointer: + if !insert(v.Pointer(), int(v.Type().Elem().Size())) { + return + } + t.measure(v.Elem()) + + case reflect.Slice: + if !insert(v.Pointer(), v.Cap()*int(v.Type().Elem().Size())) { + return + } + for i := range v.Len() { + t.measure(v.Index(i)) + } + + case reflect.Chan: + if !insert(v.Pointer(), v.Cap()*int(v.Type().Elem().Size())) { + return + } + // no way to query for objects in the channel's buffer :( + + case reflect.Map: + const header = 8 * 6 // See internal/maps.Map in maps/map.go. + if !insert(v.Pointer(), header) { + return + } + + t.extra += uint64(estimateMapSize(v)) + for iter := v.MapRange(); iter.Next(); { + t.measure(iter.Key()) + t.measure(iter.Value()) + } + + case reflect.Interface: + v := v.Elem() + if v.IsValid() { + inner := reflectx.UnwrapStruct(v) + switch inner.Kind() { + case reflect.Pointer, reflect.Chan, reflect.Map, reflect.Func: + default: + t.extra += uint64(v.Type().Size()) + } + t.measure(v) + } + + case reflect.String: + insert(v.Pointer(), v.Len()) + + case reflect.Array: + for i := range v.Len() { + t.measure(v.Index(i)) + } + + case reflect.Struct: + for i := range v.NumField() { + t.measure(v.Field(i)) + } + + default: + // nothing to do + } +} + +//nolint:revive,predeclared +func estimateMapSize(m reflect.Value) int { + const table = 8 * 4 // See internal/maps.table in maps/table.go. + const groupSize = 8 + + // Map size must be a power of two. + // Note that if len is a power of two, the cap must be the next power of + // two, because SwissTable requires a load factor of ~7/8. + cap := bitsx.NextPowerOfTwo(uint(m.Len())) + + // Approximation: this is missing padding. + group := groupSize + groupSize*(m.Type().Key().Size()+m.Type().Elem().Size()) + + // We assume that the internal map directory has exactly one entry in it. + return table + int(cap/groupSize)*int(group) +} diff --git a/internal/testing/memory/measure_test.go b/internal/testing/memory/measure_test.go new file mode 100644 index 000000000..9c0840da2 --- /dev/null +++ b/internal/testing/memory/measure_test.go @@ -0,0 +1,51 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory_test + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bufbuild/protocompile/internal/testing/memory" +) + +func TestMeasuringTapeMeasure(t *testing.T) { + t.Parallel() + + mt := new(memory.MeasuringTape) + bytes := make([]byte, 1000000) + mt.Measure(bytes) + require.Equal(t, 1000000, int(mt.Usage())) + // these do nothing since they are part of already-measured slice + mt.Measure(bytes[0:10]) + mt.Measure(bytes[1000:10000]) + require.Equal(t, 1000000, int(mt.Usage())) + + int64s := make([]int64, 1000000) + mt.Measure(int64s) + require.Equal(t, 9000000, int(mt.Usage())) + + int64ptrs := make([]*int64, 1000000) + for i := range int64ptrs { + int64ptrs[i] = &int64s[i] + } + mt.Measure(int64ptrs) + // increase is only the size of slice, not pointed-to values, since all pointers + // point to locations in already-measured slice above + ptrBytes := len(int64ptrs) * int(reflect.TypeOf((*int64)(nil)).Size()) + require.Equal(t, 9000000+ptrBytes, int(mt.Usage())) +} diff --git a/internal/trie/nybbles.go b/internal/trie/nybbles.go index ab4b86133..9fe7357e7 100644 --- a/internal/trie/nybbles.go +++ b/internal/trie/nybbles.go @@ -73,6 +73,50 @@ func (t *nybbles[N]) search(key string, yield func(string, int) bool) { } } +// searcher is used in [nybbles.step]. +type searcher struct { + i int // Length of the key examined so far, plus one. + n int // Entry being examined. +} + +// step walks a single step along the trie. +// +// Unlike search, which takes a func, this function takes state that does not +// need to be passed by pointer. This avoids an escape analysis failure in +// functions like [Trie.Prefixes], where the closure (which is passed to an +// interface implemented by nybbles[T]) and everything it captures escapes. +func (t *nybbles[N]) step(key string, s searcher) searcher { + if s.i == 0 { + s.i++ + if t.has(0) { + return s + } + } + + for ; s.i <= len(key); s.i++ { + b := key[s.i-1] + lo, hi := b&0xf, b>>4 + + if len(t.hi) <= s.n { + break + } + m := int(t.hi[s.n][hi]) + + if len(t.lo) <= m { + break + } + s.n = int(t.lo[m][lo]) + + if t.has(s.n) { + s.i++ + return s + } + } + + s.n = -1 + return s +} + // insert adds a new key to the trie; returns the index to insert the // corresponding value at. // diff --git a/internal/trie/trie.go b/internal/trie/trie.go index 4385f00f3..ce4a7f050 100644 --- a/internal/trie/trie.go +++ b/internal/trie/trie.go @@ -19,7 +19,6 @@ import ( "strings" "github.com/bufbuild/protocompile/internal/ext/iterx" - "github.com/bufbuild/protocompile/internal/ext/unsafex" ) // Trie implements a map from strings to V, except lookups return the key @@ -28,6 +27,8 @@ import ( // The zero value is empty and ready to use. type Trie[V any] struct { impl interface { + step(key string, s searcher) searcher + search(key string, yield func(string, int) bool) insert(key string) int @@ -55,14 +56,19 @@ func (t *Trie[V]) Prefixes(key string) iter.Seq2[string, V] { return } - adapt := func(prefix string, index int) bool { - return yield(prefix, t.values[index]) - } + var s searcher + for { + s = t.impl.step(key, s) + if s.n == -1 { + break + } - // No implementation of impl will ever cause adapt to escape. This - // avoids a heap allocation. - adapt = *unsafex.NoEscape(&adapt) - t.impl.search(key, adapt) + prefix := key[:s.i-1] + entry := t.values[s.n] + if !yield(prefix, entry) { + return + } + } } }