diff --git a/bindings/node/src/processors.rs b/bindings/node/src/processors.rs index a13c6b9023..50fe3ce568 100644 --- a/bindings/node/src/processors.rs +++ b/bindings/node/src/processors.rs @@ -71,10 +71,10 @@ pub fn roberta_processing( #[napi] pub fn byte_level_processing(trim_offsets: Option) -> Result { - let mut byte_level = tk::processors::byte_level::ByteLevel::default(); + let mut byte_level = tk::processors::byte_level::ByteLevelPostProcessor::default(); if let Some(trim_offsets) = trim_offsets { - byte_level = byte_level.trim_offsets(trim_offsets); + byte_level.0 = byte_level.0.trim_offsets(trim_offsets); } Ok(Processor { diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 221a2bebb0..1c9e60b743 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -10,7 +10,7 @@ use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; use tk::decoders::byte_fallback::ByteFallback; -use tk::decoders::byte_level::ByteLevel; +use tk::decoders::byte_level::ByteLevelDecoder; use tk::decoders::ctc::CTC; use tk::decoders::fuse::Fuse; use tk::decoders::metaspace::{Metaspace, PrependScheme}; @@ -187,7 +187,7 @@ impl PyByteLevelDec { #[new] #[pyo3(signature = (**_kwargs), text_signature = "(self)")] fn new(_kwargs: Option<&Bound<'_, PyDict>>) -> (Self, PyDecoder) { - (PyByteLevelDec {}, ByteLevel::default().into()) + (PyByteLevelDec {}, ByteLevelDecoder::default().into()) } } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 35ca0d3bce..c3b0fb9c3b 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -14,7 +14,7 @@ use serde::Deserializer; use serde::Serializer; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; -use tk::processors::byte_level::ByteLevel; +use tk::processors::byte_level::{ByteLevel, ByteLevelPostProcessor}; use tk::processors::roberta::RobertaProcessing; use tk::processors::template::{SpecialToken, Template}; use tk::processors::PostProcessorWrapper; @@ -538,7 +538,10 @@ impl PyByteLevel { byte_level = byte_level.use_regex(ur); } - (PyByteLevel {}, byte_level.into()) + ( + PyByteLevel {}, + ByteLevelPostProcessor::from(byte_level).into(), + ) } #[getter] diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 3a64d25ccf..c27ede2be6 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -813,7 +813,7 @@ def test_repr_complete(self): out = repr(tokenizer) assert ( out - == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' ) diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index 87f90b8fc0..3dde5b6f57 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -4,9 +4,10 @@ extern crate criterion; mod common; use criterion::{Criterion, Throughput}; +use tokenizers::decoders::byte_level::ByteLevel as ByteLevelDecoder; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::models::TrainerWrapper; -use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::byte_level::{ByteLevel, ByteLevelDecoder}; use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::tokenizer::{AddedToken, EncodeInput}; use tokenizers::Tokenizer; @@ -19,7 +20,7 @@ static BATCH_SIZE: usize = 1_000; fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { let mut tokenizer = Tokenizer::new(bpe); tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); - tokenizer.with_decoder(Some(ByteLevel::default())); + tokenizer.with_decoder(Some(ByteLevelDecoder::default())); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer diff --git a/tokenizers/benches/truncation_benchmark.rs b/tokenizers/benches/truncation_benchmark.rs index 340a1d99c6..624365fa20 100644 --- a/tokenizers/benches/truncation_benchmark.rs +++ b/tokenizers/benches/truncation_benchmark.rs @@ -4,7 +4,7 @@ extern crate criterion; use criterion::{BenchmarkId, Criterion, Throughput}; use std::hint::black_box; use tokenizers::models::bpe::BPE; -use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::byte_level::{ByteLevel, ByteLevelDecoder}; use tokenizers::tokenizer::{ AddedToken, TruncationDirection, TruncationParams, TruncationStrategy, }; @@ -16,7 +16,7 @@ fn create_gpt2_tokenizer() -> Tokenizer { .unwrap(); let mut tokenizer = Tokenizer::new(bpe); tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); - tokenizer.with_decoder(Some(ByteLevel::default())); + tokenizer.with_decoder(Some(ByteLevelDecoder::default())); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 6e79e7029c..f62e8b5d67 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -20,7 +20,7 @@ use crate::decoders::sequence::Sequence; use crate::decoders::strip::Strip; use crate::decoders::wordpiece::WordPiece; use crate::normalizers::replace::Replace; -use crate::pre_tokenizers::byte_level::ByteLevel; +use crate::pre_tokenizers::byte_level::ByteLevelDecoder; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; @@ -28,7 +28,7 @@ use crate::{Decoder, Result}; #[serde(untagged)] pub enum DecoderWrapper { BPE(BPEDecoder), - ByteLevel(ByteLevel), + ByteLevel(ByteLevelDecoder), WordPiece(WordPiece), Metaspace(Metaspace), CTC(CTC), @@ -76,7 +76,7 @@ impl<'de> Deserialize<'de> for DecoderWrapper { #[serde(untagged)] pub enum DecoderUntagged { BPE(BPEDecoder), - ByteLevel(ByteLevel), + ByteLevel(ByteLevelDecoder), WordPiece(WordPiece), Metaspace(Metaspace), CTC(CTC), @@ -167,7 +167,7 @@ impl Decoder for DecoderWrapper { } impl_enum_from!(BPEDecoder, DecoderWrapper, BPE); -impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel); +impl_enum_from!(ByteLevelDecoder, DecoderWrapper, ByteLevel); impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback); impl_enum_from!(Fuse, DecoderWrapper, Fuse); impl_enum_from!(Strip, DecoderWrapper, Strip); diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af0..a46935ffb8 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -2,13 +2,14 @@ use ahash::{AHashMap, AHashSet}; use std::sync::LazyLock; use crate::utils::SysRegex; -use serde::{Deserialize, Serialize}; +use serde::de::{self, MapAccess, Visitor}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::tokenizer::{ Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; -use crate::utils::macro_rules_attribute; /// Converts bytes to unicode characters. /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 @@ -48,11 +49,14 @@ static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); static CHAR_BYTES: LazyLock> = LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect()); -#[derive(Copy, Clone, Debug, PartialEq, Eq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. -#[macro_rules_attribute(impl_serde_type!)] +/// +/// As a [`PreTokenizer`]: uses `add_prefix_space` and `use_regex`. +/// As a [`Decoder`]: see [`ByteLevelDecoder`]. +/// As a [`PostProcessor`]: see [`ByteLevelPostProcessor`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word @@ -60,15 +64,176 @@ pub struct ByteLevel { pub add_prefix_space: bool, /// Whether the post processing step should trim offsets to avoid including whitespaces. pub trim_offsets: bool, - - /// Whether to use the standard GPT2 regex for whitespace splitting + /// Whether to use the standard GPT2 regex for whitespace splitting. /// Set it to False if you want to use your own splitting. - #[serde(default = "default_true")] pub use_regex: bool, } -fn default_true() -> bool { - true +/// Serializes `ByteLevel` for the **pre-tokenizer** role. +/// Only `add_prefix_space` and `use_regex` are relevant here; `trim_offsets` is omitted. +impl Serialize for ByteLevel { + fn serialize(&self, serializer: S) -> std::result::Result { + let mut state = serializer.serialize_struct("ByteLevel", 3)?; + state.serialize_field("type", "ByteLevel")?; + state.serialize_field("add_prefix_space", &self.add_prefix_space)?; + state.serialize_field("use_regex", &self.use_regex)?; + state.end() + } +} + +/// Deserializes `ByteLevel` accepting all three fields (backward-compat). +impl<'de> Deserialize<'de> for ByteLevel { + fn deserialize>(deserializer: D) -> std::result::Result { + struct ByteLevelVisitor; + + impl<'de> Visitor<'de> for ByteLevelVisitor { + type Value = ByteLevel; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("struct ByteLevel") + } + + fn visit_map>( + self, + mut map: A, + ) -> std::result::Result { + let mut add_prefix_space = true; + let mut trim_offsets = true; + let mut use_regex = true; + let mut seen_type = false; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "add_prefix_space" => add_prefix_space = map.next_value()?, + "trim_offsets" => trim_offsets = map.next_value()?, + "use_regex" => use_regex = map.next_value()?, + "type" => { + let v: String = map.next_value()?; + if v != "ByteLevel" { + return Err(de::Error::custom(format!( + "expected type `ByteLevel`, got `{v}`" + ))); + } + seen_type = true; + } + _ => { + map.next_value::()?; + } + } + } + if !seen_type { + return Err(de::Error::missing_field("type")); + } + Ok(ByteLevel { + add_prefix_space, + trim_offsets, + use_regex, + }) + } + } + + deserializer.deserialize_map(ByteLevelVisitor) + } +} + +/// `ByteLevel` in its **decoder** role. None of the byte-level flags affect decoding, +/// so only `"type": "ByteLevel"` is serialized. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +pub struct ByteLevelDecoder(pub ByteLevel); + +impl std::ops::Deref for ByteLevelDecoder { + type Target = ByteLevel; + fn deref(&self) -> &ByteLevel { + &self.0 + } +} + +impl std::ops::DerefMut for ByteLevelDecoder { + fn deref_mut(&mut self) -> &mut ByteLevel { + &mut self.0 + } +} + +impl From for ByteLevelDecoder { + fn from(bl: ByteLevel) -> Self { + Self(bl) + } +} + +impl From for ByteLevel { + fn from(bld: ByteLevelDecoder) -> Self { + bld.0 + } +} + +impl Serialize for ByteLevelDecoder { + fn serialize(&self, serializer: S) -> std::result::Result { + let mut state = serializer.serialize_struct("ByteLevel", 1)?; + state.serialize_field("type", "ByteLevel")?; + state.end() + } +} + +impl<'de> Deserialize<'de> for ByteLevelDecoder { + fn deserialize>(deserializer: D) -> std::result::Result { + ByteLevel::deserialize(deserializer).map(ByteLevelDecoder) + } +} + +/// `ByteLevel` in its **post-processor** role. Only `add_prefix_space` and `trim_offsets` +/// are relevant here; `use_regex` is omitted. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +pub struct ByteLevelPostProcessor(pub ByteLevel); + +impl std::ops::Deref for ByteLevelPostProcessor { + type Target = ByteLevel; + fn deref(&self) -> &ByteLevel { + &self.0 + } +} + +impl std::ops::DerefMut for ByteLevelPostProcessor { + fn deref_mut(&mut self) -> &mut ByteLevel { + &mut self.0 + } +} + +impl ByteLevelPostProcessor { + pub fn new(add_prefix_space: bool, trim_offsets: bool) -> Self { + Self(ByteLevel { + add_prefix_space, + trim_offsets, + use_regex: true, + }) + } +} + +impl From for ByteLevelPostProcessor { + fn from(bl: ByteLevel) -> Self { + Self(bl) + } +} + +impl From for ByteLevel { + fn from(blp: ByteLevelPostProcessor) -> Self { + blp.0 + } +} + +impl Serialize for ByteLevelPostProcessor { + fn serialize(&self, serializer: S) -> std::result::Result { + let mut state = serializer.serialize_struct("ByteLevel", 3)?; + state.serialize_field("type", "ByteLevel")?; + state.serialize_field("add_prefix_space", &self.0.add_prefix_space)?; + state.serialize_field("trim_offsets", &self.0.trim_offsets)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for ByteLevelPostProcessor { + fn deserialize>(deserializer: D) -> std::result::Result { + ByteLevel::deserialize(deserializer).map(ByteLevelPostProcessor) + } } impl Default for ByteLevel { @@ -171,6 +336,12 @@ impl Decoder for ByteLevel { } } +impl Decoder for ByteLevelDecoder { + fn decode_chain(&self, tokens: Vec) -> Result> { + self.0.decode_chain(tokens) + } +} + /// As a `PostProcessor`, `ByteLevel` is in charge of trimming the offsets if necessary. impl PostProcessor for ByteLevel { fn added_tokens(&self, _is_pair: bool) -> usize { @@ -199,6 +370,20 @@ impl PostProcessor for ByteLevel { } } +impl PostProcessor for ByteLevelPostProcessor { + fn added_tokens(&self, is_pair: bool) -> usize { + self.0.added_tokens(is_pair) + } + + fn process_encodings( + &self, + encodings: Vec, + add_special_tokens: bool, + ) -> Result> { + self.0.process_encodings(encodings, add_special_tokens) + } +} + pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) { encoding.process_tokens_with_offsets_mut(|(i, (token, offsets))| { let mut leading_spaces = token @@ -568,6 +753,50 @@ mod tests { ); } + #[test] + fn serialization_pre_tokenizer() { + // ByteLevel as pre-tokenizer: only add_prefix_space and use_regex are serialized + let bl = ByteLevel::default(); + assert_eq!( + serde_json::to_string(&bl).unwrap(), + r#"{"type":"ByteLevel","add_prefix_space":true,"use_regex":true}"# + ); + let bl = ByteLevel::default().add_prefix_space(false); + assert_eq!( + serde_json::to_string(&bl).unwrap(), + r#"{"type":"ByteLevel","add_prefix_space":false,"use_regex":true}"# + ); + // trim_offsets is intentionally absent (not used by pre-tokenizer) + assert!(!serde_json::to_string(&bl).unwrap().contains("trim_offsets")); + } + + #[test] + fn serialization_decoder() { + // ByteLevelDecoder: only type is serialized — no flags affect decoding + let decoder = ByteLevelDecoder::default(); + assert_eq!( + serde_json::to_string(&decoder).unwrap(), + r#"{"type":"ByteLevel"}"# + ); + } + + #[test] + fn serialization_post_processor() { + // ByteLevelPostProcessor: only add_prefix_space and trim_offsets are serialized + let proc = ByteLevelPostProcessor::default(); + assert_eq!( + serde_json::to_string(&proc).unwrap(), + r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":true}"# + ); + let proc = ByteLevelPostProcessor::new(true, false); + assert_eq!( + serde_json::to_string(&proc).unwrap(), + r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":false}"# + ); + // use_regex is intentionally absent (not used by post-processor) + assert!(!serde_json::to_string(&proc).unwrap().contains("use_regex")); + } + #[test] fn deserialization() { // Before use_regex diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index 869cc68912..0ea65ad44e 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -8,7 +8,7 @@ pub use super::pre_tokenizers::byte_level; use serde::{Deserialize, Serialize}; -use crate::pre_tokenizers::byte_level::ByteLevel; +use crate::pre_tokenizers::byte_level::ByteLevelPostProcessor; use crate::processors::bert::BertProcessing; use crate::processors::roberta::RobertaProcessing; use crate::processors::sequence::Sequence; @@ -21,7 +21,7 @@ pub enum PostProcessorWrapper { // Roberta must be before Bert for deserialization (serde does not validate tags) Roberta(RobertaProcessing), Bert(BertProcessing), - ByteLevel(ByteLevel), + ByteLevel(ByteLevelPostProcessor), Template(TemplateProcessing), Sequence(Sequence), } @@ -30,7 +30,7 @@ impl PostProcessor for PostProcessorWrapper { fn added_tokens(&self, is_pair: bool) -> usize { match self { Self::Bert(bert) => bert.added_tokens(is_pair), - Self::ByteLevel(bl) => bl.added_tokens(is_pair), + Self::ByteLevel(bl) => bl.0.added_tokens(is_pair), Self::Roberta(roberta) => roberta.added_tokens(is_pair), Self::Template(template) => template.added_tokens(is_pair), Self::Sequence(bl) => bl.added_tokens(is_pair), @@ -44,7 +44,7 @@ impl PostProcessor for PostProcessorWrapper { ) -> Result> { match self { Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens), - Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens), + Self::ByteLevel(bl) => bl.0.process_encodings(encodings, add_special_tokens), Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens), Self::Template(template) => template.process_encodings(encodings, add_special_tokens), Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens), @@ -53,7 +53,7 @@ impl PostProcessor for PostProcessorWrapper { } impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert); -impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel); +impl_enum_from!(ByteLevelPostProcessor, PostProcessorWrapper, ByteLevel); impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta); impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template); impl_enum_from!(Sequence, PostProcessorWrapper, Sequence); diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index f44cf54ac8..d2ef60dc82 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -71,7 +71,8 @@ impl PostProcessor for Sequence { #[cfg(test)] mod tests { use super::*; - use crate::processors::{ByteLevel, PostProcessorWrapper}; + use crate::pre_tokenizers::byte_level::ByteLevelPostProcessor; + use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor}; use ahash::AHashMap; use std::iter::FromIterator; @@ -96,7 +97,7 @@ mod tests { AHashMap::new(), ); - let bytelevel = ByteLevel::default().trim_offsets(true); + let bytelevel = ByteLevelPostProcessor::new(true, true); let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]); let expected = Encoding::new( vec![0; 5], diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index 26129699be..82cc3e3081 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -3,7 +3,7 @@ use tokenizers::models::bpe::BPE; use tokenizers::models::wordpiece::WordPiece; use tokenizers::normalizers::bert::BertNormalizer; use tokenizers::pre_tokenizers::bert::BertPreTokenizer; -use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::byte_level::{ByteLevel, ByteLevelDecoder, ByteLevelPostProcessor}; use tokenizers::processors::bert::BertProcessing; use tokenizers::tokenizer::{Model, Tokenizer}; @@ -26,8 +26,11 @@ pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { .with_pre_tokenizer(Some( ByteLevel::default().add_prefix_space(add_prefix_space), )) - .with_decoder(Some(ByteLevel::default())) - .with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets))); + .with_decoder(Some(ByteLevelDecoder::default())) + .with_post_processor(Some(ByteLevelPostProcessor::new( + add_prefix_space, + trim_offsets, + ))); tokenizer } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 2ab99467ed..e7c31e09e4 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -4,7 +4,7 @@ use ahash::AHashMap; use tokenizers::decoders::byte_fallback::ByteFallback; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::normalizers::{Sequence, Strip, NFC}; -use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::byte_level::{ByteLevel, ByteLevelDecoder, ByteLevelPostProcessor}; use tokenizers::{AddedToken, TokenizerBuilder}; use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper}; use tokenizers::{Tokenizer, TokenizerImpl}; @@ -19,8 +19,8 @@ fn train_tokenizer() { NFC.into(), ]))) .with_pre_tokenizer(Some(ByteLevel::default())) - .with_post_processor(Some(ByteLevel::default())) - .with_decoder(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevelPostProcessor::default())) + .with_decoder(Some(ByteLevelDecoder::default())) .build() .unwrap(); @@ -110,7 +110,7 @@ fn streaming_tokenizer() { NFC.into(), ]))) .with_pre_tokenizer(Some(ByteLevel::default())) - .with_post_processor(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevelPostProcessor::default())) .with_decoder(Some(ByteFallback::default())) .build() .unwrap(); diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index dc0c95a57e..9204f2f2bc 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -1,7 +1,7 @@ mod common; use common::*; -use tokenizers::decoders::byte_level::ByteLevel; +use tokenizers::decoders::byte_level::{ByteLevel, ByteLevelDecoder}; use tokenizers::decoders::DecoderWrapper; use tokenizers::models::bpe::BPE; use tokenizers::models::wordlevel::WordLevel; @@ -169,20 +169,28 @@ fn pretoks() { #[test] fn decoders() { - let byte_level = ByteLevel::default(); - let byte_level_ser = serde_json::to_string(&byte_level).unwrap(); - assert_eq!( - byte_level_ser, - r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":true,"use_regex":true}"# - ); - serde_json::from_str::(&byte_level_ser).unwrap(); - let byte_level_wrapper: DecoderWrapper = serde_json::from_str(&byte_level_ser).unwrap(); + // ByteLevelDecoder serializes with no extra fields + let decoder = ByteLevelDecoder::default(); + let decoder_ser = serde_json::to_string(&decoder).unwrap(); + assert_eq!(decoder_ser, r#"{"type":"ByteLevel"}"#); + + // Old format (with all fields) still deserializes into DecoderWrapper::ByteLevel + let old_format = + r#"{"type":"ByteLevel","add_prefix_space":true,"trim_offsets":true,"use_regex":true}"#; + let byte_level_wrapper: DecoderWrapper = serde_json::from_str(old_format).unwrap(); match &byte_level_wrapper { DecoderWrapper::ByteLevel(_) => (), _ => panic!("ByteLevel wrapped with incorrect variant"), } - let ser_wrapped = serde_json::to_string(&byte_level_wrapper).unwrap(); - assert_eq!(ser_wrapped, byte_level_ser); + + // ByteLevel (pre-tokenizer) serializes add_prefix_space and use_regex only + let pre_tok = ByteLevel::default(); + let pre_tok_ser = serde_json::to_string(&pre_tok).unwrap(); + assert_eq!( + pre_tok_ser, + r#"{"type":"ByteLevel","add_prefix_space":true,"use_regex":true}"# + ); + serde_json::from_str::(&pre_tok_ser).unwrap(); } #[test]