from typing import Union, Callable import mysql.connector import trino.dbapi from trino.auth import BasicAuthentication from trino.dbapi import connect from sshtunnel import SSHTunnelForwarder def singleton(class_): instances = {} def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance @singleton class MySSHTunnel: def __init__( self, ssh_host, ssh_port, ssh_username, ssh_pkey, remote_host, remote_port ): self.tunnel = SSHTunnelForwarder( ssh_host=(ssh_host, ssh_port), ssh_username=ssh_username, ssh_pkey=ssh_pkey, remote_bind_address=(remote_host, remote_port), local_bind_address=("127.0.0.1", remote_port), ) def start(self): self.tunnel.start() def stop(self): self.tunnel.stop() def get_connection(connection_config: dict) -> Union[trino.dbapi.Connection]: """ Pick the right way to build a connection and pass it the connection details. :param connection_config: confi :return: """ connection_builder = pick_connection_builder(connection_config["engine"]) connection = connection_builder(connection_config) return connection def get_possible_connection_builders() -> dict: """ Get the complete list of connection builders. :return: a dict where the keys are the strings that identify each connection engine and the values are the callable function that will return a valid connection object. """ return { "trino": get_connection_to_trino, "mysql": get_connection_to_mysql, } def pick_connection_builder(connection_engine_name: str) -> Callable: """ Get a connection builder from a string name. :param connection_engine_name: the string defining the connection engine. :return: the connection builder function. """ possible_connection_builders = get_possible_connection_builders() try: connection_builder = possible_connection_builders[connection_engine_name] except KeyError: raise ValueError( f"Connection type {connection_engine_name} is unknown. Please review config." ) return connection_builder def get_connection_to_trino(connection_config: dict) -> trino.dbapi.Connection: """ Build a connection to Trino from the passed config. :param connection_config: specifies host, port, etc. :return: the connection object """ return connect( host=connection_config["host"], port=connection_config["port"], user=connection_config["user"], auth=BasicAuthentication( connection_config["user"], connection_config["password"], ), http_scheme=connection_config["http_scheme"], catalog=connection_config["catalog"], schema=connection_config["schema"], ) def get_connection_to_mysql( connection_config, ) -> mysql.connector.connection.MySQLConnection: """ Build a connection to MySQL from the passed config. :param connection_config: specifies host, port, etc. :return: the connection object """ connection = mysql.connector.connect( host=connection_config["host"], port=connection_config["port"], user=connection_config["user"], password=connection_config["password"], database=connection_config["schema"], ) return connection def clean_up_connection(connection_config: dict) -> None: """ Perform any necessary connection clean up steps after the measuring session. :param connection_config: the connection details. :return: none. """ if connection_config["ssh_tunneling"]["use_tunnel"]: close_ssh_tunnel() def open_ssh_tunnel(connection_config: dict) -> None: """ Start an SSH tunnel with the passed details. :param connection_config: the connection details. :return: none. """ print( f"""Opening up an SSH tunnel to {connection_config["ssh_tunneling"]["ssh_host"]}""" ) MySSHTunnel( ssh_host=connection_config["ssh_tunneling"]["ssh_host"], ssh_port=connection_config["ssh_tunneling"]["ssh_port"], ssh_username=connection_config["ssh_tunneling"]["ssh_username"], ssh_pkey=connection_config["ssh_tunneling"]["path_to_key"], remote_host=connection_config["host"], remote_port=connection_config["port"], ).start() print("SSH tunnel is now open.") def close_ssh_tunnel(): """ Close the SSH tunnel. No details required because a singleton is being used. :return: None """ print(f"Closing down the SSH tunnel...") MySSHTunnel().stop() print("SSH tunnel is now closed.")