diff --git a/lolafect/connections.py b/lolafect/connections.py index 40c450f..250c389 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -5,6 +5,8 @@ from prefect import task from prefect.triggers import all_finished from trino.auth import BasicAuthentication import trino +import pymysql +import sshtunnel from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME @@ -57,3 +59,86 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None: logger.info("Trino connection closed successfully.") return logger.info("No connection received.") + + +@task(log_stdout=True, nout=2) +def connect_to_dw(use_ssh_tunnel): + print("Connecting to DW") + import pymysql + + mysql_host = LOLACONFIG.DW_CREDENTIALS["host"] + tunnel = None + if use_ssh_tunnel: + print("Going to open an SSH tunnel.") + from sshtunnel import SSHTunnelForwarder + + temp_file_path = "temp" + try: + boto3.client("s3").download_file( + LOLACONFIG.S3_BUCKET_NAME, + LOLACONFIG.SSH_TUNNEL_CREDENTIALS["path_to_ssh_pkey"], + temp_file_path, + ) + tunnel = SSHTunnelForwarder( + ssh_host=( + LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_jumphost"], + LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_port"], + ), + ssh_username=LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_username"], + ssh_pkey=temp_file_path, + remote_bind_address=( + LOLACONFIG.DW_CREDENTIALS["host"], + LOLACONFIG.DW_CREDENTIALS["port"], + ), + local_bind_address=("127.0.0.1", LOLACONFIG.DW_CREDENTIALS["port"]), + ssh_private_key_password=LOLACONFIG.SSH_TUNNEL_CREDENTIALS[ + "ssh_pkey_password" + ], + ) + except Exception as e: + raise e + finally: + # No matter what happens above, we always must delete the temp copy of the key + os.remove(temp_file_path) + + tunnel.start() + print("SSH tunnel is now open.") + mysql_host = "127.0.0.1" + + db_connection = pymysql.connect( + host=mysql_host, + port=LOLACONFIG.DW_CREDENTIALS["port"], + user=LOLACONFIG.DW_CREDENTIALS["user"], + password=LOLACONFIG.DW_CREDENTIALS["password"], + database="dw_xl", + ) + + # Customizing this attributes to retrieve them later in the flow + db_connection.raw_user = LOLACONFIG.DW_CREDENTIALS["user"] + db_connection.raw_password = LOLACONFIG.DW_CREDENTIALS["password"] + + print("Connected to DW.") + + return db_connection, tunnel + + +@task(log_stdout=True, trigger=all_finished) +def close_dw_connection(dw_connection): + import pymysql + + if isinstance(dw_connection, pymysql.Connection): + dw_connection.close() + print("DW connection closed successfully.") + return + print("No connection received.") + + +@task(log_stdout=True, trigger=all_finished) +def close_ssh_tunnel(ssh_tunnel): + from sshtunnel import SSHTunnelForwarder + + if isinstance(ssh_tunnel, SSHTunnelForwarder): + ssh_tunnel.stop() + print("SSH tunnel closed successfully.") + return + print("No connection received.") \ No newline at end of file