# Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause """ Handles all caching logic including: - Retrieving from cache - Saving to cache - Determining whether not certain items have expired and need to be refreshed """ from __future__ import annotations import json import logging import os from datetime import datetime, timezone from functools import wraps from pathlib import Path from typing import TYPE_CHECKING try: from platformdirs import user_cache_dir except ImportError: # pragma: no cover from .._vendor.appdirs import user_cache_dir from ..base.constants import APP_NAME, NOTICES_CACHE_FN, NOTICES_CACHE_SUBDIR from ..utils import ensure_dir_exists from .types import ChannelNoticeResponse if TYPE_CHECKING: from typing import Sequence from .types import ChannelNotice logger = logging.getLogger(__name__) def cached_response(func): @wraps(func) def wrapper(url: str, name: str): cache_dir = get_notices_cache_dir() cache_val = get_notice_response_from_cache(url, name, cache_dir) if cache_val: return cache_val return_value = func(url, name) if return_value is not None: write_notice_response_to_cache(return_value, cache_dir) return return_value return wrapper def is_notice_response_cache_expired( channel_notice_response: ChannelNoticeResponse, ) -> bool: """ This checks the contents of the cache response to see if it is expired. If for whatever reason we encounter an exception while parsing the individual messages, we assume an invalid cache and return true. """ now = datetime.now(timezone.utc) def is_channel_notice_expired(expired_at: datetime | None) -> bool: """If there is no "expired_at" field present assume it is expired.""" if expired_at is None: return True return expired_at < now return any( is_channel_notice_expired(chn.expired_at) for chn in channel_notice_response.notices ) @ensure_dir_exists def get_notices_cache_dir() -> Path: """Returns the location of the notices cache directory as a Path object""" cache_dir = user_cache_dir(APP_NAME, appauthor=APP_NAME) return Path(cache_dir).joinpath(NOTICES_CACHE_SUBDIR) def get_notices_cache_file() -> Path: """Returns the location of the notices cache file as a Path object""" cache_dir = get_notices_cache_dir() cache_file = cache_dir.joinpath(NOTICES_CACHE_FN) if not cache_file.is_file(): with open(cache_file, "w") as fp: fp.write("") return cache_file def get_notice_response_from_cache( url: str, name: str, cache_dir: Path ) -> ChannelNoticeResponse | None: """Retrieves a notice response object from cache if it exists.""" cache_key = ChannelNoticeResponse.get_cache_key(url, cache_dir) if os.path.isfile(cache_key): with open(cache_key) as fp: data = json.load(fp) chn_ntc_resp = ChannelNoticeResponse(url, name, data) if not is_notice_response_cache_expired(chn_ntc_resp): return chn_ntc_resp def write_notice_response_to_cache( channel_notice_response: ChannelNoticeResponse, cache_dir: Path ) -> None: """Writes our notice data to our local cache location.""" cache_key = ChannelNoticeResponse.get_cache_key( channel_notice_response.url, cache_dir ) with open(cache_key, "w") as fp: json.dump(channel_notice_response.json_data, fp) def mark_channel_notices_as_viewed( cache_file: Path, channel_notices: Sequence[ChannelNotice] ) -> None: """Insert channel notice into our database marking it as read.""" notice_ids = {chn.id for chn in channel_notices} with open(cache_file) as fp: contents: str = fp.read() contents_unique = set(filter(None, set(contents.splitlines()))) contents_new = contents_unique.union(notice_ids) # Save new version of cache file with open(cache_file, "w") as fp: fp.write("\n".join(contents_new)) def get_viewed_channel_notice_ids( cache_file: Path, channel_notices: Sequence[ChannelNotice] ) -> set[str]: """Return the ids of the channel notices which have already been seen.""" notice_ids = {chn.id for chn in channel_notices} with open(cache_file) as fp: contents: str = fp.read() contents_unique = set(filter(None, set(contents.splitlines()))) return notice_ids.intersection(contents_unique)