189 lines
5.4 KiB
Python
189 lines
5.4 KiB
Python
from typing import Union, Callable
|
|
|
|
import mysql.connector
|
|
import trino.dbapi
|
|
from trino.auth import BasicAuthentication
|
|
from trino.dbapi import connect
|
|
from sshtunnel import SSHTunnelForwarder
|
|
|
|
|
|
def singleton(class_):
|
|
instances = {}
|
|
|
|
def getinstance(*args, **kwargs):
|
|
if class_ not in instances:
|
|
instances[class_] = class_(*args, **kwargs)
|
|
return instances[class_]
|
|
|
|
return getinstance
|
|
|
|
|
|
@singleton
|
|
class MySSHTunnel:
|
|
def __init__(
|
|
self,
|
|
ssh_host,
|
|
ssh_port,
|
|
ssh_username,
|
|
ssh_pkey,
|
|
remote_host,
|
|
remote_port,
|
|
ssh_private_key_password=None,
|
|
):
|
|
self.tunnel = SSHTunnelForwarder(
|
|
ssh_host=(ssh_host, ssh_port),
|
|
ssh_username=ssh_username,
|
|
ssh_pkey=ssh_pkey,
|
|
remote_bind_address=(remote_host, remote_port),
|
|
local_bind_address=("127.0.0.1", remote_port),
|
|
ssh_private_key_password=ssh_private_key_password,
|
|
)
|
|
|
|
def start(self):
|
|
self.tunnel.start()
|
|
|
|
def stop(self):
|
|
self.tunnel.stop()
|
|
|
|
|
|
def get_connection(connection_config: dict) -> Union[trino.dbapi.Connection, mysql.connector.MySQLConnection]:
|
|
"""
|
|
Pick the right way to build a connection and pass it the connection details.
|
|
|
|
:param connection_config: confi
|
|
:return:
|
|
"""
|
|
connection_builder = pick_connection_builder(connection_config["engine"])
|
|
connection = connection_builder(connection_config)
|
|
return connection
|
|
|
|
|
|
def get_possible_connection_builders() -> dict:
|
|
"""
|
|
Get the complete list of connection builders.
|
|
|
|
:return: a dict where the keys are the strings that identify each
|
|
connection engine and the values are the callable function that will return
|
|
a valid connection object.
|
|
"""
|
|
return {
|
|
"trino": get_connection_to_trino,
|
|
"mysql": get_connection_to_mysql,
|
|
}
|
|
|
|
|
|
def pick_connection_builder(connection_engine_name: str) -> Callable:
|
|
"""
|
|
Get a connection builder from a string name.
|
|
|
|
:param connection_engine_name: the string defining the connection engine.
|
|
:return: the connection builder function.
|
|
"""
|
|
possible_connection_builders = get_possible_connection_builders()
|
|
|
|
try:
|
|
connection_builder = possible_connection_builders[connection_engine_name]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Connection type {connection_engine_name} is unknown. Please review config."
|
|
)
|
|
|
|
return connection_builder
|
|
|
|
|
|
def get_connection_to_trino(connection_config: dict) -> trino.dbapi.Connection:
|
|
"""
|
|
Build a connection to Trino from the passed config.
|
|
|
|
:param connection_config: specifies host, port, etc.
|
|
:return: the connection object
|
|
"""
|
|
|
|
return connect(
|
|
host=connection_config["host"],
|
|
port=connection_config["port"],
|
|
user=connection_config["user"],
|
|
auth=BasicAuthentication(
|
|
connection_config["user"],
|
|
connection_config["password"],
|
|
),
|
|
http_scheme=connection_config["http_scheme"],
|
|
catalog=connection_config["catalog"],
|
|
schema=connection_config["schema"],
|
|
)
|
|
|
|
|
|
def get_connection_to_mysql(
|
|
connection_config,
|
|
) -> mysql.connector.connection.MySQLConnection:
|
|
"""
|
|
Build a connection to MySQL from the passed config.
|
|
|
|
:param connection_config: specifies host, port, etc.
|
|
:return: the connection object
|
|
"""
|
|
mysql_connection_host = connection_config["host"]
|
|
|
|
if connection_config.get("ssh_tunneling", {}).get("use_tunnel", None):
|
|
open_ssh_tunnel(connection_config)
|
|
mysql_connection_host = "127.0.0.1"
|
|
# If we open an SSH tunnel, we reference the local bind instead of the
|
|
# actual host
|
|
|
|
connection = mysql.connector.connect(
|
|
host=mysql_connection_host,
|
|
port=connection_config["port"],
|
|
user=connection_config["user"],
|
|
password=connection_config["password"],
|
|
database=connection_config["schema"],
|
|
)
|
|
|
|
return connection
|
|
|
|
|
|
def clean_up_connection(connection_config: dict) -> None:
|
|
"""
|
|
Perform any necessary connection clean up steps after the measuring session.
|
|
|
|
:param connection_config: the connection details.
|
|
:return: none.
|
|
"""
|
|
if connection_config["ssh_tunneling"]["use_tunnel"]:
|
|
close_ssh_tunnel()
|
|
|
|
|
|
def open_ssh_tunnel(connection_config: dict) -> None:
|
|
"""
|
|
Start an SSH tunnel with the passed details.
|
|
|
|
:param connection_config: the connection details.
|
|
:return: none.
|
|
"""
|
|
|
|
print(
|
|
f"""Opening up an SSH tunnel to {connection_config["ssh_tunneling"]["ssh_host"]}"""
|
|
)
|
|
MySSHTunnel(
|
|
ssh_host=connection_config["ssh_tunneling"]["ssh_host"],
|
|
ssh_port=connection_config["ssh_tunneling"]["ssh_port"],
|
|
ssh_username=connection_config["ssh_tunneling"]["ssh_username"],
|
|
ssh_pkey=connection_config["ssh_tunneling"]["path_to_key"],
|
|
remote_host=connection_config["host"],
|
|
remote_port=connection_config["port"],
|
|
ssh_private_key_password=connection_config["ssh_tunneling"].get(
|
|
"ssh_private_key_password",
|
|
None, # Since password is optional, we need a safe default
|
|
),
|
|
).start()
|
|
print("SSH tunnel is now open.")
|
|
|
|
|
|
def close_ssh_tunnel():
|
|
"""
|
|
Close the SSH tunnel. No details required because a singleton is being used.
|
|
|
|
:return: None
|
|
"""
|
|
print(f"Closing down the SSH tunnel...")
|
|
MySSHTunnel().stop()
|
|
print("SSH tunnel is now closed.")
|