Adapting references

This commit is contained in:
Pablo Martin 2023-01-23 15:49:54 +01:00
parent 9feda795e9
commit e4e1f42309

View file

@ -1,4 +1,6 @@
import datetime
import os
from typing import Tuple
import prefect
from prefect import task
@ -6,7 +8,8 @@ from prefect.triggers import all_finished
from trino.auth import BasicAuthentication
import trino
import pymysql
import sshtunnel
import boto3
from sshtunnel import SSHTunnelForwarder
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
@ -62,38 +65,42 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
@task(log_stdout=True, nout=2)
def connect_to_dw(use_ssh_tunnel):
print("Connecting to DW")
import pymysql
def connect_to_dw(
mysql_credentials: dict,
use_ssh_tunnel: bool,
s3_bucket_name,
ssh_tunnel_credentials,
) -> Tuple[pymysql.Connection, SSHTunnelForwarder]:
logger = prefect.context.get("logger")
logger.info(
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
)
mysql_host = LOLACONFIG.DW_CREDENTIALS["host"]
mysql_host = mysql_credentials["host"]
tunnel = None
if use_ssh_tunnel:
print("Going to open an SSH tunnel.")
from sshtunnel import SSHTunnelForwarder
logger.info("Going to open an SSH tunnel.")
temp_file_path = "temp"
try:
boto3.client("s3").download_file(
LOLACONFIG.S3_BUCKET_NAME,
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["path_to_ssh_pkey"],
s3_bucket_name,
ssh_tunnel_credentials["path_to_ssh_pkey"],
temp_file_path,
)
tunnel = SSHTunnelForwarder(
ssh_host=(
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_jumphost"],
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_port"],
ssh_tunnel_credentials["ssh_jumphost"],
ssh_tunnel_credentials["ssh_port"],
),
ssh_username=LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_username"],
ssh_username=ssh_tunnel_credentials["ssh_username"],
ssh_pkey=temp_file_path,
remote_bind_address=(
LOLACONFIG.DW_CREDENTIALS["host"],
LOLACONFIG.DW_CREDENTIALS["port"],
mysql_credentials["host"],
mysql_credentials["port"],
),
local_bind_address=("127.0.0.1", LOLACONFIG.DW_CREDENTIALS["port"]),
ssh_private_key_password=LOLACONFIG.SSH_TUNNEL_CREDENTIALS[
"ssh_pkey_password"
],
local_bind_address=("127.0.0.1", mysql_credentials["port"]),
ssh_private_key_password=mysql_credentials["ssh_pkey_password"],
)
except Exception as e:
raise e
@ -102,43 +109,45 @@ def connect_to_dw(use_ssh_tunnel):
os.remove(temp_file_path)
tunnel.start()
print("SSH tunnel is now open.")
logger.info(
f"SSH tunnel is now open and listening at{mysql_credentials['host']}:{mysql_credentials['port']}."
)
mysql_host = "127.0.0.1"
db_connection = pymysql.connect(
host=mysql_host,
port=LOLACONFIG.DW_CREDENTIALS["port"],
user=LOLACONFIG.DW_CREDENTIALS["user"],
password=LOLACONFIG.DW_CREDENTIALS["password"],
port=mysql_credentials["port"],
user=mysql_credentials["user"],
password=mysql_credentials["password"],
database="dw_xl",
)
# Customizing this attributes to retrieve them later in the flow
db_connection.raw_user = LOLACONFIG.DW_CREDENTIALS["user"]
db_connection.raw_password = LOLACONFIG.DW_CREDENTIALS["password"]
db_connection.raw_user = mysql_credentials["user"]
db_connection.raw_password = mysql_credentials["password"]
print("Connected to DW.")
logger.info(f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}.")
return db_connection, tunnel
@task(log_stdout=True, trigger=all_finished)
def close_dw_connection(dw_connection):
import pymysql
@task(trigger=all_finished)
def close_dw_connection(dw_connection: pymysql.Connection) -> None:
logger = prefect.context.get("logger")
if isinstance(dw_connection, pymysql.Connection):
dw_connection.close()
print("DW connection closed successfully.")
logger.info("DW connection closed successfully.")
return
print("No connection received.")
logger.info("No connection received.")
@task(log_stdout=True, trigger=all_finished)
def close_ssh_tunnel(ssh_tunnel):
from sshtunnel import SSHTunnelForwarder
def close_ssh_tunnel(ssh_tunnel: SSHTunnelForwarder) -> None:
logger = prefect.context.get("logger")
if isinstance(ssh_tunnel, SSHTunnelForwarder):
ssh_tunnel.stop()
print("SSH tunnel closed successfully.")
logger.info("SSH tunnel closed successfully.")
return
print("No connection received.")
logger.info("No connection received.")