tests passing

This commit is contained in:
Pablo Martin 2024-06-13 15:56:57 +02:00
parent de8828a9cb
commit 515a4527af
4 changed files with 79 additions and 69 deletions

View file

@ -6,6 +6,7 @@ from typing import List
from money.currency import Currency
from xecd_rates_client import XecdClient
from xexe.constants import RATES_SOURCES
from xexe.exchange_rates import ExchangeRates, add_equal_rates, add_inverse_rates
from xexe.rate_fetching import (
MockRateFetcher,
@ -13,7 +14,7 @@ from xexe.rate_fetching import (
XERateFetcher,
build_rate_fetcher,
)
from xexe.rate_writing import CSVRateWriter, RateWriter
from xexe.rate_writing import CSVRateWriter, RateWriter, build_rate_writer
from xexe.utils import DateRange, generate_currency_and_dates_combinations
logger = logging.getLogger()
@ -67,27 +68,32 @@ def run_get_rates(
) -> None:
logger.info("Getting rates")
process_state = GetRatesProcessState(output=output, ignore_warnings=ignore_warnings)
process_state = GetRatesProcessState(ignore_warnings=ignore_warnings)
rates = obtain_rates_from_source(
process_state,
rates_source=rates_source,
date_range=date_range,
currencies=currencies,
ignore_warnings=ignore_warnings,
)
logger.info("Rates obtained.")
if dry_run:
logger.info("Dry run mode active. Not writing rates to output.")
return
write_rates_to_output(process_state, rates)
write_rates_to_output(rates, output)
logger.info("Rates written to output.")
def obtain_rates_from_source(
process_state, rates_source: str, date_range: DateRange, currencies: List[Currency]
rates_source: str,
date_range: DateRange,
currencies: List[Currency],
ignore_warnings: bool,
) -> ExchangeRates:
rates_fetcher = build_rate_fetcher(rates_source)
rates_fetcher = build_rate_fetcher(
rates_source=rates_source, rate_sources_mapping=RATES_SOURCES
)
currency_and_date_combinations = generate_currency_and_dates_combinations(
date_range=date_range, currencies=currencies
@ -96,7 +102,7 @@ def obtain_rates_from_source(
large_api_call_planned = (
rates_fetcher.is_production_grade and len(currency_and_date_combinations) > 100
)
if large_api_call_planned and not process_state.ignore_warnings:
if large_api_call_planned and not ignore_warnings:
user_confirmation_string = "i understand"
user_response = input(
f"WARNING: you are about to execute a large call {len(currency_and_date_combinations)} to a metered API. Type '{user_confirmation_string}' to move forward: "
@ -130,25 +136,12 @@ def obtain_rates_from_source(
return rates
def write_rates_to_output(process_state, rates):
rates_writer = process_state.get_writer()
def write_rates_to_output(rates, output) -> None:
rates_writer = build_rate_writer(output)
logger.info("Attempting writing rates to output.")
rates_writer.write_rates(rates)
class GetRatesProcessState:
def __init__(self, output: str, ignore_warnings: bool) -> None:
self.writer = self._select_writer(output)
def __init__(self, ignore_warnings: bool) -> None:
self.ignore_warnings = ignore_warnings
@staticmethod
def _select_writer(output: str) -> CSVRateWriter:
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
if output_is_csv_file_path:
return CSVRateWriter(output_file_path=output)
raise ValueError(f"Don't know how to handle passed output: {output}")
def get_writer(self) -> RateWriter:
return self.writer

View file

@ -7,7 +7,6 @@ from money.currency import Currency, CurrencyHelper
from money.money import Money
from xecd_rates_client import XecdClient
from xexe.constants import RATES_SOURCES
from xexe.exchange_rates import ExchangeRate
@ -82,5 +81,5 @@ class XERateFetcher(RateFetcher):
)
def build_rate_fetcher(rates_source: str):
return RATES_SOURCES[rates_source]()
def build_rate_fetcher(rates_source: str, rate_sources_mapping):
return rate_sources_mapping[rates_source]()

View file

@ -1,5 +1,6 @@
import csv
import datetime
import logging
import os
import pathlib
from abc import ABC, abstractmethod
@ -10,6 +11,8 @@ from psycopg2 import sql
from xexe.constants import DWH_SCHEMA, DWH_TABLE
from xexe.exchange_rates import ExchangeRates
logger = logging.getLogger()
class RateWriter(ABC):
@abstractmethod
@ -130,3 +133,18 @@ class DWHRateWriter(RateWriter):
finally:
cursor.close()
self.connection.autocommit = True # Reset autocommit to its default state
def build_rate_writer(output: str) -> RateWriter:
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
if output_is_csv_file_path:
logger.info("Creating CSV Rate Writer.")
return CSVRateWriter(output_file_path=output)
output_is_dwh = bool(output == "dwh")
if output_is_dwh:
logger.info("Creating DWH Rate Writer.")
return DWHRateWriter()
raise ValueError(f"Don't know how to handle passed output: {output}")