SSHTunnel can now optionally take the password of the SSH key file.

This commit is contained in:
Pablo Martin 2022-08-19 16:57:14 +02:00
parent c9f81f1c07
commit 9db02c154b

View file

@ -21,15 +21,22 @@ def singleton(class_):
@singleton
class MySSHTunnel:
def __init__(
self, ssh_host, ssh_port, ssh_username, ssh_pkey, remote_host, remote_port
self,
ssh_host,
ssh_port,
ssh_username,
ssh_pkey,
remote_host,
remote_port,
ssh_private_key_password=None,
):
self.tunnel = SSHTunnelForwarder(
ssh_host=(ssh_host, ssh_port),
ssh_username=ssh_username,
ssh_pkey=ssh_pkey,
remote_bind_address=(remote_host, remote_port),
local_bind_address=("127.0.0.1", remote_port),
ssh_private_key_password=ssh_private_key_password,
)
def start(self):
@ -115,12 +122,13 @@ def get_connection_to_mysql(
:param connection_config: specifies host, port, etc.
:return: the connection object
"""
if connection_config["ssh_tunneling"]["use_tunnel"]:
mysql_connection_host = connection_config["host"]
if connection_config.get("ssh_tunneling", {}).get("use_tunnel", None):
open_ssh_tunnel(connection_config)
mysql_connection_host = "127.0.0.1"
if not connection_config["ssh_tunneling"]["use_tunnel"]:
mysql_connection_host = connection_config["host"]
# If we open an SSH tunnel, we reference the local bind instead of the
# actual host
connection = mysql.connector.connect(
host=mysql_connection_host,
@ -162,6 +170,10 @@ def open_ssh_tunnel(connection_config: dict) -> None:
ssh_pkey=connection_config["ssh_tunneling"]["path_to_key"],
remote_host=connection_config["host"],
remote_port=connection_config["port"],
ssh_private_key_password=connection_config["ssh_tunneling"].get(
"ssh_private_key_password",
None, # Since password is optional, we need a safe default
),
).start()
print("SSH tunnel is now open.")