-
Notifications
You must be signed in to change notification settings - Fork 1
More Markets #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
More Markets #3
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,39 @@ | ||
| import dataclasses as dc | ||
| import datetime as dt | ||
| import peewee as pw | ||
| from . import utils | ||
|
|
||
| @dc.dataclass(frozen=True) | ||
| class Day: | ||
| class Day(pw.Model): | ||
| """Base class representing a calendar day.""" | ||
| date: dt.date | ||
|
|
||
| @dc.dataclass(frozen=True) | ||
| class Meta: | ||
| database = utils.get_db() | ||
|
|
||
| date = pw.DateField(primary_key=True) | ||
| market = pw.CharField(null=True) | ||
|
|
||
| class Holiday(Day): | ||
| """Represents a full holiday (market closed).""" | ||
| name: str | ||
|
|
||
| @dc.dataclass(frozen=True) | ||
| name = pw.CharField() | ||
|
|
||
|
|
||
| class TradingDay(Day): | ||
| """Represents a full trading day with standard open/close times.""" | ||
| open_time: dt.time | ||
| close_time: dt.time | ||
|
|
||
| @dc.dataclass(frozen=True) | ||
| open_time = pw.TimeField() | ||
| close_time = pw.TimeField() | ||
|
|
||
|
|
||
| class NonTradingDay(Day): | ||
| """Represents a non-trading day (e.g. weekends).""" | ||
| pass | ||
|
|
||
| @dc.dataclass(frozen=True) | ||
|
|
||
| class PartialTradingDay(TradingDay, Holiday): | ||
| """Represents a partial trading day (early close or late open).""" | ||
| name: str | ||
| early_close: bool = False | ||
| late_open: bool = False | ||
| early_close_reason: str = "" | ||
| late_open_reason: str = "" | ||
|
|
||
| early_close = pw.BooleanField(default=False) | ||
| late_open = pw.BooleanField(default=False) | ||
| early_close_reason = pw.CharField(default="") | ||
| late_open_reason = pw.CharField(default="") | ||
|
|
||
| utils.get_db().create_tables([Day, Holiday, TradingDay, NonTradingDay, PartialTradingDay]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,15 @@ | ||
| from .market import Market | ||
|
|
||
| from .nyse import NYSE | ||
| from .nasdaq import NASDAQ | ||
| from .sse import SSE | ||
| from .lse import LSE | ||
|
|
||
| MARKETS:'dict[str, type[Market]]' = { | ||
| "NYSE": NYSE | ||
| MARKETS: "dict[str, type[Market]]" = { | ||
| "NYSE": NYSE, | ||
| "NASDAQ": NASDAQ, | ||
| "SSE": SSE, | ||
| "LSE": LSE, | ||
| } | ||
|
|
||
| __all__ = ["MARKETS", "Market", "NYSE"] | ||
| __all__ = ["MARKETS", "Market", "NYSE", "NASDAQ", "SSE", "LSE"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,42 +1,61 @@ | ||||||||||||||||||||||
| import datetime as dt | ||||||||||||||||||||||
| import typing as t | ||||||||||||||||||||||
| from ..days import Day | ||||||||||||||||||||||
| from ..utils import NOT_SET | ||||||||||||||||||||||
| from ..utils import NOT_SET, get_db | ||||||||||||||||||||||
| import peewee as pw | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| T = t.TypeVar("T") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| DB = get_db() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class Cache: | ||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||
| self.cache: 'dict[dt.date, Day]' = {} | ||||||||||||||||||||||
| def __init__(self, market: "str"): | ||||||||||||||||||||||
| self.market = market | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get(self, key: 'dt.date') -> 't.Optional[Day]': | ||||||||||||||||||||||
| return self.cache.get(key) | ||||||||||||||||||||||
| def get(self, key: "dt.date") -> "Day": | ||||||||||||||||||||||
| selection = Day.select().where(Day.date == key, Day.market == self.market) | ||||||||||||||||||||||
| print(selection) | ||||||||||||||||||||||
| return selection.get() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def set(self, key: 'dt.date', value: 'Day'): | ||||||||||||||||||||||
| self.cache[key] = value | ||||||||||||||||||||||
| def set(self, day: "Day"): | ||||||||||||||||||||||
| day.market = self.market | ||||||||||||||||||||||
| day.save() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_or_set(self, key: 'dt.date', func: 't.Callable[[int], None]') -> 'Day': | ||||||||||||||||||||||
| if key in self.cache: | ||||||||||||||||||||||
| def get_or_set(self, key: "dt.date", func: "t.Callable[[int], None]") -> "Day": | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| return self.get(key) | ||||||||||||||||||||||
| func(key.year) | ||||||||||||||||||||||
| if key in self.cache: | ||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||
| func(key.year) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| return self.get(key) | ||||||||||||||||||||||
| raise ValueError("Cache miss") | ||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||
| raise ValueError(f"Could not find day {key} after fetching data") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def clear(self): | ||||||||||||||||||||||
| self.cache.clear() | ||||||||||||||||||||||
| Day.delete().where(Day.market == self.market).execute() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @t.overload | ||||||||||||||||||||||
| def pop(self, key:'dt.date') -> 'Day': ... | ||||||||||||||||||||||
| def pop(self, key: "dt.date") -> "Day": ... | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @t.overload | ||||||||||||||||||||||
| def pop(self, key:'dt.date', default:'T') -> 't.Union[Day, T]': ... | ||||||||||||||||||||||
| def pop(self, key: "dt.date", default: "T") -> "t.Union[Day, T]": ... | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def pop(self, key, default=NOT_SET): | ||||||||||||||||||||||
| if default == NOT_SET: | ||||||||||||||||||||||
| return self.cache.pop(key) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return self.cache.pop(key, default) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __contains__(self, key: 'dt.date') -> bool: | ||||||||||||||||||||||
| return key in self.cache | ||||||||||||||||||||||
| query = Day.select().where(Day.market == self.market, Day.date == key) | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| item = query.get() | ||||||||||||||||||||||
| except pw.DoesNotExist: | ||||||||||||||||||||||
| if default is NOT_SET: | ||||||||||||||||||||||
| raise KeyError(key) from None | ||||||||||||||||||||||
| return default | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if default is NOT_SET: | ||||||||||||||||||||||
| item.delete_instance(recursive=False) | ||||||||||||||||||||||
| return item | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __contains__(self, key: "dt.date") -> bool: | ||||||||||||||||||||||
| return ( | ||||||||||||||||||||||
| key in Day.select().where(Day.market == self.market, Day.date == key).get() | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+58
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix incorrect implementation of contains The current implementation tries to check if a date is within a Day object, which is incorrect. It should check if the query returns any results. def __contains__(self, key: "dt.date") -> bool:
- return (
- key in Day.select().where(Day.market == self.market, Day.date == key).get()
- )
+ try:
+ Day.select().where(Day.market == self.market, Day.date == key).get()
+ return True
+ except pw.DoesNotExist:
+ return False📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
utils.get_db()is called without the required argumentget_db(cache)expects a cache dict, but no argument is supplied, so importing this module raisesTypeError: get_db() missing 1 required positional argument: 'cache'.Two minimal fixes:
or refactor
utils.get_dbto accept no arguments and return a singleton database.Either way, the current code prevents the package from being imported.
📝 Committable suggestion