Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions common/annotation_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <functional>
#include <optional>
#include <string>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
Expand Down Expand Up @@ -303,6 +304,30 @@ absl::StatusOr<std::optional<std::string>> GetAnnotationWithStringArg(
return std::string(*arg);
}

absl::StatusOr<std::optional<std::vector<std::string>>>
GetAnnotationWithStringArgs(const clang::Decl& decl,
absl::string_view annotation_name) {
CRUBIT_ASSIGN_OR_RETURN(std::optional<AnnotateArgs> maybe_args,
GetAnnotateAttrArgs(decl, annotation_name));
if (!maybe_args.has_value()) {
return std::nullopt;
}
const AnnotateArgs& args = *maybe_args;
std::vector<std::string> result;
result.reserve(args.size());
for (const clang::Expr* arg_expr : args) {
absl::StatusOr<absl::string_view> arg =
GetExprAsStringLiteral(*arg_expr, decl.getASTContext());
if (!arg.ok()) {
return absl::InvalidArgumentError(
absl::StrCat("Annotation ", annotation_name,
" arguments must be string literals."));
}
result.push_back(std::string(*arg));
}
return result;
}

absl::StatusOr<const clang::AnnotateTypeAttr* absl_nullable>
GetTypeAnnotationSingleDecl(const clang::Type* absl_nonnull type
ABSL_ATTRIBUTE_LIFETIME_BOUND,
Expand Down
10 changes: 10 additions & 0 deletions common/annotation_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ absl::StatusOr<absl::string_view> GetExprAsStringLiteral(
absl::StatusOr<std::optional<std::string>> GetAnnotationWithStringArg(
const clang::Decl& decl, absl::string_view annotation_name);

// Returns the string arguments of [[clang::annotate(annotation_name,
// string_arg1, string_arg2, ...)]] annotation on `decl`, or none if the
// annotation does not exist.
//
// Returns an error if there are conflicting annotations or if any argument is
// not a string.
absl::StatusOr<std::optional<std::vector<std::string>>>
GetAnnotationWithStringArgs(const clang::Decl& decl,
absl::string_view annotation_name);

// Returns true if `decl` has an annotation with the given name.
//
// Returns an error if an annotation with the given name exists, but it has
Expand Down
37 changes: 37 additions & 0 deletions common/annotation_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
namespace crubit {
namespace {

using testing::ElementsAre;
using testing::Eq;
using testing::HasSubstr;
using testing::Ne;
Expand Down Expand Up @@ -151,5 +152,41 @@ TEST(AnnotationReaderTest,
ASSERT_THAT(GetAnnotateAttrArgs(var, "foo"), IsOkAndHolds(Ne(std::nullopt)));
}

TEST(AnnotationReaderTest, GetAnnotationWithStringArgsSuccess) {
clang::TestAST ast(R"cc(
[[clang::annotate("foo", "arg1", "arg2")]] int i;
)cc");

auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");

auto result = GetAnnotationWithStringArgs(var, "foo");
ASSERT_THAT(result, IsOkAndHolds(Ne(std::nullopt)));
EXPECT_THAT(**result, ElementsAre("arg1", "arg2"));
}

TEST(AnnotationReaderTest, GetAnnotationWithStringArgsNone) {
clang::TestAST ast(R"cc(
int i;
)cc");

auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");

EXPECT_THAT(GetAnnotationWithStringArgs(var, "foo"),
IsOkAndHolds(Eq(std::nullopt)));
}

TEST(AnnotationReaderTest, GetAnnotationWithStringArgsFailureNonString) {
clang::TestAST ast(R"cc(
[[clang::annotate("foo", "arg1", 42)]] int i;
)cc");

auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");

EXPECT_THAT(
GetAnnotationWithStringArgs(var, "foo"),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Annotation foo arguments must be string literals.")));
}

} // namespace
} // namespace crubit
24 changes: 16 additions & 8 deletions rs_bindings_from_cc/generate_bindings/database/code_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ pub fn generated_items_to_tokens<'db>(
nested_items,
indirect_functions,
delete,
owned_type_name,
owned_ptr_config,
member_methods,
free_functions,
lifetime_params,
Expand Down Expand Up @@ -604,10 +604,12 @@ pub fn generated_items_to_tokens<'db>(
None
};

let owned_type_def = owned_type_name.as_ref().map(|owned_type_name| {
let owned_type_def = owned_ptr_config.as_ref().map(|cfg| {
let owned_type_name = &cfg.owned_type_name;
let drop_meth = &cfg.drop_impl;
let doc_comment = format!(
"Wrapper for a C++ {} owned by Rust. \n\n Style guide: The C++ type to which this refers should be wrapped in an `Arc` or `Mutex` if it is not already thread-safe. \n\n THIS TYPE REQUIRES A MANUAL DROP IMPLEMENTATION. \n You MUST provide an `impl {} {{ pub fn DropImpl(&mut self) {{ ... }} }}` block in a separate Rust file (e.g., via `additional_rust_srcs`). Failure to do so will result in a compile-time error: `method not found in `{}``.",
ident, owned_type_name, owned_type_name
"Wrapper for a C++ {} owned by Rust. \n\n Style guide: The C++ type to which this refers should be wrapped in an `Arc` or `Mutex` if it is not already thread-safe. \n\n THIS TYPE REQUIRES A MANUAL DROP IMPLEMENTATION. \n You MUST provide an `impl {} {{ pub fn {}(&mut self) {{ ... }} }}` block in a separate Rust file (e.g., via `additional_rust_srcs`). Failure to do so will result in a compile-time error: `method not found in `{}``.",
ident, owned_type_name, drop_meth, owned_type_name
);
quote! {
__NEWLINE__ __NEWLINE__
Expand All @@ -618,10 +620,10 @@ pub fn generated_items_to_tokens<'db>(

impl Drop for #owned_type_name {
fn drop(&mut self) {
__COMMENT__ "IMPORTANT: The DropImpl method for `{}` MUST be implemented in a user-written .rs file (e.g., using `additional_rust_srcs`)."
__COMMENT__ "IMPORTANT: The drop method MUST be implemented in a user-written .rs file (e.g., using `additional_rust_srcs`)."
__COMMENT__ "Crubit cannot automatically generate the destruction logic for this type."
__COMMENT__ "See the struct documentation for more details."
self.DropImpl();
self.#drop_meth();
}
}
}
Expand Down Expand Up @@ -977,6 +979,12 @@ impl GeneratedItem {
}
}

#[derive(Clone, Debug)]
pub struct OwnedPtrConfig {
pub owned_type_name: Ident,
pub drop_impl: Ident,
}

#[derive(Clone, Debug)]
pub struct Record {
pub doc_comment_attr: Option<DocCommentAttr>,
Expand Down Expand Up @@ -1005,8 +1013,8 @@ pub struct Record {
/// Functions that get attached either by a trait or from a base class.
pub indirect_functions: Vec<TokenStream>,
pub delete: Option<DeleteImpl>,
/// The name of the owning wrapper type when the type was annotated with CRUBIT_OWNED_POINTEE.
pub owned_type_name: Option<Ident>,
/// The owning wrapper type configuration when the type was annotated with CRUBIT_OWNED_POINTEE.
pub owned_ptr_config: Option<OwnedPtrConfig>,
pub member_methods: Vec<TokenStream>,
pub free_functions: Vec<TokenStream>,
pub lifetime_params: Vec<syn::Lifetime>,
Expand Down
4 changes: 2 additions & 2 deletions rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ impl RsTypeKind {
lifetimes,
in_cc_std,
)?,
owned_ptr_type: record.owned_ptr_type.clone(),
owned_ptr_type: record.owned_ptr_config.as_ref().map(|cfg| cfg.owned_ptr_type.clone()),
record,
crate_path,
lifetimes: lifetimes.to_vec(),
Expand Down Expand Up @@ -1448,7 +1448,7 @@ impl RsTypeKind {
)
};

let owned_ptr_type = record.owned_ptr_type.as_ref().expect(
let owned_ptr_type = record.owned_ptr_config.as_ref().map(|cfg| cfg.owned_ptr_type.as_ref()).expect(
"CRUBIT_OWNED_POINTER annotated pointers should point to a struct with an associated CRUBIT_OWNED_POINTEE",
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,11 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
})
};

let owned_type_name = record.owned_ptr_type.as_ref().map(|opt| make_rs_ident(opt.as_ref()));
let owned_ptr_config =
record.owned_ptr_config.as_ref().map(|cfg| database::code_snippet::OwnedPtrConfig {
owned_type_name: make_rs_ident(cfg.owned_ptr_type.as_ref()),
drop_impl: make_rs_ident(cfg.drop_impl.as_ref()),
});
let member_methods = api_snippets.member_functions.remove(&record.id).unwrap_or_default();
let free_functions = api_snippets.free_functions.remove(&record.id).unwrap_or_default();

Expand Down Expand Up @@ -753,7 +757,7 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
items,
nested_items,
indirect_functions,
owned_type_name,
owned_ptr_config,
member_methods,
free_functions,
delete: operator_delete_impl,
Expand Down
34 changes: 29 additions & 5 deletions rs_bindings_from_cc/importers/cxx_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,12 +922,36 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
std::optional<BridgeType> bridge_type =
GetBridgeTypeAnnotation(ictx_, *record_decl);

absl::StatusOr<std::optional<std::string>> owned_ptr_type =
GetAnnotationWithStringArg(*record_decl, "crubit_owned_pointee");
if (!owned_ptr_type.ok()) {
absl::StatusOr<std::optional<std::vector<std::string>>> args =
GetAnnotationWithStringArgs(*record_decl, "crubit_owned_pointee");
if (!args.ok()) {
return ictx_.ImportUnsupportedItem(
*record_decl, std::nullopt,
FormattedError::FromStatus(std::move(owned_ptr_type).status()));
FormattedError::FromStatus(std::move(args).status()));
}

std::optional<OwnedPtrConfig> owned_ptr_config;

if (args->has_value()) {
const auto& args_vec = **args;
if (args_vec.empty() || args_vec.size() > 2) {
return ictx_.ImportUnsupportedItem(
*record_decl, std::nullopt,
FormattedError::Static(
"crubit_owned_pointee takes 1 or 2 arguments"));
}

std::string owned_ptr_type = args_vec[0];
std::string drop_impl = "DropImpl";

if (args_vec.size() == 2) {
drop_impl = args_vec[1];
}

owned_ptr_config = OwnedPtrConfig{
.owned_ptr_type = std::move(owned_ptr_type),
.drop_impl = std::move(drop_impl),
};
}

BazelLabel owning_target = ictx_.GetOwningTarget(record_decl);
Expand Down Expand Up @@ -1195,7 +1219,7 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
.unknown_attr = std::move(*unknown_attr),
.doc_comment = std::move(doc_comment),
.bridge_type = std::move(bridge_type),
.owned_ptr_type = *std::move(owned_ptr_type),
.owned_ptr_config = std::move(owned_ptr_config),
.source_loc = ictx_.ConvertSourceLocation(source_loc, nullptr),
.unambiguous_public_bases = GetUnambiguousPublicBases(*record_decl),
.fields = ImportFields(record_decl),
Expand Down
12 changes: 11 additions & 1 deletion rs_bindings_from_cc/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,13 @@ llvm::json::Value TraitDerives::ToJson() const {
};
}

llvm::json::Value OwnedPtrConfig::ToJson() const {
return llvm::json::Object{
{"owned_ptr_type", owned_ptr_type},
{"drop_impl", drop_impl},
};
}

llvm::json::Value Record::ToJson() const {
std::vector<llvm::json::Value> json_item_ids;
json_item_ids.reserve(child_item_ids.size());
Expand All @@ -686,7 +693,6 @@ llvm::json::Value Record::ToJson() const {
{"unknown_attr", unknown_attr},
{"doc_comment", doc_comment},
{"bridge_type", bridge_type},
{"owned_ptr_type", owned_ptr_type},
{"source_loc", source_loc},
{"unambiguous_public_bases", unambiguous_public_bases},
{"fields", fields},
Expand All @@ -713,6 +719,10 @@ llvm::json::Value Record::ToJson() const {
{"detected_formatter", detected_formatter},
};

if (owned_ptr_config.has_value()) {
record.insert({"owned_ptr_config", owned_ptr_config->ToJson()});
}

if (!lifetime_inputs.empty()) {
record.insert({"lifetime_inputs", lifetime_inputs});
}
Expand Down
9 changes: 8 additions & 1 deletion rs_bindings_from_cc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,13 @@ struct TraitDerives {
std::vector<std::string> custom;
};

struct OwnedPtrConfig {
llvm::json::Value ToJson() const;

std::string owned_ptr_type;
std::string drop_impl;
};

// A record (struct, class, union).
struct Record {
llvm::json::Value ToJson() const;
Expand All @@ -730,7 +737,7 @@ struct Record {
std::optional<std::string> unknown_attr;
std::optional<std::string> doc_comment;
std::optional<BridgeType> bridge_type;
std::optional<std::string> owned_ptr_type;
std::optional<OwnedPtrConfig> owned_ptr_config;
std::string source_loc;
std::vector<BaseClass> unambiguous_public_bases;
std::vector<Field> fields;
Expand Down
12 changes: 10 additions & 2 deletions rs_bindings_from_cc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use code_gen_utils::make_rs_ident;
use crubit_feature::CrubitFeature;
use proc_macro2::{Ident, TokenStream};
use quote::{quote, ToTokens};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::cell::OnceCell;
use std::cmp::Ordering;
use std::collections::hash_map::{Entry, HashMap};
Expand Down Expand Up @@ -1148,6 +1148,13 @@ pub struct TraitDerives {
pub custom: Vec<Rc<str>>,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
#[serde(deny_unknown_fields)]
pub struct OwnedPtrConfig {
pub owned_ptr_type: Rc<str>,
pub drop_impl: Rc<str>,
}

#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Record {
Expand All @@ -1174,7 +1181,8 @@ pub struct Record {
pub unknown_attr: Option<Rc<str>>,
pub doc_comment: Option<Rc<str>>,
pub bridge_type: Option<BridgeType>,
pub owned_ptr_type: Option<Rc<str>>,
#[serde(default)]
pub owned_ptr_config: Option<OwnedPtrConfig>,
pub source_loc: Rc<str>,
pub unambiguous_public_bases: Vec<BaseClass>,
pub fields: Vec<Field>,
Expand Down
4 changes: 2 additions & 2 deletions rs_bindings_from_cc/ir_from_cc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,8 @@ fn test_struct_with_owned_ptr_type_annotation() -> googletest::Result<()> {

let record =
ir.records().find(|record| record.rs_name == "RecordWithOwnedPtrType").or_fail()?;
let owned_ptr_type = &record.owned_ptr_type.clone().or_fail()?;
expect_that!(&**owned_ptr_type, eq("SomeOwnedPtrType"));
let owned_ptr_config = record.owned_ptr_config.as_ref().or_fail()?;
expect_that!(&*owned_ptr_config.owned_ptr_type, eq("SomeOwnedPtrType"));
Ok(())
}

Expand Down
9 changes: 9 additions & 0 deletions rs_bindings_from_cc/test/annotations/owned_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,13 @@ struct CRUBIT_OWNED_POINTEE("OwnedThing") CRUBIT_RUST_NAME("RawThing") Thing {
void Close() { delete this; }
};

// A struct that specifies a custom drop method name.
struct CRUBIT_OWNED_POINTEE("CustomOwnedThing", "CustomDropImpl")
CRUBIT_RUST_NAME("CustomRawThing") CustomThing {
explicit CustomThing(int32_t value) : value(value) {};
int32_t value;

void CustomDropImpl() { delete this; }
};

#endif // THIRD_PARTY_CRUBIT_RS_BINDINGS_FROM_CC_TEST_ANNOTATIONS_OWNED_PTR_H_
16 changes: 16 additions & 0 deletions rs_bindings_from_cc/test/annotations/owned_ptr_api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,20 @@ extern "C" void __rust_thunk___ZN5Thing5CloseEv(struct Thing* __this) {

static_assert((void (::Thing::*)()) & ::Thing::Close);

static_assert(CRUBIT_SIZEOF(struct CustomThing) == 4);
static_assert(alignof(struct CustomThing) == 4);
static_assert(CRUBIT_OFFSET_OF(value, struct CustomThing) == 0);

extern "C" void __rust_thunk___ZN11CustomThingC1Ei(struct CustomThing* __this,
int32_t value) {
crubit::construct_at(__this, value);
}

extern "C" void __rust_thunk___ZN11CustomThing14CustomDropImplEv(
struct CustomThing* __this) {
__this->CustomDropImpl();
}

static_assert((void (::CustomThing::*)()) & ::CustomThing::CustomDropImpl);

#pragma clang diagnostic pop
Loading