diff --git a/lolafect/connections.py b/lolafect/connections.py index caf3fdf..223c9f4 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -24,7 +24,7 @@ def connect_to_trino( :param trino_credentials: a dict with the host, port, user and password. :param http_schema: which http schema to use in the connection. - :return: + :return: the connection to trino. """ logger = prefect.context.get("logger") logger.info( @@ -96,7 +96,7 @@ def _temp_secret_file_from_s3( raise e finally: # Regardless of what happens in the context manager, we always delete the temp - # copy of the private key. + # copy of the secret file. os.remove(local_temp_file_path) @@ -187,6 +187,27 @@ def open_ssh_tunnel( return tunnel +@task(trigger=all_finished) +def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None: + """ + Close a SSH tunnel, or do nothing if something different is passed. + + :param tunnel: a SSH tunnel. + :return: + """ + logger = prefect.context.get("logger") + + if isinstance(tunnel, SSHTunnelForwarder): + tunnel.stop() + logger.info("SSH tunnel closed successfully.") + return + logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.") + raise DeprecationWarning( + "This method will only accept the type 'SSHTunnelForwarder' in next major release.\n" + "Please, update your code accordingly." + ) + + @task() def get_local_bind_address_from_ssh_tunnel( tunnel: SSHTunnelForwarder, @@ -270,24 +291,3 @@ def close_mysql_connection(connection: pymysql.Connection) -> None: "This method will only accept the type 'pymysql.Connection' in next major release.\n" "Please, update your code accordingly." ) - - -@task(trigger=all_finished) -def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None: - """ - Close a SSH tunnel, or do nothing if something different is passed. - - :param tunnel: a SSH tunnel. - :return: - """ - logger = prefect.context.get("logger") - - if isinstance(tunnel, SSHTunnelForwarder): - tunnel.stop() - logger.info("SSH tunnel closed successfully.") - return - logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.") - raise DeprecationWarning( - "This method will only accept the type 'SSHTunnelForwarder' in next major release.\n" - "Please, update your code accordingly." - )