Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 153 additions & 37 deletions api/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
Licence: MIT (see LICENSE)
"""

import os
import os.path
import hashlib
import json
import time
from urllib.request import urlopen, urlretrieve, HTTPError
from urllib.parse import urlparse, urlunparse
import zipfile
Expand All @@ -18,6 +21,134 @@
from . import settings


class CacheManager:
"""
Manages the downloaded file cache with size limits and expiration.
Tracks file access metadata in a JSON index file.
"""

def __init__(self):
self.cache_dir = getattr(settings, "DOWNLOADED_FILE_CACHE_DIR", "")
self.max_size_bytes = getattr(settings, "CACHE_MAX_SIZE_GB", 10) * 1024**3
self.expiry_seconds = getattr(settings, "CACHE_FILE_EXPIRY_DAYS", 30) * 86400
self.index_path = os.path.join(self.cache_dir, "cache_index.json")
os.makedirs(self.cache_dir, exist_ok=True)
self._index = self._load_index()

def _load_index(self):
if os.path.exists(self.index_path):
try:
with open(self.index_path, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
return {}

def _save_index(self):
try:
with open(self.index_path, "w") as f:
json.dump(self._index, f, indent=2)
except IOError:
pass

def record_access(self, file_path, url):
"""Record that a cached file was accessed."""
key = os.path.relpath(file_path, self.cache_dir)
now = time.time()
if key not in self._index:
self._index[key] = {
"url": url,
"downloaded_at": now,
"last_accessed": now,
"size": os.path.getsize(file_path) if os.path.exists(file_path) else 0,
}
else:
self._index[key]["last_accessed"] = now
self._save_index()

def is_expired(self, file_path):
"""Check if a cached file has expired."""
key = os.path.relpath(file_path, self.cache_dir)
if key in self._index:
downloaded_at = self._index[key].get("downloaded_at", 0)
return (time.time() - downloaded_at) > self.expiry_seconds
return False

def get_cache_size(self):
"""Calculate total size of all cached files in bytes."""
total = 0
for dirpath, dirnames, filenames in os.walk(self.cache_dir):
for f in filenames:
if f == "cache_index.json":
continue
fp = os.path.join(dirpath, f)
total += os.path.getsize(fp)
return total

def cleanup_expired(self):
"""Remove files that have exceeded the expiry period."""
now = time.time()
expired_keys = []
for key, meta in self._index.items():
if (now - meta.get("downloaded_at", 0)) > self.expiry_seconds:
expired_keys.append(key)
for key in expired_keys:
file_path = os.path.join(self.cache_dir, key)
if os.path.exists(file_path):
os.remove(file_path)
del self._index[key]
if expired_keys:
self._save_index()
self._cleanup_empty_dirs()

def cleanup_by_size(self):
"""Remove least recently accessed files until cache is under size limit."""
while self.get_cache_size() > self.max_size_bytes:
if not self._index:
break
oldest_key = min(
self._index,
key=lambda k: self._index[k].get("last_accessed", 0),
)
file_path = os.path.join(self.cache_dir, oldest_key)
if os.path.exists(file_path):
os.remove(file_path)
del self._index[oldest_key]
self._save_index()
self._cleanup_empty_dirs()

def cleanup(self):
"""Run all cleanup operations."""
self.cleanup_expired()
self.cleanup_by_size()

def _cleanup_empty_dirs(self):
"""Remove empty subdirectories from cache."""
for dirpath, dirnames, filenames in os.walk(self.cache_dir, topdown=False):
if dirpath == self.cache_dir:
continue
if not dirnames and not filenames:
os.rmdir(dirpath)

def should_redownload(self, file_path):
"""Check if a file needs to be re-downloaded (missing or expired)."""
if not os.path.exists(file_path):
return True
return self.is_expired(file_path)

def invalidate(self, file_path):
"""Remove a specific file from the cache and index."""
key = os.path.relpath(file_path, self.cache_dir)
if key in self._index:
del self._index[key]
self._save_index()
if os.path.exists(file_path):
os.remove(file_path)


cache_manager = CacheManager()


def get_base_url_and_path(url):
"""
Strip off any file name from a URL, and return the
Expand Down Expand Up @@ -60,12 +191,8 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None):
io_mode = getattr(io_cls, "rawmode", None)
if io_mode == "one-dir":
if not resolved_url.endswith(".zip"):
# In general, we don't know the names of the individual files
# and have no way to get a directory listing from a URL
# so we raise an exception
if io_cls.__name__ in ("PhyIO"):
# for the exceptions, resolved_url must represent a directory
raise NotImplementedError # todo: for these ios, the file names are known
raise NotImplementedError
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -75,16 +202,11 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None):
)
)
elif io_mode == "multi-file":
# Here the resolved_url represents a single file, with or without the file extension.
# By taking the base/root path and adding various extensions we get a list of files to download
for extension in io_cls.extensions:
file_list.append(
# Neo doesn't tell us which files are required and which are optional
# so we have to treat them all as optional at this stage
(f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", False)
)
elif io_cls.__name__ == "BrainVisionIO":
# in should io_mode be "multi-file" for this? currently "one-file"
for extension in ("eeg", "vmrk"):
file_list.append(
(f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", True)
Expand All @@ -95,10 +217,6 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None):
(f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", True)
)
elif io_mode == "one-file":
# Here the resolved url should represent a single file,
# which could have different possible extensions
# todo: check the URL extension matches one of the possible extensions
# and raise an exception otherwise
pass
elif io_cls.mode == "dir":
raise HTTPException(
Expand All @@ -109,27 +227,22 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None):
)
)
else:
# we assume the resolved url represents a single file
# certain IOs have additional metadata files
if io_cls.__name__ == "AsciiSignalIO":
# if we have a text file, try to download the accompanying json file
name, ext = os.path.splitext(main_file)
if ext[1:] in neo.io.AsciiSignalIO.extensions: # ext has a leading '.'
if ext[1:] in neo.io.AsciiSignalIO.extensions:
metadata_filename = main_file.replace(ext, "_about.json")
metadata_url = resolved_url.replace(ext, "_about.json")
file_list.append((metadata_url, f"{cache_dir}/{metadata_filename}", False))
return file_list


def download_neo_data(url, io_cls=None):
def download_neo_data(url, io_cls=None, refresh_cache=False):
"""
Download a neo data file from the given URL.

We do not at present handle formats that require multiple files,
for which the URL should probably point to a zip or tar archive.
If refresh_cache is True, any previously cached version of the file
will be invalidated and the file will be re-downloaded from the source.
"""
# we first open the url to resolve any redirects and have a consistent
# address for caching.
try:
response = urlopen(url)
except HTTPError as err:
Expand All @@ -140,23 +253,27 @@ def download_neo_data(url, io_cls=None):
resolved_url = response.geturl()

cache_dir, main_file = get_cache_path(resolved_url)
if not os.path.exists(os.path.join(cache_dir, main_file)):
main_file_path = os.path.join(cache_dir, main_file)
cache_manager.cleanup()
if refresh_cache:
cache_manager.invalidate(main_file_path)
if cache_manager.should_redownload(main_file_path):
files_to_download = list_files_to_download(resolved_url, cache_dir, io_cls)
for file_url, file_path, required in files_to_download:
try:
urlretrieve(file_url, file_path)
except HTTPError:
if required:
# todo: may not be a 404, could also be a 500 if local disk is full
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, # maybe use 501 Not Implemented?
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Problem downloading '{file_url}'"
)
main_path = files_to_download[0][1]
else:
main_path = os.path.join(cache_dir, main_file)
if main_path.endswith(".zip"):
main_path = get_archive_dir(main_path, cache_dir)
cache_manager.record_access(main_path, url)
return main_path


Expand All @@ -167,9 +284,6 @@ def get_archive_dir(archive_path, cache_dir):
main_path = os.path.join(cache_dir, dir_name)
if not os.path.exists(main_path):
zf.extractall(path=cache_dir)
# we are assuming the zipfile unpacks to a single directory
# todo: check this is the case, and if not either raise an Exception
# or create our own directory to unpack in to
return main_path


Expand All @@ -179,19 +293,22 @@ def get_archive_dir(archive_path, cache_dir):
}
}

def load_blocks(url, io_class_name=None):

def load_blocks(url, io_class_name=None, refresh_cache=False):
"""
Load the first block from the data file at the given URL.

If io_class_name is provided, we use the Neo IO class with that name
to open the file, otherwise we use Neo's `get_io()` function to
find an appropriate class.

If refresh_cache is True, any cached version of the file will be
invalidated and re-downloaded from the source.
"""
assert isinstance(url, str)
# todo: handle formats with multiple files, or with a directory
if io_class_name:
io_cls = getattr(neo.io, io_class_name.value)
main_path = download_neo_data(url, io_cls=io_cls)
main_path = download_neo_data(url, io_cls=io_cls, refresh_cache=refresh_cache)
try:
if io_cls.mode == "dir":
io = io_cls(dirname=main_path)
Expand All @@ -201,10 +318,10 @@ def load_blocks(url, io_class_name=None):
io = io_cls(filename=main_path)
except ImportError:
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, # maybe use 501 Not Implemented?
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"This server does not have the {io_class_name} module installed.",
)
except (RuntimeError, TypeError, OSError) as err: # RuntimeError from NixIO, TypeError from TdtIO, OSError from EDFIO
except (RuntimeError, TypeError, OSError) as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Error when trying to open file with {io_class_name}: "{err}"',
Expand All @@ -215,8 +332,7 @@ def load_blocks(url, io_class_name=None):
detail=f'Associated file not found. More details: "{err}"'
)
else:
# todo: handle IOError, if none of the IO classes work
main_path = download_neo_data(url)
main_path = download_neo_data(url, refresh_cache=refresh_cache)
io = neo.io.get_io(main_path)

try:
Expand All @@ -232,4 +348,4 @@ def load_blocks(url, io_class_name=None):
)
if hasattr(io, "close"):
io.close()
return blocks
return blocks
Loading
Loading