Skip to content
Merged
12 changes: 12 additions & 0 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, Literal, TypeVar

Expand All @@ -10,6 +11,7 @@
)
from pydantic import BaseModel, TypeAdapter, ValidationError

from blueapi import __version__
from blueapi.config import RestConfig
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
Expand All @@ -32,6 +34,8 @@

TRACER = get_tracer("rest")

LOGGER = logging.getLogger(__name__)


class UnauthorisedAccessError(Exception):
pass
Expand Down Expand Up @@ -271,6 +275,14 @@ def _request_and_deserialize(
raise exception
if response.status_code == status.HTTP_204_NO_CONTENT:
raise NoContentError(target_type)
if (server_version := response.headers.get("x-blueapi-version")) is not None:
from packaging.version import Version

if Version(server_version).release != Version(__version__).release:
LOGGER.warning(
f"Server version is {Version(server_version).release} and "
f"client version is {Version(__version__).release}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this as a draft.

It can be some thing like

"Version mismatch: server is , client is {Version(version).release}"

Suggested change
)
if (server_version := Version(server_version).release) != (client_version:= Version(__version__).release):
LOGGER.warning(
f"Version mismatch : Blueapi server version is {server_version}"
f" and client version is {client_version}.Some features may not work as expected."
)

deserialized = TypeAdapter(target_type).validate_python(response.json())
return deserialized

Expand Down
7 changes: 5 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

import blueapi
import blueapi.cli
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import blueapi
import blueapi.cli
from blueapi import __version__

from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.worker import TrackableTask, WorkerState
Expand Down Expand Up @@ -123,7 +125,7 @@ def get_app(config: ApplicationConfig):
app.include_router(secure_router, dependencies=dependencies)
app.add_exception_handler(KeyError, on_key_error_404)
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
app.middleware("http")(add_api_version_header)
app.middleware("http")(add_version_headers)
app.middleware("http")(inject_propagated_observability_context)
app.middleware("http")(log_request_details)
if config.api.cors:
Expand Down Expand Up @@ -568,11 +570,12 @@ def start(config: ApplicationConfig):
)


async def add_api_version_header(
async def add_version_headers(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
response.headers["X-BlueAPI-Version"] = blueapi.__version__
return response


Expand Down
35 changes: 34 additions & 1 deletion tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import uuid
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
import requests
import responses
from packaging.version import Version

from blueapi import __version__
from blueapi.client.rest import (
BlueapiRestClient,
BlueskyRemoteControlError,
Expand Down Expand Up @@ -196,3 +198,34 @@ def test_parameter_error_other_string():
input=34,
)
assert str(p1) == "Invalid value 34 for field field_one.0: error_message"


@pytest.mark.parametrize(
"server_version,logging_warning_present",
[(__version__, False), ("0.0.1", True), (None, False)],
)
@patch("blueapi.client.rest.TypeAdapter")
@patch("blueapi.client.rest.requests.Session.request")
@patch("blueapi.client.rest.LOGGER")
def test_server_and_client_versions(
mock_logger: MagicMock,
mock_request: Mock,
mock_type_adapter: Mock,
rest: BlueapiRestClient,
server_version: str,
logging_warning_present: bool,
):
response = Mock(spec=requests.Response)
response.status_code = 200
response.headers = {"x-blueapi-version": server_version}
mock_request.return_value = response

rest.get_plans()

if logging_warning_present:
mock_logger.warning.assert_called_once_with(
f"Server version is {Version(server_version).release} and "
f"client version is {Version(__version__).release}"
)
else:
mock_logger.assert_not_called()
23 changes: 22 additions & 1 deletion tests/unit_tests/service/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,28 @@
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient

from blueapi.service.main import get_passthrough_headers, log_request_details
import blueapi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import blueapi
from blueapi import __version__

from blueapi.config import ApplicationConfig
from blueapi.service.main import (
add_version_headers,
get_passthrough_headers,
log_request_details,
)


async def test_add_version_header():
app = FastAPI()
app.middleware("http")(add_version_headers)

@app.get("/")
async def root():
return {"message": "Hello World"}

client = TestClient(app)
response = client.get("/")

assert response.headers["X-API-VERSION"] == ApplicationConfig.REST_API_VERSION
assert response.headers["X-BlueAPI-VERSION"] == blueapi.__version__


async def test_log_request_details():
Expand Down
Loading