# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Utility functions and classes.
"""
import inspect
import queue
import traceback
import xmlrpc.client as xmlrpc
from io import StringIO
from urllib.request import urlopen
from .constants import SAMP_STATUS_ERROR
from .errors import SAMPProxyError
def internet_on():
from . import conf
if not conf.use_internet:
return False
else:
try:
urlopen("http://google.com", timeout=1.0)
except Exception:
return False
else:
return True
__all__ = ["SAMPMsgReplierWrapper", "SAMPXXEServerProxy", "safe_xmlrpc_loads"]
def get_safe_parser(use_datetime=False, use_builtin_types=False):
"""
Return a safe XML parser and its associated unmarshaller.
"""
unmarshaller = xmlrpc.Unmarshaller(use_datetime, use_builtin_types)
parser = xmlrpc.ExpatParser(unmarshaller)
if hasattr(parser, "_parser"):
# Explicitly disable external entities to prevent XXE.
# While None is often the default, being explicit ensures security
# across different environments and expat builds.
parser._parser.ExternalEntityRefHandler = None
return parser, unmarshaller
[docs]
def safe_xmlrpc_loads(data, use_datetime=False, use_builtin_types=False):
"""
A secure replacement for `xmlrpc.client.loads` that prevents XXE.
"""
parser, unmarshaller = get_safe_parser(use_datetime, use_builtin_types)
parser.feed(data)
parser.close()
return unmarshaller.close(), unmarshaller.getmethodname()
[docs]
class SAMPXXEServerProxy(xmlrpc.ServerProxy):
"""
An XML-RPC server proxy that uses a safe transport to prevent XXE.
"""
def __init__(self, uri, *args, **kwargs):
if "transport" not in kwargs:
from urllib.parse import urlparse
parsed_uri = urlparse(uri)
base_class = (
xmlrpc.SafeTransport
if parsed_uri.scheme == "https"
else xmlrpc.Transport
)
class XXESafeTransport(base_class):
def getparser(self):
return get_safe_parser(self._use_datetime, self._use_builtin_types)
kwargs["transport"] = XXESafeTransport(
use_datetime=kwargs.get("use_datetime", False),
use_builtin_types=kwargs.get("use_builtin_types", False),
)
super().__init__(uri, *args, **kwargs)
def getattr_recursive(variable, attribute):
"""
Get attributes recursively.
"""
if "." in attribute:
top, remaining = attribute.split(".", 1)
return getattr_recursive(getattr(variable, top), remaining)
else:
return getattr(variable, attribute)
class _ServerProxyPoolMethod:
# some magic to bind an XML-RPC method to an RPC server.
# supports "nested" methods (e.g. examples.getStateName)
def __init__(self, proxies, name):
self.__proxies = proxies
self.__name = name
def __getattr__(self, name):
return _ServerProxyPoolMethod(self.__proxies, f"{self.__name}.{name}")
def __call__(self, *args, **kwrds):
proxy = self.__proxies.get()
function = getattr_recursive(proxy, self.__name)
try:
response = function(*args, **kwrds)
except xmlrpc.Fault as exc:
raise SAMPProxyError(exc.faultCode, exc.faultString)
finally:
self.__proxies.put(proxy)
return response
class ServerProxyPool:
"""
A thread-safe pool of `xmlrpc.ServerProxy` objects.
"""
def __init__(self, size, proxy_class, *args, **keywords):
self._proxies = queue.Queue(size)
for _ in range(size):
self._proxies.put(proxy_class(*args, **keywords))
def __getattr__(self, name):
# magic method dispatcher
return _ServerProxyPoolMethod(self._proxies, name)
def shutdown(self):
"""Shut down the proxy pool by closing all active connections."""
while True:
try:
proxy = self._proxies.get_nowait()
except queue.Empty:
break
# An undocumented but apparently supported way to call methods on
# an ServerProxy that are not dispatched to the remote server
proxy("close")
[docs]
class SAMPMsgReplierWrapper:
"""
Function decorator that allows to automatically grab errors and returned
maps (if any) from a function bound to a SAMP call (or notify).
Parameters
----------
cli : :class:`~astropy.samp.SAMPIntegratedClient` or :class:`~astropy.samp.SAMPClient`
SAMP client instance. Decorator initialization, accepting the instance
of the client that receives the call or notification.
"""
def __init__(self, cli):
self.cli = cli
[docs]
def __call__(self, f):
def wrapped_f(*args):
if get_num_args(f) == 5 or args[2] is None: # notification
f(*args)
else: # call
try:
result = f(*args)
if result:
self.cli.hub.reply(
self.cli.get_private_key(),
args[2],
{"samp.status": SAMP_STATUS_ERROR, "samp.result": result},
)
except Exception:
err = StringIO()
traceback.print_exc(file=err)
txt = err.getvalue()
self.cli.hub.reply(
self.cli.get_private_key(),
args[2],
{"samp.status": SAMP_STATUS_ERROR, "samp.result": {"txt": txt}},
)
return wrapped_f
class _HubAsClient:
def __init__(self, handler):
self._handler = handler
def __getattr__(self, name):
# magic method dispatcher
return _HubAsClientMethod(self._handler, name)
class _HubAsClientMethod:
def __init__(self, send, name):
self.__send = send
self.__name = name
def __getattr__(self, name):
return _HubAsClientMethod(self.__send, f"{self.__name}.{name}")
def __call__(self, *args):
return self.__send(self.__name, args)
def get_num_args(f):
"""
Find the number of arguments a function or method takes (excluding ``self``).
"""
if inspect.ismethod(f):
return f.__func__.__code__.co_argcount - 1
elif inspect.isfunction(f):
return f.__code__.co_argcount
else:
raise TypeError("f should be a function or a method")