diff --git a/tests/tests_unit/test_input_handling.py b/tests/tests_unit/test_input_handling.py index 000c46d..865394c 100644 --- a/tests/tests_unit/test_input_handling.py +++ b/tests/tests_unit/test_input_handling.py @@ -4,6 +4,7 @@ import pathlib import pytest from money.currency import Currency +from xexe.currency_pair import CurrencyPair from xexe.inputs_handling import handle_get_rates_inputs from xexe.utils import DateRange @@ -142,7 +143,11 @@ def test_handle_input_rates_with_pairs_works_fine(): start_date=datetime.datetime.now().date(), end_date=datetime.datetime.now().date(), ), - "pairs": {"Pending real object"}, + "pairs": { + CurrencyPair(from_currency=Currency["USD"], to_currency=Currency["EUR"]), + CurrencyPair(from_currency=Currency["EUR"], to_currency=Currency["USD"]), + CurrencyPair(from_currency=Currency["GBP"], to_currency=Currency["ZAR"]), + }, "dry_run": False, "rates_source": "mock", "ignore_warnings": True, diff --git a/xexe/currency_pair.py b/xexe/currency_pair.py index d9a603b..5a89ab2 100644 --- a/xexe/currency_pair.py +++ b/xexe/currency_pair.py @@ -22,3 +22,9 @@ class CurrencyPair: return (self.from_currency == other.from_currency) and ( self.to_currency == other.to_currency ) + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.from_currency, self.to_currency)) diff --git a/xexe/inputs_handling.py b/xexe/inputs_handling.py index 0cf5a75..6a823e7 100644 --- a/xexe/inputs_handling.py +++ b/xexe/inputs_handling.py @@ -6,6 +6,7 @@ from typing import Union from money.currency import Currency from xexe.constants import DEFAULT_CURRENCIES, RATES_SOURCES +from xexe.currency_pair import CurrencyPair from xexe.utils import DateRange logger = logging.getLogger() @@ -14,11 +15,12 @@ logger = logging.getLogger() def handle_get_rates_inputs( start_date: Union[datetime.datetime, datetime.date], end_date: Union[datetime.datetime, datetime.date], - currencies: Union[None, str], dry_run: bool, rates_source: str, ignore_warnings: bool, output: Union[str, pathlib.Path], + currencies: Union[None, str] = None, + pairs: Union[None, str] = None, ): logger.info("Handling inputs.") @@ -27,14 +29,29 @@ def handle_get_rates_inputs( if date_range.end_date > datetime.datetime.today().date(): date_range.end_date = datetime.datetime.today().date() + if pairs: + if currencies: + logger.error(f"Received both currencies and pairs.") + logger.error(f"Currencies: '{currencies}'.") + logger.error(f"Pairs: '{pairs}'.") + raise ValueError("You can pass currencies or pairs, but not both.") + + pairs = { + CurrencyPair( + from_currency=Currency[str_pair[0:3]], + to_currency=Currency[str_pair[3:6]], + ) + for str_pair in pairs.split(",") + } + if currencies: # CLI input comes as a string of comma-separated currency codes currencies = {currency_code.strip() for currency_code in currencies.split(",")} tmp = {Currency(currency_code) for currency_code in currencies} currencies = tmp - if currencies is None or currencies == "": - logger.info("No currency list passed. Running for default currencies.") + if currencies is None or currencies == "" and not pairs: + logger.info("No currency list or pairs passed. Running for default currencies.") currencies = DEFAULT_CURRENCIES if rates_source not in RATES_SOURCES: @@ -49,12 +66,17 @@ def handle_get_rates_inputs( prepared_inputs = { "date_range": date_range, - "currencies": currencies, "dry_run": dry_run, "rates_source": rates_source, "ignore_warnings": ignore_warnings, "output": output, } + if currencies: + prepared_inputs["currencies"] = currencies + + if pairs: + prepared_inputs["pairs"] = pairs + logger.debug(prepared_inputs) return prepared_inputs