query-performance-gauge/connections.py

190 lines
5.4 KiB
Python
Raw Normal View History

from typing import Union, Callable
import mysql.connector
import trino.dbapi
from trino.auth import BasicAuthentication
from trino.dbapi import connect
from sshtunnel import SSHTunnelForwarder
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class MySSHTunnel:
def __init__(
self,
ssh_host,
ssh_port,
ssh_username,
ssh_pkey,
remote_host,
remote_port,
ssh_private_key_password=None,
):
self.tunnel = SSHTunnelForwarder(
ssh_host=(ssh_host, ssh_port),
ssh_username=ssh_username,
ssh_pkey=ssh_pkey,
remote_bind_address=(remote_host, remote_port),
local_bind_address=("127.0.0.1", remote_port),
ssh_private_key_password=ssh_private_key_password,
)
def start(self):
self.tunnel.start()
def stop(self):
self.tunnel.stop()
2022-08-22 13:45:46 +02:00
def get_connection(connection_config: dict) -> Union[trino.dbapi.Connection, mysql.connector.MySQLConnection]:
"""
Pick the right way to build a connection and pass it the connection details.
:param connection_config: confi
:return:
"""
connection_builder = pick_connection_builder(connection_config["engine"])
connection = connection_builder(connection_config)
return connection
def get_possible_connection_builders() -> dict:
"""
Get the complete list of connection builders.
:return: a dict where the keys are the strings that identify each
connection engine and the values are the callable function that will return
a valid connection object.
"""
return {
"trino": get_connection_to_trino,
"mysql": get_connection_to_mysql,
}
def pick_connection_builder(connection_engine_name: str) -> Callable:
"""
Get a connection builder from a string name.
:param connection_engine_name: the string defining the connection engine.
:return: the connection builder function.
"""
possible_connection_builders = get_possible_connection_builders()
try:
connection_builder = possible_connection_builders[connection_engine_name]
except KeyError:
raise ValueError(
f"Connection type {connection_engine_name} is unknown. Please review config."
)
return connection_builder
def get_connection_to_trino(connection_config: dict) -> trino.dbapi.Connection:
"""
Build a connection to Trino from the passed config.
:param connection_config: specifies host, port, etc.
:return: the connection object
"""
return connect(
host=connection_config["host"],
port=connection_config["port"],
user=connection_config["user"],
auth=BasicAuthentication(
connection_config["user"],
connection_config["password"],
),
http_scheme=connection_config["http_scheme"],
catalog=connection_config["catalog"],
schema=connection_config["schema"],
)
def get_connection_to_mysql(
connection_config,
) -> mysql.connector.connection.MySQLConnection:
"""
Build a connection to MySQL from the passed config.
:param connection_config: specifies host, port, etc.
:return: the connection object
"""
mysql_connection_host = connection_config["host"]
if connection_config.get("ssh_tunneling", {}).get("use_tunnel", None):
2022-07-21 18:31:14 +02:00
open_ssh_tunnel(connection_config)
mysql_connection_host = "127.0.0.1"
# If we open an SSH tunnel, we reference the local bind instead of the
# actual host
2022-07-21 18:31:14 +02:00
connection = mysql.connector.connect(
2022-07-21 18:31:14 +02:00
host=mysql_connection_host,
port=connection_config["port"],
user=connection_config["user"],
password=connection_config["password"],
database=connection_config["schema"],
)
return connection
def clean_up_connection(connection_config: dict) -> None:
"""
Perform any necessary connection clean up steps after the measuring session.
:param connection_config: the connection details.
:return: none.
"""
if connection_config["ssh_tunneling"]["use_tunnel"]:
close_ssh_tunnel()
def open_ssh_tunnel(connection_config: dict) -> None:
"""
Start an SSH tunnel with the passed details.
:param connection_config: the connection details.
:return: none.
"""
print(
f"""Opening up an SSH tunnel to {connection_config["ssh_tunneling"]["ssh_host"]}"""
)
MySSHTunnel(
ssh_host=connection_config["ssh_tunneling"]["ssh_host"],
ssh_port=connection_config["ssh_tunneling"]["ssh_port"],
ssh_username=connection_config["ssh_tunneling"]["ssh_username"],
ssh_pkey=connection_config["ssh_tunneling"]["path_to_key"],
remote_host=connection_config["host"],
remote_port=connection_config["port"],
ssh_private_key_password=connection_config["ssh_tunneling"].get(
"ssh_private_key_password",
None, # Since password is optional, we need a safe default
),
).start()
print("SSH tunnel is now open.")
def close_ssh_tunnel():
"""
Close the SSH tunnel. No details required because a singleton is being used.
:return: None
"""
print(f"Closing down the SSH tunnel...")
MySSHTunnel().stop()
print("SSH tunnel is now closed.")