Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions project/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,28 @@ def get_storage_client(
)


def get_node_id(
settings: Annotated[Settings, Depends(get_settings)],
core_client: Annotated[flame_hub.CoreClient, Depends(get_core_client)],
):
if settings.hub.auth.flow != AuthFlow.client:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="It's only possible to retrieve the id of this node if a client authentication flow is configured.",
)

client_id = settings.hub.auth.id
nodes = core_client.find_nodes(filter={"client_id": client_id})

if len(nodes) != 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Found {len(nodes)} nodes with the client id {client_id}.",
)

return nodes[0].id


@lru_cache
def get_postgres_db(
settings: Annotated[Settings, Depends(get_settings)],
Expand Down
4 changes: 4 additions & 0 deletions project/routers/intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_core_client,
get_storage_client,
get_ecdh_private_key,
get_node_id,
)
from project.event_logging import EventLoggingRoute

Expand Down Expand Up @@ -65,6 +66,7 @@ async def submit_intermediate_result_to_hub(
core_client: Annotated[flame_hub.CoreClient, Depends(get_core_client)],
storage_client: Annotated[flame_hub.StorageClient, Depends(get_storage_client)],
private_key: Annotated[ec.EllipticCurvePrivateKey, Depends(get_ecdh_private_key)],
node_id: Annotated[uuid.UUID, Depends(get_node_id)],
remote_node_id: Annotated[str, Form()],
):
"""Upload a file as an intermediate result to the FLAME Hub.
Expand Down Expand Up @@ -114,6 +116,8 @@ async def submit_intermediate_result_to_hub(
request.url_for(
"intermediate.object.get",
object_id=bucket_file.id,
).include_query_params(
remote_node_id=node_id,
)
),
)
Expand Down
3 changes: 3 additions & 0 deletions project/routers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_core_client,
get_storage_client,
get_ecdh_private_key,
get_node_id,
)
from project.routers.intermediate import IntermediateUploadResponse, submit_intermediate_result_to_hub
from project.event_logging import EventLoggingRoute
Expand Down Expand Up @@ -383,6 +384,7 @@ async def upload_local_file(
storage_client: Annotated[flame_hub.StorageClient, Depends(get_storage_client)],
db: Annotated[PooledPostgresqlDatabase, Depends(get_postgres_db)],
private_key: Annotated[ec.EllipticCurvePrivateKey, Depends(get_ecdh_private_key)],
node_id: Annotated[uuid.UUID, Depends(get_node_id)],
):
"""Upload a local file directly to the FLAME Hub so that the requesting service does not have to load the file in
its working memory to use the intermediate upload endpoint. Returns a 200 on success. This endpoint returns a link
Expand Down Expand Up @@ -414,4 +416,5 @@ async def upload_local_file(
storage_client=storage_client,
private_key=private_key,
remote_node_id=remote_node_id,
node_id=node_id,
)
15 changes: 11 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from minio import Minio
from node_event_logging import EventLog

from project.dependencies import get_postgres_db, get_local_minio, get_ecdh_private_key
from project.dependencies import get_postgres_db, get_local_minio, get_ecdh_private_key, get_node_id
from project.server import get_server_instance
from tests.common import env
from tests.common.auth import get_oid_test_jwk, get_test_ecdh_keypair
Expand All @@ -29,6 +29,7 @@
next_random_bytes,
next_uuid,
next_ecdh_keypair_bytes,
temporarily_change_dependency,
)


Expand Down Expand Up @@ -347,7 +348,7 @@ def analysis_id(analysis_id_factory):
return analysis_id_factory()


@pytest.fixture
@pytest.fixture(scope="package")
def realm_id(auth_client):
preferred_realm_name = os.environ.get("PYTEST__PREFERRED_REALM_NAME", "master")
realm_list = auth_client.find_realms(filter={"name": preferred_realm_name})
Expand All @@ -357,14 +358,20 @@ def realm_id(auth_client):
yield realm_list.pop()


@pytest.fixture()
def this_node(core_client, realm_id):
@pytest.fixture(scope="package")
def this_node(test_client, core_client, realm_id):
node = core_client.create_node(name=next_uuid(), realm_id=realm_id, node_type="default")
_, public_key = get_test_ecdh_keypair()
# Also update node reference.
node = core_client.update_node(
node, public_key=public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).hex()
)

def override_get_node_id():
return node.id

# Change dependency here since live infra is mandatory at this point.
temporarily_change_dependency(test_client, get_node_id, override_get_node_id)
yield node
core_client.delete_node(node.id)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def test_200_encrypt_and_decrypt(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={"remote_node_id": this_node.id},
)
finally:
reset_private_key()

assert r.status_code == status.HTTP_200_OK, str(r.text)
assert blob == r.read()
assert storage_client.get_bucket_file(bucket_file_id=model.object_id) is None, (
"File was not deleted from the Hub after its retrieval."
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_400_decrypt_intermediate(
)

# The file is encrypted for a remote node and therefore cannot be decrypted by the node that encrypted the file
# and of course all other nodes except that one remote node.
# and of course all other nodes except that one remote node. Note that the local private key is not exchanged.
assert r.status_code == status.HTTP_400_BAD_REQUEST
assert detail_of(r) == (
f"File with ID {model.object_id} cannot be decrypted under the assumption that the file was encrypted by node "
Expand Down
5 changes: 1 addition & 4 deletions tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,8 @@ def test_200_upload_local_file(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={
"remote_node_id": str(this_node.id),
},
)
finally:
reset_private_key()
Expand Down
5 changes: 1 addition & 4 deletions tests/test_local_tagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,8 @@ def test_200_upload_local_file(

try:
r = test_client.get(
model.url.path,
f"{model.url.path}?{model.url.query}",
auth=BearerAuth(issue_client_access_token(analysis_id)),
params={
"remote_node_id": str(this_node.id),
},
)
finally:
reset_private_key()
Expand Down
Loading