From 515a4527af3e4845c437126268c325df90eaaa5a Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Thu, 13 Jun 2024 15:56:57 +0200 Subject: [PATCH] tests passing --- tests/tests_integration/test_get_rates.py | 86 +++++++++++------------ xexe/processes.py | 39 +++++----- xexe/rate_fetching.py | 5 +- xexe/rate_writing.py | 18 +++++ 4 files changed, 79 insertions(+), 69 deletions(-) diff --git a/tests/tests_integration/test_get_rates.py b/tests/tests_integration/test_get_rates.py index c41ec72..a3a59cc 100644 --- a/tests/tests_integration/test_get_rates.py +++ b/tests/tests_integration/test_get_rates.py @@ -78,50 +78,50 @@ def test_get_rates_dry_run_always_returns_42_as_rates(): runner = CliRunner() - with runner.isolated_filesystem(): - run_result = runner.invoke( - get_rates, - [ - "--start-date", - some_random_date.strftime("%Y-%m-%d"), - "--end-date", - (some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"), - "--currencies", - ",".join(some_random_currencies), - "--output", - "test_output.csv", - ], - ) + # with runner.isolated_filesystem(): + run_result = runner.invoke( + get_rates, + [ + "--start-date", + some_random_date.strftime("%Y-%m-%d"), + "--end-date", + (some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"), + "--currencies", + ",".join(some_random_currencies), + "--output", + "test_output.csv", + ], + ) - assert run_result.exit_code == 0 + assert run_result.exit_code == 0 - with open("test_output.csv", newline="") as csv_file: - reader = csv.DictReader(csv_file) - rows = list(reader) + with open("test_output.csv", newline="") as csv_file: + reader = csv.DictReader(csv_file) + rows = list(reader) + + # Ensure that the output contains the correct number of rows + expected_num_rows = 36 + assert ( + len(rows) == expected_num_rows + ), f"Expected {expected_num_rows} rows, but got {len(rows)}" + + # Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates + for row in rows: + assert row["rate"] in ( + "42", + "0.024", + "0.02", + "0", + "1", + ), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}" + assert row["rate_date"] in [ + (some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d") + for i in range(4) + ], f"Unexpected rate_date {row['rate_date']}" - # Ensure that the output contains the correct number of rows - expected_num_rows = 36 assert ( - len(rows) == expected_num_rows - ), f"Expected {expected_num_rows} rows, but got {len(rows)}" - - # Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates - for row in rows: - assert row["rate"] in ( - "42", - "0.024", - "0.02", - "0", - "1", - ), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}" - assert row["rate_date"] in [ - (some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d") - for i in range(4) - ], f"Unexpected rate_date {row['rate_date']}" - - assert ( - row["from_currency"] in some_random_currencies - ), f"Unexpected from_currency {row['from_currency']}" - assert ( - row["to_currency"] in some_random_currencies - ), f"Unexpected to_currency {row['to_currency']}" + row["from_currency"] in some_random_currencies + ), f"Unexpected from_currency {row['from_currency']}" + assert ( + row["to_currency"] in some_random_currencies + ), f"Unexpected to_currency {row['to_currency']}" diff --git a/xexe/processes.py b/xexe/processes.py index 020c918..0974989 100644 --- a/xexe/processes.py +++ b/xexe/processes.py @@ -6,6 +6,7 @@ from typing import List from money.currency import Currency from xecd_rates_client import XecdClient +from xexe.constants import RATES_SOURCES from xexe.exchange_rates import ExchangeRates, add_equal_rates, add_inverse_rates from xexe.rate_fetching import ( MockRateFetcher, @@ -13,7 +14,7 @@ from xexe.rate_fetching import ( XERateFetcher, build_rate_fetcher, ) -from xexe.rate_writing import CSVRateWriter, RateWriter +from xexe.rate_writing import CSVRateWriter, RateWriter, build_rate_writer from xexe.utils import DateRange, generate_currency_and_dates_combinations logger = logging.getLogger() @@ -67,27 +68,32 @@ def run_get_rates( ) -> None: logger.info("Getting rates") - process_state = GetRatesProcessState(output=output, ignore_warnings=ignore_warnings) + process_state = GetRatesProcessState(ignore_warnings=ignore_warnings) rates = obtain_rates_from_source( - process_state, rates_source=rates_source, date_range=date_range, currencies=currencies, + ignore_warnings=ignore_warnings, ) logger.info("Rates obtained.") if dry_run: logger.info("Dry run mode active. Not writing rates to output.") return - write_rates_to_output(process_state, rates) + write_rates_to_output(rates, output) logger.info("Rates written to output.") def obtain_rates_from_source( - process_state, rates_source: str, date_range: DateRange, currencies: List[Currency] + rates_source: str, + date_range: DateRange, + currencies: List[Currency], + ignore_warnings: bool, ) -> ExchangeRates: - rates_fetcher = build_rate_fetcher(rates_source) + rates_fetcher = build_rate_fetcher( + rates_source=rates_source, rate_sources_mapping=RATES_SOURCES + ) currency_and_date_combinations = generate_currency_and_dates_combinations( date_range=date_range, currencies=currencies @@ -96,7 +102,7 @@ def obtain_rates_from_source( large_api_call_planned = ( rates_fetcher.is_production_grade and len(currency_and_date_combinations) > 100 ) - if large_api_call_planned and not process_state.ignore_warnings: + if large_api_call_planned and not ignore_warnings: user_confirmation_string = "i understand" user_response = input( f"WARNING: you are about to execute a large call {len(currency_and_date_combinations)} to a metered API. Type '{user_confirmation_string}' to move forward: " @@ -130,25 +136,12 @@ def obtain_rates_from_source( return rates -def write_rates_to_output(process_state, rates): - rates_writer = process_state.get_writer() +def write_rates_to_output(rates, output) -> None: + rates_writer = build_rate_writer(output) logger.info("Attempting writing rates to output.") rates_writer.write_rates(rates) class GetRatesProcessState: - def __init__(self, output: str, ignore_warnings: bool) -> None: - self.writer = self._select_writer(output) + def __init__(self, ignore_warnings: bool) -> None: self.ignore_warnings = ignore_warnings - - @staticmethod - def _select_writer(output: str) -> CSVRateWriter: - output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv") - - if output_is_csv_file_path: - return CSVRateWriter(output_file_path=output) - - raise ValueError(f"Don't know how to handle passed output: {output}") - - def get_writer(self) -> RateWriter: - return self.writer diff --git a/xexe/rate_fetching.py b/xexe/rate_fetching.py index 17ae137..705f0f0 100644 --- a/xexe/rate_fetching.py +++ b/xexe/rate_fetching.py @@ -7,7 +7,6 @@ from money.currency import Currency, CurrencyHelper from money.money import Money from xecd_rates_client import XecdClient -from xexe.constants import RATES_SOURCES from xexe.exchange_rates import ExchangeRate @@ -82,5 +81,5 @@ class XERateFetcher(RateFetcher): ) -def build_rate_fetcher(rates_source: str): - return RATES_SOURCES[rates_source]() +def build_rate_fetcher(rates_source: str, rate_sources_mapping): + return rate_sources_mapping[rates_source]() diff --git a/xexe/rate_writing.py b/xexe/rate_writing.py index fc82805..a2ac25e 100644 --- a/xexe/rate_writing.py +++ b/xexe/rate_writing.py @@ -1,5 +1,6 @@ import csv import datetime +import logging import os import pathlib from abc import ABC, abstractmethod @@ -10,6 +11,8 @@ from psycopg2 import sql from xexe.constants import DWH_SCHEMA, DWH_TABLE from xexe.exchange_rates import ExchangeRates +logger = logging.getLogger() + class RateWriter(ABC): @abstractmethod @@ -130,3 +133,18 @@ class DWHRateWriter(RateWriter): finally: cursor.close() self.connection.autocommit = True # Reset autocommit to its default state + + +def build_rate_writer(output: str) -> RateWriter: + output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv") + + if output_is_csv_file_path: + logger.info("Creating CSV Rate Writer.") + return CSVRateWriter(output_file_path=output) + + output_is_dwh = bool(output == "dwh") + if output_is_dwh: + logger.info("Creating DWH Rate Writer.") + return DWHRateWriter() + + raise ValueError(f"Don't know how to handle passed output: {output}")