# Licensed under a 3-clause BSD style license - see LICENSE.rst import abc import contextlib import re import warnings from collections import OrderedDict from operator import itemgetter import numpy as np __all__ = ['IORegistryError'] class IORegistryError(Exception): """Custom error for registry clashes. """ pass # ----------------------------------------------------------------------------- class _UnifiedIORegistryBase(metaclass=abc.ABCMeta): """Base class for registries in Astropy's Unified IO. This base class provides identification functions and miscellaneous utilities. For an example how to build a registry subclass we suggest :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables read-only registries. These higher-level subclasses will probably serve better as a baseclass, for instance :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both :class:`~astropy.io.registry.UnifiedInputRegistry` and :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both reading from and writing to files. .. versionadded:: 5.0 """ def __init__(self): # registry of identifier functions self._identifiers = OrderedDict() # what this class can do: e.g. 'read' &/or 'write' self._registries = dict() self._registries["identify"] = dict(attr="_identifiers", column="Auto-identify") self._registries_order = ("identify", ) # match keys in `_registries` # If multiple formats are added to one class the update of the docs is quite # expensive. Classes for which the doc update is temporarly delayed are added # to this set. self._delayed_docs_classes = set() @property def available_registries(self): """Available registries. Returns ------- ``dict_keys`` """ return self._registries.keys() def get_formats(self, data_class=None, filter_on=None): """ Get the list of registered formats as a `~astropy.table.Table`. Parameters ---------- data_class : class or None, optional Filter readers/writer to match data class (default = all classes). filter_on : str or None, optional Which registry to show. E.g. "identify" If None search for both. Default is None. Returns ------- format_table : :class:`~astropy.table.Table` Table of available I/O formats. Raises ------ ValueError If ``filter_on`` is not None nor a registry name. """ from astropy.table import Table # set up the column names colnames = ( "Data class", "Format", *[self._registries[k]["column"] for k in self._registries_order], "Deprecated") i_dataclass = colnames.index("Data class") i_format = colnames.index("Format") i_regstart = colnames.index(self._registries[self._registries_order[0]]["column"]) i_deprecated = colnames.index("Deprecated") # registries regs = set() for k in self._registries.keys() - {"identify"}: regs |= set(getattr(self, self._registries[k]["attr"])) format_classes = sorted(regs, key=itemgetter(0)) # the format classes from all registries except "identify" rows = [] for (fmt, cls) in format_classes: # see if can skip, else need to document in row if (data_class is not None and not self._is_best_match( data_class, cls, format_classes)): continue # flags for each registry has_ = {k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No" for k, v in self._registries.items()} # Check if this is a short name (e.g. 'rdb') which is deprecated in # favor of the full 'ascii.rdb'. ascii_format_class = ('ascii.' + fmt, cls) # deprecation flag deprecated = "Yes" if ascii_format_class in format_classes else "" # add to rows rows.append((cls.__name__, fmt, *[has_[n] for n in self._registries_order], deprecated)) # filter_on can be in self_registries_order or None if str(filter_on).lower() in self._registries_order: index = self._registries_order.index(str(filter_on).lower()) rows = [row for row in rows if row[i_regstart + index] == 'Yes'] elif filter_on is not None: raise ValueError('unrecognized value for "filter_on": {0}.\n' f'Allowed are {self._registries_order} and None.') # Sorting the list of tuples is much faster than sorting it after the # table is created. (#5262) if rows: # Indices represent "Data Class", "Deprecated" and "Format". data = list(zip(*sorted( rows, key=itemgetter(i_dataclass, i_deprecated, i_format)))) else: data = None # make table # need to filter elementwise comparison failure issue # https://github.com/numpy/numpy/issues/6784 with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=FutureWarning) format_table = Table(data, names=colnames) if not np.any(format_table['Deprecated'].data == 'Yes'): format_table.remove_column('Deprecated') return format_table @contextlib.contextmanager def delay_doc_updates(self, cls): """Contextmanager to disable documentation updates when registering reader and writer. The documentation is only built once when the contextmanager exits. .. versionadded:: 1.3 Parameters ---------- cls : class Class for which the documentation updates should be delayed. Notes ----- Registering multiple readers and writers can cause significant overhead because the documentation of the corresponding ``read`` and ``write`` methods are build every time. Examples -------- see for example the source code of ``astropy.table.__init__``. """ self._delayed_docs_classes.add(cls) yield self._delayed_docs_classes.discard(cls) for method in self._registries.keys() - {"identify"}: self._update__doc__(cls, method) # ========================================================================= # Identifier methods def register_identifier(self, data_format, data_class, identifier, force=False): """ Associate an identifier function with a specific data type. Parameters ---------- data_format : str The data format identifier. This is the string that is used to specify the data type when reading/writing. data_class : class The class of the object that can be written. identifier : function A function that checks the argument specified to `read` or `write` to determine whether the input can be interpreted as a table of type ``data_format``. This function should take the following arguments: - ``origin``: A string ``"read"`` or ``"write"`` identifying whether the file is to be opened for reading or writing. - ``path``: The path to the file. - ``fileobj``: An open file object to read the file's contents, or `None` if the file could not be opened. - ``*args``: Positional arguments for the `read` or `write` function. - ``**kwargs``: Keyword arguments for the `read` or `write` function. One or both of ``path`` or ``fileobj`` may be `None`. If they are both `None`, the identifier will need to work from ``args[0]``. The function should return True if the input can be identified as being of format ``data_format``, and False otherwise. force : bool, optional Whether to override any existing function if already present. Default is ``False``. Examples -------- To set the identifier based on extensions, for formats that take a filename as a first argument, you can do for example .. code-block:: python from astropy.io.registry import register_identifier from astropy.table import Table def my_identifier(*args, **kwargs): return isinstance(args[0], str) and args[0].endswith('.tbl') register_identifier('ipac', Table, my_identifier) unregister_identifier('ipac', Table) """ if not (data_format, data_class) in self._identifiers or force: self._identifiers[(data_format, data_class)] = identifier else: raise IORegistryError("Identifier for format '{}' and class '{}' is " 'already defined'.format(data_format, data_class.__name__)) def unregister_identifier(self, data_format, data_class): """ Unregister an identifier function Parameters ---------- data_format : str The data format identifier. data_class : class The class of the object that can be read/written. """ if (data_format, data_class) in self._identifiers: self._identifiers.pop((data_format, data_class)) else: raise IORegistryError("No identifier defined for format '{}' and class" " '{}'".format(data_format, data_class.__name__)) def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs): """Loop through identifiers to see which formats match. Parameters ---------- origin : str A string ``"read`` or ``"write"`` identifying whether the file is to be opened for reading or writing. data_class_required : object The specified class for the result of `read` or the class that is to be written. path : str or path-like or None The path to the file or None. fileobj : file-like or None. An open file object to read the file's contents, or ``None`` if the file could not be opened. args : sequence Positional arguments for the `read` or `write` function. Note that these must be provided as sequence. kwargs : dict-like Keyword arguments for the `read` or `write` function. Note that this parameter must be `dict`-like. Returns ------- valid_formats : list List of matching formats. """ valid_formats = [] for data_format, data_class in self._identifiers: if self._is_best_match(data_class_required, data_class, self._identifiers): if self._identifiers[(data_format, data_class)]( origin, path, fileobj, *args, **kwargs): valid_formats.append(data_format) return valid_formats # ========================================================================= # Utils def _get_format_table_str(self, data_class, filter_on): """``get_formats()``, without column "Data class", as a str.""" format_table = self.get_formats(data_class, filter_on) format_table.remove_column('Data class') format_table_str = '\n'.join(format_table.pformat(max_lines=-1)) return format_table_str def _is_best_match(self, class1, class2, format_classes): """ Determine if class2 is the "best" match for class1 in the list of classes. It is assumed that (class2 in classes) is True. class2 is the the best match if: - ``class1`` is a subclass of ``class2`` AND - ``class2`` is the nearest ancestor of ``class1`` that is in classes (which includes the case that ``class1 is class2``) """ if issubclass(class1, class2): classes = {cls for fmt, cls in format_classes} for parent in class1.__mro__: if parent is class2: # class2 is closest registered ancestor return True if parent in classes: # class2 was superceded return False return False def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs): """ Returns the first valid format that can be used to read/write the data in question. Mode can be either 'read' or 'write'. """ valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs) if len(valid_formats) == 0: format_table_str = self._get_format_table_str(cls, mode.capitalize()) raise IORegistryError("Format could not be identified based on the" " file name or contents, please provide a" " 'format' argument.\n" "The available formats are:\n" "{}".format(format_table_str)) elif len(valid_formats) > 1: return self._get_highest_priority_format(mode, cls, valid_formats) return valid_formats[0] def _get_highest_priority_format(self, mode, cls, valid_formats): """ Returns the reader or writer with the highest priority. If it is a tie, error. """ if mode == "read": format_dict = self._readers mode_loader = "reader" elif mode == "write": format_dict = self._writers mode_loader = "writer" best_formats = [] current_priority = - np.inf for format in valid_formats: try: _, priority = format_dict[(format, cls)] except KeyError: # We could throw an exception here, but get_reader/get_writer handle # this case better, instead maximally deprioritise the format. priority = - np.inf if priority == current_priority: best_formats.append(format) elif priority > current_priority: best_formats = [format] current_priority = priority if len(best_formats) > 1: raise IORegistryError("Format is ambiguous - options are: {}".format( ', '.join(sorted(valid_formats, key=itemgetter(0))) )) return best_formats[0] def _update__doc__(self, data_class, readwrite): """ Update the docstring to include all the available readers / writers for the ``data_class.read``/``data_class.write`` functions (respectively). Don't update if the data_class does not have the relevant method. """ # abort if method "readwrite" isn't on data_class if not hasattr(data_class, readwrite): return from .interface import UnifiedReadWrite FORMATS_TEXT = 'The available built-in formats are:' # Get the existing read or write method and its docstring class_readwrite_func = getattr(data_class, readwrite) if not isinstance(class_readwrite_func.__doc__, str): # No docstring--could just be test code, or possibly code compiled # without docstrings return lines = class_readwrite_func.__doc__.splitlines() # Find the location of the existing formats table if it exists sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line] if sep_indices: # Chop off the existing formats table, including the initial blank line chop_index = sep_indices[0] lines = lines[:chop_index] # Find the minimum indent, skipping the first line because it might be odd matches = [re.search(r'(\S)', line) for line in lines[1:]] left_indent = ' ' * min(match.start() for match in matches if match) # Get the available unified I/O formats for this class # Include only formats that have a reader, and drop the 'Data class' column format_table = self.get_formats(data_class, readwrite.capitalize()) format_table.remove_column('Data class') # Get the available formats as a table, then munge the output of pformat() # a bit and put it into the docstring. new_lines = format_table.pformat(max_lines=-1, max_width=80) table_rst_sep = re.sub('-', '=', new_lines[1]) new_lines[1] = table_rst_sep new_lines.insert(0, table_rst_sep) new_lines.append(table_rst_sep) # Check for deprecated names and include a warning at the end. if 'Deprecated' in format_table.colnames: new_lines.extend(['', 'Deprecated format names like ``aastex`` will be ' 'removed in a future version. Use the full ', 'name (e.g. ``ascii.aastex``) instead.']) new_lines = [FORMATS_TEXT, ''] + new_lines lines.extend([left_indent + line for line in new_lines]) # Depending on Python version and whether class_readwrite_func is # an instancemethod or classmethod, one of the following will work. if isinstance(class_readwrite_func, UnifiedReadWrite): class_readwrite_func.__class__.__doc__ = '\n'.join(lines) else: try: class_readwrite_func.__doc__ = '\n'.join(lines) except AttributeError: class_readwrite_func.__func__.__doc__ = '\n'.join(lines)