Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions project/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,29 @@ def get_storage_client(
)


@lru_cache
def get_node_id(
settings: Annotated[Settings, Depends(get_settings)],
core_client: Annotated[flame_hub.CoreClient, Depends(get_core_client)],
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
if settings.hub.auth.flow != "client":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="It's only possible to retrieve the id of this node if a client authentication flow is configured.",
)

client_id = settings.hub.auth.id
nodes = core_client.find_nodes(filter={"client_id": client_id})

if len(nodes) != 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Found {len(nodes)} nodes with the client id {client_id}.",
)

return nodes[0].id


@lru_cache
def get_postgres_db(
settings: Annotated[Settings, Depends(get_settings)],
Expand Down
4 changes: 4 additions & 0 deletions project/routers/intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_core_client,
get_storage_client,
get_ecdh_private_key,
get_node_id,
)
from project.event_logging import EventLoggingRoute

Expand Down Expand Up @@ -65,6 +66,7 @@ async def submit_intermediate_result_to_hub(
core_client: Annotated[flame_hub.CoreClient, Depends(get_core_client)],
storage_client: Annotated[flame_hub.StorageClient, Depends(get_storage_client)],
private_key: Annotated[ec.EllipticCurvePrivateKey, Depends(get_ecdh_private_key)],
node_id: Annotated[uuid.UUID, Depends(get_node_id)],
remote_node_id: Annotated[str, Form()],
):
"""Upload a file as an intermediate result to the FLAME Hub.
Expand Down Expand Up @@ -114,6 +116,8 @@ async def submit_intermediate_result_to_hub(
request.url_for(
"intermediate.object.get",
object_id=bucket_file.id,
).include_query_params(
remote_node_id=node_id,
)
),
)
Expand Down
3 changes: 3 additions & 0 deletions project/routers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_core_client,
get_storage_client,
get_ecdh_private_key,
get_node_id,
)
from project.routers.intermediate import IntermediateUploadResponse, submit_intermediate_result_to_hub
from project.event_logging import EventLoggingRoute
Expand Down Expand Up @@ -383,6 +384,7 @@ async def upload_local_file(
storage_client: Annotated[flame_hub.StorageClient, Depends(get_storage_client)],
db: Annotated[PooledPostgresqlDatabase, Depends(get_postgres_db)],
private_key: Annotated[ec.EllipticCurvePrivateKey, Depends(get_ecdh_private_key)],
node_id: Annotated[uuid.UUID, Depends(get_node_id)],
):
"""Upload a local file directly to the FLAME Hub so that the requesting service does not have to load the file in
its working memory to use the intermediate upload endpoint. Returns a 200 on success. This endpoint returns a link
Expand Down Expand Up @@ -414,4 +416,5 @@ async def upload_local_file(
storage_client=storage_client,
private_key=private_key,
remote_node_id=remote_node_id,
node_id=node_id,
)
18 changes: 14 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from minio import Minio
from node_event_logging import EventLog

from project.dependencies import get_postgres_db, get_local_minio, get_ecdh_private_key
from project.dependencies import get_postgres_db, get_local_minio, get_ecdh_private_key, get_node_id
from project.server import get_server_instance
from tests.common import env
from tests.common.auth import get_oid_test_jwk, get_test_ecdh_keypair
Expand Down Expand Up @@ -128,9 +128,17 @@ def _get_ecdh_private_key():
yield _get_ecdh_private_key


@pytest.fixture(scope="package")
def override_node_id(this_node):
def _get_node_id():
return this_node.id

yield _get_node_id


# noinspection PyUnresolvedReferences
@pytest.fixture(scope="package")
def test_app(override_minio, override_postgres, override_ecdh_private_key):
def test_app(override_minio, override_postgres, override_ecdh_private_key, override_node_id):
app = get_server_instance()

if callable(override_postgres):
Expand All @@ -141,6 +149,8 @@ def test_app(override_minio, override_postgres, override_ecdh_private_key):

app.dependency_overrides[get_ecdh_private_key] = override_ecdh_private_key

app.dependency_overrides[get_node_id] = override_node_id

return app


Expand Down Expand Up @@ -347,7 +357,7 @@ def analysis_id(analysis_id_factory):
return analysis_id_factory()


@pytest.fixture
@pytest.fixture(scope="package")
def realm_id(auth_client):
preferred_realm_name = os.environ.get("PYTEST__PREFERRED_REALM_NAME", "master")
realm_list = auth_client.find_realms(filter={"name": preferred_realm_name})
Expand All @@ -357,7 +367,7 @@ def realm_id(auth_client):
yield realm_list.pop()


@pytest.fixture()
@pytest.fixture(scope="package")
def this_node(core_client, realm_id):
node = core_client.create_node(name=next_uuid(), realm_id=realm_id, node_type="default")
_, public_key = get_test_ecdh_keypair()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def test_200_encrypt_and_decrypt(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={"remote_node_id": this_node.id},
)
finally:
reset_private_key()

assert r.status_code == status.HTTP_200_OK, str(r.text)
assert blob == r.read()
assert storage_client.get_bucket_file(bucket_file_id=model.object_id) is None, (
"File was not deleted from the Hub after its retrieval."
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_400_decrypt_intermediate(
)

# The file is encrypted for a remote node and therefore cannot be decrypted by the node that encrypted the file
# and of course all other nodes except that one remote node.
# and of course all other nodes except that one remote node. Note that the local private key is not exchanged.
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert detail_of(r) == (
f"File with ID {model.object_id} cannot be decrypted under the assumption that the file was encrypted by node "
Expand Down
5 changes: 1 addition & 4 deletions tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,8 @@ def test_200_upload_local_file(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={
"remote_node_id": str(this_node.id),
},
)
finally:
reset_private_key()
Expand Down
5 changes: 1 addition & 4 deletions tests/test_local_tagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,8 @@ def test_200_upload_local_file(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={
"remote_node_id": str(this_node.id),
},
)
finally:
reset_private_key()
Expand Down
Loading