diff --git a/include/ctranslate2/models/transformer.h b/include/ctranslate2/models/transformer.h index 4e97f85e0..48bda6cd8 100644 --- a/include/ctranslate2/models/transformer.h +++ b/include/ctranslate2/models/transformer.h @@ -13,6 +13,7 @@ namespace ctranslate2 { TransformerModel(size_t num_heads = 0); size_t current_spec_revision() const override; std::unique_ptr as_sequence_to_sequence() const override; + std::unique_ptr as_sequence_encoder() const override; protected: bool is_linear_weight(const std::string& variable_name) const override; diff --git a/src/models/transformer.cc b/src/models/transformer.cc index f62984b2e..d01b33080 100644 --- a/src/models/transformer.cc +++ b/src/models/transformer.cc @@ -89,6 +89,15 @@ namespace ctranslate2 { return std::make_unique(model, std::move(encoder), std::move(decoder)); } + std::unique_ptr TransformerModel::as_sequence_encoder() const { + const auto scoped_device_setter = get_scoped_device_setter(); + + auto encoder = std::make_unique(*this, "encoder"); + + const auto model = std::static_pointer_cast(shared_from_this()); + return std::make_unique(model, std::move(encoder)); + } + std::unique_ptr TransformerModel::clone() const { return std::make_unique(*this); }