diff --git a/api/data_handler.py b/api/data_handler.py index f30d9dc..8225f57 100644 --- a/api/data_handler.py +++ b/api/data_handler.py @@ -6,8 +6,11 @@ Licence: MIT (see LICENSE) """ +import os import os.path import hashlib +import json +import time from urllib.request import urlopen, urlretrieve, HTTPError from urllib.parse import urlparse, urlunparse import zipfile @@ -18,6 +21,134 @@ from . import settings +class CacheManager: + """ + Manages the downloaded file cache with size limits and expiration. + Tracks file access metadata in a JSON index file. + """ + + def __init__(self): + self.cache_dir = getattr(settings, "DOWNLOADED_FILE_CACHE_DIR", "") + self.max_size_bytes = getattr(settings, "CACHE_MAX_SIZE_GB", 10) * 1024**3 + self.expiry_seconds = getattr(settings, "CACHE_FILE_EXPIRY_DAYS", 30) * 86400 + self.index_path = os.path.join(self.cache_dir, "cache_index.json") + os.makedirs(self.cache_dir, exist_ok=True) + self._index = self._load_index() + + def _load_index(self): + if os.path.exists(self.index_path): + try: + with open(self.index_path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return {} + return {} + + def _save_index(self): + try: + with open(self.index_path, "w") as f: + json.dump(self._index, f, indent=2) + except IOError: + pass + + def record_access(self, file_path, url): + """Record that a cached file was accessed.""" + key = os.path.relpath(file_path, self.cache_dir) + now = time.time() + if key not in self._index: + self._index[key] = { + "url": url, + "downloaded_at": now, + "last_accessed": now, + "size": os.path.getsize(file_path) if os.path.exists(file_path) else 0, + } + else: + self._index[key]["last_accessed"] = now + self._save_index() + + def is_expired(self, file_path): + """Check if a cached file has expired.""" + key = os.path.relpath(file_path, self.cache_dir) + if key in self._index: + downloaded_at = self._index[key].get("downloaded_at", 0) + return (time.time() - downloaded_at) > self.expiry_seconds + return False + + def get_cache_size(self): + """Calculate total size of all cached files in bytes.""" + total = 0 + for dirpath, dirnames, filenames in os.walk(self.cache_dir): + for f in filenames: + if f == "cache_index.json": + continue + fp = os.path.join(dirpath, f) + total += os.path.getsize(fp) + return total + + def cleanup_expired(self): + """Remove files that have exceeded the expiry period.""" + now = time.time() + expired_keys = [] + for key, meta in self._index.items(): + if (now - meta.get("downloaded_at", 0)) > self.expiry_seconds: + expired_keys.append(key) + for key in expired_keys: + file_path = os.path.join(self.cache_dir, key) + if os.path.exists(file_path): + os.remove(file_path) + del self._index[key] + if expired_keys: + self._save_index() + self._cleanup_empty_dirs() + + def cleanup_by_size(self): + """Remove least recently accessed files until cache is under size limit.""" + while self.get_cache_size() > self.max_size_bytes: + if not self._index: + break + oldest_key = min( + self._index, + key=lambda k: self._index[k].get("last_accessed", 0), + ) + file_path = os.path.join(self.cache_dir, oldest_key) + if os.path.exists(file_path): + os.remove(file_path) + del self._index[oldest_key] + self._save_index() + self._cleanup_empty_dirs() + + def cleanup(self): + """Run all cleanup operations.""" + self.cleanup_expired() + self.cleanup_by_size() + + def _cleanup_empty_dirs(self): + """Remove empty subdirectories from cache.""" + for dirpath, dirnames, filenames in os.walk(self.cache_dir, topdown=False): + if dirpath == self.cache_dir: + continue + if not dirnames and not filenames: + os.rmdir(dirpath) + + def should_redownload(self, file_path): + """Check if a file needs to be re-downloaded (missing or expired).""" + if not os.path.exists(file_path): + return True + return self.is_expired(file_path) + + def invalidate(self, file_path): + """Remove a specific file from the cache and index.""" + key = os.path.relpath(file_path, self.cache_dir) + if key in self._index: + del self._index[key] + self._save_index() + if os.path.exists(file_path): + os.remove(file_path) + + +cache_manager = CacheManager() + + def get_base_url_and_path(url): """ Strip off any file name from a URL, and return the @@ -60,12 +191,8 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None): io_mode = getattr(io_cls, "rawmode", None) if io_mode == "one-dir": if not resolved_url.endswith(".zip"): - # In general, we don't know the names of the individual files - # and have no way to get a directory listing from a URL - # so we raise an exception if io_cls.__name__ in ("PhyIO"): - # for the exceptions, resolved_url must represent a directory - raise NotImplementedError # todo: for these ios, the file names are known + raise NotImplementedError else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -75,16 +202,11 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None): ) ) elif io_mode == "multi-file": - # Here the resolved_url represents a single file, with or without the file extension. - # By taking the base/root path and adding various extensions we get a list of files to download for extension in io_cls.extensions: file_list.append( - # Neo doesn't tell us which files are required and which are optional - # so we have to treat them all as optional at this stage (f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", False) ) elif io_cls.__name__ == "BrainVisionIO": - # in should io_mode be "multi-file" for this? currently "one-file" for extension in ("eeg", "vmrk"): file_list.append( (f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", True) @@ -95,10 +217,6 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None): (f"{base_url}/{root_path}.{extension}", f"{cache_dir}/{root_path}.{extension}", True) ) elif io_mode == "one-file": - # Here the resolved url should represent a single file, - # which could have different possible extensions - # todo: check the URL extension matches one of the possible extensions - # and raise an exception otherwise pass elif io_cls.mode == "dir": raise HTTPException( @@ -109,27 +227,22 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None): ) ) else: - # we assume the resolved url represents a single file - # certain IOs have additional metadata files if io_cls.__name__ == "AsciiSignalIO": - # if we have a text file, try to download the accompanying json file name, ext = os.path.splitext(main_file) - if ext[1:] in neo.io.AsciiSignalIO.extensions: # ext has a leading '.' + if ext[1:] in neo.io.AsciiSignalIO.extensions: metadata_filename = main_file.replace(ext, "_about.json") metadata_url = resolved_url.replace(ext, "_about.json") file_list.append((metadata_url, f"{cache_dir}/{metadata_filename}", False)) return file_list -def download_neo_data(url, io_cls=None): +def download_neo_data(url, io_cls=None, refresh_cache=False): """ Download a neo data file from the given URL. - We do not at present handle formats that require multiple files, - for which the URL should probably point to a zip or tar archive. + If refresh_cache is True, any previously cached version of the file + will be invalidated and the file will be re-downloaded from the source. """ - # we first open the url to resolve any redirects and have a consistent - # address for caching. try: response = urlopen(url) except HTTPError as err: @@ -140,16 +253,19 @@ def download_neo_data(url, io_cls=None): resolved_url = response.geturl() cache_dir, main_file = get_cache_path(resolved_url) - if not os.path.exists(os.path.join(cache_dir, main_file)): + main_file_path = os.path.join(cache_dir, main_file) + cache_manager.cleanup() + if refresh_cache: + cache_manager.invalidate(main_file_path) + if cache_manager.should_redownload(main_file_path): files_to_download = list_files_to_download(resolved_url, cache_dir, io_cls) for file_url, file_path, required in files_to_download: try: urlretrieve(file_url, file_path) except HTTPError: if required: - # todo: may not be a 404, could also be a 500 if local disk is full raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, # maybe use 501 Not Implemented? + status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem downloading '{file_url}'" ) main_path = files_to_download[0][1] @@ -157,6 +273,7 @@ def download_neo_data(url, io_cls=None): main_path = os.path.join(cache_dir, main_file) if main_path.endswith(".zip"): main_path = get_archive_dir(main_path, cache_dir) + cache_manager.record_access(main_path, url) return main_path @@ -167,9 +284,6 @@ def get_archive_dir(archive_path, cache_dir): main_path = os.path.join(cache_dir, dir_name) if not os.path.exists(main_path): zf.extractall(path=cache_dir) - # we are assuming the zipfile unpacks to a single directory - # todo: check this is the case, and if not either raise an Exception - # or create our own directory to unpack in to return main_path @@ -179,19 +293,22 @@ def get_archive_dir(archive_path, cache_dir): } } -def load_blocks(url, io_class_name=None): + +def load_blocks(url, io_class_name=None, refresh_cache=False): """ Load the first block from the data file at the given URL. If io_class_name is provided, we use the Neo IO class with that name to open the file, otherwise we use Neo's `get_io()` function to find an appropriate class. + + If refresh_cache is True, any cached version of the file will be + invalidated and re-downloaded from the source. """ assert isinstance(url, str) - # todo: handle formats with multiple files, or with a directory if io_class_name: io_cls = getattr(neo.io, io_class_name.value) - main_path = download_neo_data(url, io_cls=io_cls) + main_path = download_neo_data(url, io_cls=io_cls, refresh_cache=refresh_cache) try: if io_cls.mode == "dir": io = io_cls(dirname=main_path) @@ -201,10 +318,10 @@ def load_blocks(url, io_class_name=None): io = io_cls(filename=main_path) except ImportError: raise HTTPException( - status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, # maybe use 501 Not Implemented? + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=f"This server does not have the {io_class_name} module installed.", ) - except (RuntimeError, TypeError, OSError) as err: # RuntimeError from NixIO, TypeError from TdtIO, OSError from EDFIO + except (RuntimeError, TypeError, OSError) as err: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f'Error when trying to open file with {io_class_name}: "{err}"', @@ -215,8 +332,7 @@ def load_blocks(url, io_class_name=None): detail=f'Associated file not found. More details: "{err}"' ) else: - # todo: handle IOError, if none of the IO classes work - main_path = download_neo_data(url) + main_path = download_neo_data(url, refresh_cache=refresh_cache) io = neo.io.get_io(main_path) try: @@ -232,4 +348,4 @@ def load_blocks(url, io_class_name=None): ) if hasattr(io, "close"): io.close() - return blocks + return blocks \ No newline at end of file diff --git a/api/resources/v1.py b/api/resources/v1.py index a1bb894..57b678b 100644 --- a/api/resources/v1.py +++ b/api/resources/v1.py @@ -49,14 +49,22 @@ async def get_block_data( ) ), ] = None, + refresh_cache: Annotated[ + bool, + Query( + description=( + "If true, any previously cached version of the file will be " + "invalidated and the file will be re-downloaded from the source." + ) + ), + ] = False, ) -> BlockContainer: """ Return metadata about all the blocks in a data file, including metadata about the segments within each block, but without any information about the data contained within each segment. """ - # here `url` is a Pydantic object, which we convert to a string - blocks = load_blocks(str(url), type) + blocks = load_blocks(str(url), type, refresh_cache=refresh_cache) return BlockContainer.from_neo(blocks, url) @@ -86,6 +94,15 @@ async def get_segment_data( ) ), ] = None, + refresh_cache: Annotated[ + bool, + Query( + description=( + "If true, any previously cached version of the file will be " + "invalidated and the file will be re-downloaded from the source." + ) + ), + ] = False, ) -> Segment: """ Return information about an individual Segment within a block, @@ -93,18 +110,18 @@ async def get_segment_data( but not the signal data themselves. """ try: - block = load_blocks(str(url), type)[block_id] + block = load_blocks(str(url), type, refresh_cache=refresh_cache)[block_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on block_id", # todo: improve this message in next API version + detail="IndexError on block_id", ) try: segment = block.segments[segment_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on segment_id", # todo: improve this message in next API version + detail="IndexError on segment_id", ) return Segment.from_neo(segment, url) @@ -145,21 +162,30 @@ async def get_analogsignal_data( ) ), ] = 1, + refresh_cache: Annotated[ + bool, + Query( + description=( + "If true, any previously cached version of the file will be " + "invalidated and the file will be re-downloaded from the source." + ) + ), + ] = False, ) -> AnalogSignal: """Get an analog signal from a given segment, including both data and metadata.""" try: - block = load_blocks(str(url), type)[block_id] + block = load_blocks(str(url), type, refresh_cache=refresh_cache)[block_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on block_id", # todo: improve this message in next API version + detail="IndexError on block_id", ) try: segment = block.segments[segment_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on segment_id", # todo: improve this message in next API version + detail="IndexError on segment_id", ) if len(segment.analogsignals) > 0: container = segment.analogsignals @@ -170,7 +196,7 @@ async def get_analogsignal_data( except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on analog_signal_id", # todo: improve this message in next API version + detail="IndexError on analog_signal_id", ) try: asig = AnalogSignal.from_neo(signal, down_sample_factor) @@ -208,20 +234,29 @@ async def get_spiketrain_data( ) ), ] = None, + refresh_cache: Annotated[ + bool, + Query( + description=( + "If true, any previously cached version of the file will be " + "invalidated and the file will be re-downloaded from the source." + ) + ), + ] = False, ) -> dict[str, SpikeTrain]: """Get the spike trains from a given segment, including both data and metadata.""" try: - block = load_blocks(str(url), type)[block_id] + block = load_blocks(str(url), type, refresh_cache=refresh_cache)[block_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on block_id", # todo: improve this message in next API version + detail="IndexError on block_id", ) try: segment = block.segments[segment_id] except IndexError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="IndexError on segment_id", # todo: improve this message in next API version + detail="IndexError on segment_id", ) return {str(i): SpikeTrain.from_neo(st) for i, st in enumerate(segment.spiketrains)}