diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 87f41fbc9..95a7d9090 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -382,10 +382,7 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None - # Change dtype to input dtype on CPU if A.device.type == "cpu": - quant_state.dtype = A.dtype - if getattr(quant_state, "packing_format_for_cpu", False): out = F.gemv_4bit(A, B, out, state=quant_state) if bias is not None: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3c065b739..bfd41d5dd 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -258,42 +258,85 @@ def __setstate__(self, state): self.bnb_quantized = state["bnb_quantized"] self.module = state["module"] - # Map from state_dict key names (as produced by QuantState.as_dict) to - # the actual QuantState attribute/access path. FSDP's _get_fqns() resolves - # dotted FQN keys via getattr, so "weight.quant_map" becomes - # getattr(weight, "quant_map") — we must map that to quant_state.code. - _QUANT_STATE_ATTR_MAP = { - # Direct QuantState attributes - "absmax": lambda qs: qs.absmax, - "code": lambda qs: qs.code, - "blocksize": lambda qs: qs.blocksize, - "dtype": lambda qs: qs.dtype, - "shape": lambda qs: qs.shape, - "offset": lambda qs: qs.offset, - "state2": lambda qs: qs.state2, - # as_dict serializes code → "quant_map" - "quant_map": lambda qs: qs.code, - "quant_type": lambda qs: qs.quant_type, - # as_dict serializes nested state2 attributes under "nested_*" keys - "nested_absmax": lambda qs: qs.state2.absmax, - "nested_blocksize": lambda qs: qs.state2.blocksize, - "nested_quant_map": lambda qs: qs.state2.code, - "nested_dtype": lambda qs: qs.state2.dtype, - "nested_offset": lambda qs: qs.offset, - } - - def __getattr__(self, name): - # Proxy known QuantState attributes so that PyTorch's FSDP state_dict - # machinery (which traverses FQN paths via getattr) can find them. - accessor = self._QUANT_STATE_ATTR_MAP.get(name) - if accessor is not None: - quant_state = self.__dict__.get("quant_state") - if quant_state is not None: - try: - return accessor(quant_state) - except AttributeError: - pass - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + # Properties that proxy QuantState attributes for FSDP state_dict traversal. + # FSDP's _get_fqns() resolves dotted FQN keys via getattr, e.g. "weight.absmax" + # becomes getattr(weight, "absmax"). Using @property instead of __getattr__ + # avoids torch.compile graph breaks (see #1904), since Dynamo can trace + # descriptor protocol access but not __getattr__ on Tensor subclasses. + # + # Note: attributes that collide with Params4bit instance attrs (blocksize, + # quant_type) or Tensor attrs (dtype, shape) are intentionally omitted — + # they are packed into the bitsandbytes__* blob and not traversed by FSDP. + + @property + def absmax(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.absmax + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'absmax'") + + @property + def code(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.code + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'code'") + + @property + def quant_map(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.code + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'quant_map'") + + @property + def offset(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.offset + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'offset'") + + @property + def state2(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.state2 + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'state2'") + + @property + def nested_absmax(self): + qs = self.__dict__.get("quant_state") + if qs is not None and qs.state2 is not None: + return qs.state2.absmax + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_absmax'") + + @property + def nested_blocksize(self): + qs = self.__dict__.get("quant_state") + if qs is not None and qs.state2 is not None: + return qs.state2.blocksize + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_blocksize'") + + @property + def nested_quant_map(self): + qs = self.__dict__.get("quant_state") + if qs is not None and qs.state2 is not None: + return qs.state2.code + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_quant_map'") + + @property + def nested_dtype(self): + qs = self.__dict__.get("quant_state") + if qs is not None and qs.state2 is not None: + return qs.state2.dtype + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_dtype'") + + @property + def nested_offset(self): + qs = self.__dict__.get("quant_state") + if qs is not None: + return qs.offset + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_offset'") def __deepcopy__(self, memo): new_instance = type(self).__new__(type(self)) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index d43656b63..296a70b0f 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -434,6 +434,75 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st torch.testing.assert_close(grad_compiled, grad_ref) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.skipif(torch.__version__ < (2, 8, 0, "dev"), reason="fullgraph requires torch 2.8+") +@pytest.mark.skipif( + torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" +) +def test_linear4bit_torch_compile_activation_checkpointing(device, quant_type, compress_statistics): + """Regression test for #1904: __getattr__ on Params4bit causes graph breaks under torch.compile. + + Activation checkpointing replays the forward pass during backward, which multiplies + attribute accesses on Params4bit. If __getattr__ is defined (instead of @property), + Dynamo cannot trace through it and creates graph breaks. With fullgraph=True, this + causes torch.compile to raise an error rather than silently degrading performance. + """ + if device == "hpu" and not is_supported_on_hpu(quant_type): + pytest.skip("This configuration is not supported on HPU.") + if device == "cuda" and platform.system() == "Windows": + pytest.skip("Triton is not officially supported on Windows") + dim = 256 + batch_size = 16 + compute_dtype = torch.bfloat16 + + torch.compiler.reset() + + class CheckpointedNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + bnb.nn.Linear4bit( + dim, + dim, + bias=False, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + for _ in range(4) + ] + ) + + def forward(self, x): + for layer in self.layers: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + return x + + net = CheckpointedNet().to(device) + + x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device, requires_grad=True) + + # Reference output (eager) + ref_output = net(x) + ref_output.sum().backward() + grad_ref = x.grad.clone() + x.grad = None + + # Compiled with fullgraph=True — will raise if there are graph breaks + compile_backend = "hpu_backend" if device == "hpu" else "inductor" + compiled_net = torch.compile(net, fullgraph=True, backend=compile_backend) + + compiled_output = compiled_net(x) + compiled_output.sum().backward() + grad_compiled = x.grad.clone() + + torch.testing.assert_close(compiled_output, ref_output) + torch.testing.assert_close(grad_compiled, grad_ref) + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -494,7 +563,7 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist with pytest.raises(AttributeError, match="nonexistent_attribute"): _ = w.nonexistent_attribute - # Verify that normal Params4bit attributes are unaffected by __getattr__ + # Verify that normal Params4bit instance attributes are unaffected assert isinstance(w.quant_state, bnb.functional.QuantState) assert isinstance(w.bnb_quantized, bool) assert w.bnb_quantized is True