Source code for astropy.samp.utils

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Utility functions and classes.
"""

import inspect
import queue
import traceback
import xmlrpc.client as xmlrpc
from io import StringIO
from urllib.request import urlopen

from .constants import SAMP_STATUS_ERROR
from .errors import SAMPProxyError


def internet_on():
    from . import conf

    if not conf.use_internet:
        return False
    else:
        try:
            urlopen("http://google.com", timeout=1.0)
        except Exception:
            return False
        else:
            return True


__all__ = ["SAMPMsgReplierWrapper", "SAMPXXEServerProxy", "safe_xmlrpc_loads"]


def get_safe_parser(use_datetime=False, use_builtin_types=False):
    """
    Return a safe XML parser and its associated unmarshaller.
    """
    unmarshaller = xmlrpc.Unmarshaller(use_datetime, use_builtin_types)
    parser = xmlrpc.ExpatParser(unmarshaller)
    if hasattr(parser, "_parser"):
        # Explicitly disable external entities to prevent XXE.
        # While None is often the default, being explicit ensures security
        # across different environments and expat builds.
        parser._parser.ExternalEntityRefHandler = None
    return parser, unmarshaller


[docs] def safe_xmlrpc_loads(data, use_datetime=False, use_builtin_types=False): """ A secure replacement for `xmlrpc.client.loads` that prevents XXE. """ parser, unmarshaller = get_safe_parser(use_datetime, use_builtin_types) parser.feed(data) parser.close() return unmarshaller.close(), unmarshaller.getmethodname()
[docs] class SAMPXXEServerProxy(xmlrpc.ServerProxy): """ An XML-RPC server proxy that uses a safe transport to prevent XXE. """ def __init__(self, uri, *args, **kwargs): if "transport" not in kwargs: from urllib.parse import urlparse parsed_uri = urlparse(uri) base_class = ( xmlrpc.SafeTransport if parsed_uri.scheme == "https" else xmlrpc.Transport ) class XXESafeTransport(base_class): def getparser(self): return get_safe_parser(self._use_datetime, self._use_builtin_types) kwargs["transport"] = XXESafeTransport( use_datetime=kwargs.get("use_datetime", False), use_builtin_types=kwargs.get("use_builtin_types", False), ) super().__init__(uri, *args, **kwargs)
def getattr_recursive(variable, attribute): """ Get attributes recursively. """ if "." in attribute: top, remaining = attribute.split(".", 1) return getattr_recursive(getattr(variable, top), remaining) else: return getattr(variable, attribute) class _ServerProxyPoolMethod: # some magic to bind an XML-RPC method to an RPC server. # supports "nested" methods (e.g. examples.getStateName) def __init__(self, proxies, name): self.__proxies = proxies self.__name = name def __getattr__(self, name): return _ServerProxyPoolMethod(self.__proxies, f"{self.__name}.{name}") def __call__(self, *args, **kwrds): proxy = self.__proxies.get() function = getattr_recursive(proxy, self.__name) try: response = function(*args, **kwrds) except xmlrpc.Fault as exc: raise SAMPProxyError(exc.faultCode, exc.faultString) finally: self.__proxies.put(proxy) return response class ServerProxyPool: """ A thread-safe pool of `xmlrpc.ServerProxy` objects. """ def __init__(self, size, proxy_class, *args, **keywords): self._proxies = queue.Queue(size) for _ in range(size): self._proxies.put(proxy_class(*args, **keywords)) def __getattr__(self, name): # magic method dispatcher return _ServerProxyPoolMethod(self._proxies, name) def shutdown(self): """Shut down the proxy pool by closing all active connections.""" while True: try: proxy = self._proxies.get_nowait() except queue.Empty: break # An undocumented but apparently supported way to call methods on # an ServerProxy that are not dispatched to the remote server proxy("close")
[docs] class SAMPMsgReplierWrapper: """ Function decorator that allows to automatically grab errors and returned maps (if any) from a function bound to a SAMP call (or notify). Parameters ---------- cli : :class:`~astropy.samp.SAMPIntegratedClient` or :class:`~astropy.samp.SAMPClient` SAMP client instance. Decorator initialization, accepting the instance of the client that receives the call or notification. """ def __init__(self, cli): self.cli = cli
[docs] def __call__(self, f): def wrapped_f(*args): if get_num_args(f) == 5 or args[2] is None: # notification f(*args) else: # call try: result = f(*args) if result: self.cli.hub.reply( self.cli.get_private_key(), args[2], {"samp.status": SAMP_STATUS_ERROR, "samp.result": result}, ) except Exception: err = StringIO() traceback.print_exc(file=err) txt = err.getvalue() self.cli.hub.reply( self.cli.get_private_key(), args[2], {"samp.status": SAMP_STATUS_ERROR, "samp.result": {"txt": txt}}, ) return wrapped_f
class _HubAsClient: def __init__(self, handler): self._handler = handler def __getattr__(self, name): # magic method dispatcher return _HubAsClientMethod(self._handler, name) class _HubAsClientMethod: def __init__(self, send, name): self.__send = send self.__name = name def __getattr__(self, name): return _HubAsClientMethod(self.__send, f"{self.__name}.{name}") def __call__(self, *args): return self.__send(self.__name, args) def get_num_args(f): """ Find the number of arguments a function or method takes (excluding ``self``). """ if inspect.ismethod(f): return f.__func__.__code__.co_argcount - 1 elif inspect.isfunction(f): return f.__code__.co_argcount else: raise TypeError("f should be a function or a method")