Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions BetterHolidays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .days import Day, Holiday, TradingDay, PartialTradingDay, NonTradingDay
from .multi import get_market
from .markets import Market, NYSE, MARKETS
from .markets import Market, NYSE, NASDAQ, SSE, LSE, MARKETS

__all__ = [
"Day",
Expand All @@ -10,6 +10,9 @@
"NonTradingDay",
"MARKETS",
"NYSE",
"NASDAQ",
"SSE",
"LSE",
"Market",
"get_market"
]
"get_market",
]
34 changes: 19 additions & 15 deletions BetterHolidays/const.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import typing as t

DAYS = (
"MONDAY",
"TUESDAY",
"WEDNESDAY",
"THURSDAY",
"FRIDAY",
"SATURDAY",
"SUNDAY"
)
DAYS = ("MONDAY", "TUESDAY", "WEDNESDAY", "THURSDAY", "FRIDAY", "SATURDAY", "SUNDAY")

DAYS_MAP = {day: i for i, day in enumerate(DAYS)}
MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY = 0,1,2,3,4,5,6
DAYS_TYPE = t.Literal[0,1,2,3,4,5,6]
MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY = 0, 1, 2, 3, 4, 5, 6
DAYS_TYPE = t.Literal[0, 1, 2, 3, 4, 5, 6]

MONTHS = (
"JANUARY",
Expand All @@ -26,11 +18,24 @@
"SEPTEMBER",
"OCTOBER",
"NOVEMBER",
"DECEMBER"
"DECEMBER",
)
MONTHS_MAP = {month: i for i, month in enumerate(MONTHS, 1)}

JANUARY, FEBRUARY, MARCH, APRIL, MAY, JUNE, JULY, AUGUST, SEPTEMBER, OCTOBER, NOVEMBER, DECEMBER = range(1, 13)
(
JANUARY,
FEBRUARY,
MARCH,
APRIL,
MAY,
JUNE,
JULY,
AUGUST,
SEPTEMBER,
OCTOBER,
NOVEMBER,
DECEMBER,
) = range(1, 13)


DAYS_IN_MONTH = {
Expand All @@ -45,6 +50,5 @@
"SEPTEMBER": 30,
"OCTOBER": 31,
"NOVEMBER": 30,
"DECEMBER": 31
"DECEMBER": 31,
}

41 changes: 24 additions & 17 deletions BetterHolidays/days.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,39 @@
import dataclasses as dc
import datetime as dt
import peewee as pw
from . import utils

@dc.dataclass(frozen=True)
class Day:
class Day(pw.Model):
"""Base class representing a calendar day."""
date: dt.date

@dc.dataclass(frozen=True)
class Meta:
database = utils.get_db()

Comment on lines +7 to +9
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

utils.get_db() is called without the required argument

get_db(cache) expects a cache dict, but no argument is supplied, so importing this module raises
TypeError: get_db() missing 1 required positional argument: 'cache'.

Two minimal fixes:

-    class Meta:
-        database = utils.get_db()
+    class Meta:
+        # Provide a cache dict that contains a configured DB instance.
+        database = utils.get_db({"db": pw.SqliteDatabase(":memory:")})

or refactor utils.get_db to accept no arguments and return a singleton database.

Either way, the current code prevents the package from being imported.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class Meta:
database = utils.get_db()
class Meta:
# Provide a cache dict that contains a configured DB instance.
database = utils.get_db({"db": pw.SqliteDatabase(":memory:")})

date = pw.DateField(primary_key=True)
market = pw.CharField(null=True)

class Holiday(Day):
"""Represents a full holiday (market closed)."""
name: str

@dc.dataclass(frozen=True)
name = pw.CharField()


class TradingDay(Day):
"""Represents a full trading day with standard open/close times."""
open_time: dt.time
close_time: dt.time

@dc.dataclass(frozen=True)
open_time = pw.TimeField()
close_time = pw.TimeField()


class NonTradingDay(Day):
"""Represents a non-trading day (e.g. weekends)."""
pass

@dc.dataclass(frozen=True)

class PartialTradingDay(TradingDay, Holiday):
"""Represents a partial trading day (early close or late open)."""
name: str
early_close: bool = False
late_open: bool = False
early_close_reason: str = ""
late_open_reason: str = ""

early_close = pw.BooleanField(default=False)
late_open = pw.BooleanField(default=False)
early_close_reason = pw.CharField(default="")
late_open_reason = pw.CharField(default="")

utils.get_db().create_tables([Day, Holiday, TradingDay, NonTradingDay, PartialTradingDay])
12 changes: 9 additions & 3 deletions BetterHolidays/markets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from .market import Market

from .nyse import NYSE
from .nasdaq import NASDAQ
from .sse import SSE
from .lse import LSE

MARKETS:'dict[str, type[Market]]' = {
"NYSE": NYSE
MARKETS: "dict[str, type[Market]]" = {
"NYSE": NYSE,
"NASDAQ": NASDAQ,
"SSE": SSE,
"LSE": LSE,
}

__all__ = ["MARKETS", "Market", "NYSE"]
__all__ = ["MARKETS", "Market", "NYSE", "NASDAQ", "SSE", "LSE"]
63 changes: 41 additions & 22 deletions BetterHolidays/markets/cache.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,61 @@
import datetime as dt
import typing as t
from ..days import Day
from ..utils import NOT_SET
from ..utils import NOT_SET, get_db
import peewee as pw

T = t.TypeVar("T")

DB = get_db()


class Cache:
def __init__(self):
self.cache: 'dict[dt.date, Day]' = {}
def __init__(self, market: "str"):
self.market = market

def get(self, key: 'dt.date') -> 't.Optional[Day]':
return self.cache.get(key)
def get(self, key: "dt.date") -> "Day":
selection = Day.select().where(Day.date == key, Day.market == self.market)
print(selection)
return selection.get()

def set(self, key: 'dt.date', value: 'Day'):
self.cache[key] = value
def set(self, day: "Day"):
day.market = self.market
day.save()

def get_or_set(self, key: 'dt.date', func: 't.Callable[[int], None]') -> 'Day':
if key in self.cache:
def get_or_set(self, key: "dt.date", func: "t.Callable[[int], None]") -> "Day":
try:
return self.get(key)
func(key.year)
if key in self.cache:
except Exception:
func(key.year)

try:
return self.get(key)
raise ValueError("Cache miss")
except Exception:
raise ValueError(f"Could not find day {key} after fetching data")

def clear(self):
self.cache.clear()
Day.delete().where(Day.market == self.market).execute()

@t.overload
def pop(self, key:'dt.date') -> 'Day': ...
def pop(self, key: "dt.date") -> "Day": ...

@t.overload
def pop(self, key:'dt.date', default:'T') -> 't.Union[Day, T]': ...
def pop(self, key: "dt.date", default: "T") -> "t.Union[Day, T]": ...

def pop(self, key, default=NOT_SET):
if default == NOT_SET:
return self.cache.pop(key)

return self.cache.pop(key, default)

def __contains__(self, key: 'dt.date') -> bool:
return key in self.cache
query = Day.select().where(Day.market == self.market, Day.date == key)
try:
item = query.get()
except pw.DoesNotExist:
if default is NOT_SET:
raise KeyError(key) from None
return default

if default is NOT_SET:
item.delete_instance(recursive=False)
return item

def __contains__(self, key: "dt.date") -> bool:
return (
key in Day.select().where(Day.market == self.market, Day.date == key).get()
)
Comment on lines +58 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect implementation of contains

The current implementation tries to check if a date is within a Day object, which is incorrect. It should check if the query returns any results.

    def __contains__(self, key: "dt.date") -> bool:
-        return (
-            key in Day.select().where(Day.market == self.market, Day.date == key).get()
-        )
+        try:
+            Day.select().where(Day.market == self.market, Day.date == key).get()
+            return True
+        except pw.DoesNotExist:
+            return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __contains__(self, key: "dt.date") -> bool:
return (
key in Day.select().where(Day.market == self.market, Day.date == key).get()
)
def __contains__(self, key: "dt.date") -> bool:
try:
Day.select().where(Day.market == self.market, Day.date == key).get()
return True
except pw.DoesNotExist:
return False

Loading