diff --git a/lolafect/lolaconfig.py b/lolafect/lolaconfig.py index 9ac66c0..786855a 100644 --- a/lolafect/lolaconfig.py +++ b/lolafect/lolaconfig.py @@ -69,6 +69,7 @@ class LolaConfig: ) self.TRINO_CREDENTIALS = None + self.SSH_TUNNEL_CREDENTIALS = None self._s3_reader = S3FileReader(s3_client=boto3.client("s3")) @@ -109,6 +110,29 @@ class LolaConfig: "port": env_data["trino_port"], } + def fetch_ssh_tunnel_credentials(self, s3_reader=None) -> None: + """ + Read the env file from S3 and store the SSH tunnel credentials. + + :param s3_reader: a client to fetch files from S3. + :return: None + """ + + if s3_reader is None: + s3_reader = self._s3_reader + + env_data = s3_reader.read_json_from_s3_file( + bucket=self.S3_BUCKET_NAME, key=self.ENV_FILE_PATH + ) + + self.SSH_TUNNEL_CREDENTIALS = { + "path_to_ssh_pkey": env_data["pt_ssh_pkey_path"], + "ssh_pkey_password": env_data["pt_ssh_pkey_passphrase"], + "ssh_username": env_data["pt_ssh_username"], + "ssh_port": env_data["pt_ssh_jumphost_port"], + "ssh_jumphost": env_data["pt_ssh_jumphost"], + } + def build_lolaconfig( flow_name: str, diff --git a/tests/test_lolaconfig.py b/tests/test_lolaconfig.py index 5d7189a..f42dd4b 100644 --- a/tests/test_lolaconfig.py +++ b/tests/test_lolaconfig.py @@ -50,3 +50,24 @@ def test_lolaconfig_fetches_trino_creds_properly(): lolaconfig.fetch_trino_credentials(s3_reader=fake_s3_reader) assert type(lolaconfig.TRINO_CREDENTIALS) is dict + + +def test_lolaconfig_fetches_ssh_tunnel_creds_properly(): + lolaconfig = LolaConfig(flow_name="some-flow") + + fake_s3_reader = SimpleNamespace() + + def mock_read_json_from_s3_file(bucket, key): + return { + "pt_ssh_pkey_path": "some-path", + "pt_ssh_pkey_passphrase": "some-password", + "pt_ssh_username": "some-username", + "pt_ssh_jumphost_port": "some-port", + "pt_ssh_jumphost": "some-jumphost", + } + + fake_s3_reader.read_json_from_s3_file = mock_read_json_from_s3_file + + lolaconfig.fetch_ssh_tunnel_credentials(s3_reader=fake_s3_reader) + + assert type(lolaconfig.SSH_TUNNEL_CREDENTIALS) is dict