diff --git a/pyabc2/sources/the_session.py b/pyabc2/sources/the_session.py index 203e4e4..d608be3 100644 --- a/pyabc2/sources/the_session.py +++ b/pyabc2/sources/the_session.py @@ -14,6 +14,7 @@ import logging import os import warnings +from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -477,6 +478,184 @@ def load_meta( return df +def _consume( + endpoint: str, + *, + pages: int | None = None, + size: int = 50, + max_threads: int = 1, + **params, +) -> list[dict]: + """Consume paginated The Session API endpoint, returning a list of entries. + + Parameters + ---------- + endpoint + The API endpoint, e.g. ``'/tunes/popular'``. + pages + Number of pages to retrieve. + Default: all pages. + size + Number of entries per page. + Corresponds to the ``perpage`` API parameter. + Default: 50 (maximum). + max_threads + Maximum number of threads to use. + Default: 1 (no multi-threading). + **params + Additional parameters to pass to the API. + For example, ``sortby=popular`` works for some endpoints. + Note that these, if provided, will be ignored: ``format``, ``perpage``, ``page``. + """ + import requests + + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + if not 1 <= size <= 50: + raise ValueError("`size` must be between 1 and 50 (inclusive).") + if pages is not None and pages < 1: + raise ValueError("`pages` must be >= 1.") + if max_threads < 1: + raise ValueError("`max_threads` must be >= 1.") + + base_url = "https://thesession.org" + + params.update( + { + "format": "json", + "perpage": size, + } + ) + + def get_page(page: int) -> dict: + page_params = params.copy() + page_params["page"] = page + url = base_url + endpoint + r = requests.get(url, timeout=5, params=page_params) + r.raise_for_status() + return r.json() + + # Even for page out of bounds + # https://thesession.org/tunes/popular?format=json&perpage=50&page=1000000 + # we get 'pages' (page count) and 'total' (entry count) + # (though the key that contains the data we want varies by endpoint) + # So start by getting the first page, and then we can multithread the rest + first_page = get_page(1) + if pages is None: + pages = first_page.get("pages", 1) + assert isinstance(pages, int) + parallel = pages > 2 and max_threads > 1 + + page_range = range(2, pages + 1) + if parallel: + from multiprocessing.pool import ThreadPool + + with ThreadPool(min(max_threads, pages - 1)) as pool: + remaining_pages = pool.map(get_page, page_range) + else: + remaining_pages = [get_page(page) for page in page_range] + + return [first_page] + remaining_pages + + +def get_tune_collections(tune_id: int) -> "pandas.DataFrame": + """Get data about the other collections a tune is in.""" + # https://thesession.org/tunes/1/collections?format=json + import pandas as pd + + endpoint = f"/tunes/{tune_id}/collections" + (res,) = _consume(endpoint) + + return pd.DataFrame(res["collections"]).rename( + columns={ + "id": "collection_id", + "name": "collection_name", + "url": "collection_page", + # ^ https://thesession.org/tunes/collections/ID + "identifier": "collection_tune_id", + # ^ sometimes string ID (e.g. for print book), sometimes URL (e.g. for Norbeck) + } + ) + + +def _tune_id_from_url(url: str) -> int: + from urllib.parse import urlsplit + + res = urlsplit(url) + return int(res.path.split("/")[-1]) + + +def get_member_set(member_id: int, set_id: int) -> list[dict]: + """Get information about the tunes in a specific member's set. + + Parameters + ---------- + member_id + Numeric identifier of the member on The Session. + For example, Jeremy is ``1`` (https://thesession.org/members/1). + set_id + Numeric identifier of the set belonging to ``member_id``. + """ + + endpoint = f"/members/{member_id}/sets/{set_id}" + (res,) = _consume(endpoint) + + tunes = [] + for setting in res["settings"]: + d = { + "name": setting["name"], + "tune_id": _tune_id_from_url(setting["url"]), + "setting_id": setting["id"], + "type": setting["type"], + "key": setting["key"], + } + tunes.append(d) + + return tunes + + +def get_member_sets(member_id: int, **kwargs) -> list[list[dict]]: + """Get information about all sets belonging to a specific member. + + Parameters + ---------- + member_id + Numeric identifier of the member on The Session. + For example, Jeremy is ``1`` (https://thesession.org/members/1). + **kwargs + Additional parameters passed to :func:`_consume`, + e.g. ``pages``, ``size``, ``max_threads``. + + See Also + -------- + get_member_set + """ + + endpoint = f"/members/{member_id}/sets" + + if "max_threads" not in kwargs: + kwargs["max_threads"] = 4 + results = _consume(endpoint, **kwargs) + + sets = [] + for set in chain.from_iterable(res["sets"] for res in results): + sets.append( + [ + { + "name": setting["name"], + "tune_id": _tune_id_from_url(setting["url"]), + "setting_id": setting["id"], + "type": setting["type"], + "key": setting["key"], + } + for setting in set["settings"] + ] + ) + + return sets + + if __name__ == "__main__": # pragma: no cover tune = load_url("https://thesession.org/tunes/10000") print(tune) diff --git a/tests/test_sources.py b/tests/test_sources.py index 84b86d4..1a23492 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -474,3 +474,52 @@ def test_bill_black_load(): lst = bill_black.load_meta() assert len(lst) > 0 assert lst[0].startswith("X:") + + +def test_the_session_get_tune_collections(): + df = the_session.get_tune_collections(1) # Cooley's + assert not df.empty + + +def test_the_session_get_member_set(): + tunes = the_session.get_member_set(65013, 106212) + assert len(tunes) == 3 + d = tunes[0] + assert d["name"] == "Garech's Wedding" + assert d["tune_id"] == 2620 + assert d["setting_id"] == 31341 + + +def test_the_session_get_member_sets(): + sets = the_session.get_member_sets(65013) + assert len(sets) >= 1 + d = sets[0][0] + assert d["name"] == "Garech's Wedding" + assert d["tune_id"] == 2620 + assert d["setting_id"] == 31341 + + +def test_the_session_get_member_sets_multipage(): + sets = the_session.get_member_sets(1, pages=3, size=2, max_threads=2) + assert len(sets) == 6 + d = sets[0][0] + assert d["name"] == "Toss The Feathers" + assert d["tune_id"] == 138 + + +def test_the_session_consume_validation(): + f = the_session.get_member_sets + + with pytest.raises(ValueError, match="`size`"): + _ = f(1, size=1000) + + with pytest.raises(ValueError, match="`pages`"): + _ = f(1, pages=0) + + with pytest.raises(ValueError, match="`max_threads`"): + _ = f(1, max_threads=0) + + +def test_the_session_consume_auto_leading_slash(): + (d,) = the_session._consume("tunes/22878") + assert d["name"] == "Jack Farrell's"