Skip to content
Merged
179 changes: 179 additions & 0 deletions pyabc2/sources/the_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import warnings
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -477,6 +478,184 @@ def load_meta(
return df


def _consume(
endpoint: str,
*,
pages: int | None = None,
size: int = 50,
max_threads: int = 1,
**params,
) -> list[dict]:
"""Consume paginated The Session API endpoint, returning a list of entries.

Parameters
----------
endpoint
The API endpoint, e.g. ``'/tunes/popular'``.
pages
Number of pages to retrieve.
Default: all pages.
size
Number of entries per page.
Corresponds to the ``perpage`` API parameter.
Default: 50 (maximum).
max_threads
Maximum number of threads to use.
Default: 1 (no multi-threading).
**params
Additional parameters to pass to the API.
For example, ``sortby=popular`` works for some endpoints.
Note that these, if provided, will be ignored: ``format``, ``perpage``, ``page``.
"""
import requests

if not endpoint.startswith("/"):
endpoint = "/" + endpoint

if not 1 <= size <= 50:
raise ValueError("`size` must be between 1 and 50 (inclusive).")
if pages is not None and pages < 1:
raise ValueError("`pages` must be >= 1.")
if max_threads < 1:
raise ValueError("`max_threads` must be >= 1.")

base_url = "https://thesession.org"

params.update(
{
"format": "json",
"perpage": size,
}
)

def get_page(page: int) -> dict:
page_params = params.copy()
page_params["page"] = page
url = base_url + endpoint
r = requests.get(url, timeout=5, params=page_params)
r.raise_for_status()
return r.json()

# Even for page out of bounds
# https://thesession.org/tunes/popular?format=json&perpage=50&page=1000000
# we get 'pages' (page count) and 'total' (entry count)
# (though the key that contains the data we want varies by endpoint)
# So start by getting the first page, and then we can multithread the rest
first_page = get_page(1)
if pages is None:
pages = first_page.get("pages", 1)
assert isinstance(pages, int)
parallel = pages > 2 and max_threads > 1

page_range = range(2, pages + 1)
if parallel:
from multiprocessing.pool import ThreadPool

with ThreadPool(min(max_threads, pages - 1)) as pool:
remaining_pages = pool.map(get_page, page_range)
else:
remaining_pages = [get_page(page) for page in page_range]

return [first_page] + remaining_pages


def get_tune_collections(tune_id: int) -> "pandas.DataFrame":
"""Get data about the other collections a tune is in."""
# https://thesession.org/tunes/1/collections?format=json
import pandas as pd

endpoint = f"/tunes/{tune_id}/collections"
(res,) = _consume(endpoint)

return pd.DataFrame(res["collections"]).rename(
columns={
"id": "collection_id",
"name": "collection_name",
"url": "collection_page",
# ^ https://thesession.org/tunes/collections/ID
"identifier": "collection_tune_id",
# ^ sometimes string ID (e.g. for print book), sometimes URL (e.g. for Norbeck)
}
)


def _tune_id_from_url(url: str) -> int:
from urllib.parse import urlsplit

res = urlsplit(url)
return int(res.path.split("/")[-1])


def get_member_set(member_id: int, set_id: int) -> list[dict]:
"""Get information about the tunes in a specific member's set.

Parameters
----------
member_id
Numeric identifier of the member on The Session.
For example, Jeremy is ``1`` (https://thesession.org/members/1).
set_id
Numeric identifier of the set belonging to ``member_id``.
"""

endpoint = f"/members/{member_id}/sets/{set_id}"
(res,) = _consume(endpoint)

tunes = []
for setting in res["settings"]:
d = {
"name": setting["name"],
"tune_id": _tune_id_from_url(setting["url"]),
"setting_id": setting["id"],
"type": setting["type"],
"key": setting["key"],
}
tunes.append(d)

return tunes


def get_member_sets(member_id: int, **kwargs) -> list[list[dict]]:
"""Get information about all sets belonging to a specific member.

Parameters
----------
member_id
Numeric identifier of the member on The Session.
For example, Jeremy is ``1`` (https://thesession.org/members/1).
**kwargs
Additional parameters passed to :func:`_consume`,
e.g. ``pages``, ``size``, ``max_threads``.

See Also
--------
get_member_set
"""

endpoint = f"/members/{member_id}/sets"

if "max_threads" not in kwargs:
kwargs["max_threads"] = 4
results = _consume(endpoint, **kwargs)

sets = []
for set in chain.from_iterable(res["sets"] for res in results):
sets.append(
[
{
"name": setting["name"],
"tune_id": _tune_id_from_url(setting["url"]),
"setting_id": setting["id"],
"type": setting["type"],
"key": setting["key"],
}
for setting in set["settings"]
]
)

return sets


if __name__ == "__main__": # pragma: no cover
tune = load_url("https://thesession.org/tunes/10000")
print(tune)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,52 @@ def test_bill_black_load():
lst = bill_black.load_meta()
assert len(lst) > 0
assert lst[0].startswith("X:")


def test_the_session_get_tune_collections():
df = the_session.get_tune_collections(1) # Cooley's
assert not df.empty


def test_the_session_get_member_set():
tunes = the_session.get_member_set(65013, 106212)
assert len(tunes) == 3
d = tunes[0]
assert d["name"] == "Garech's Wedding"
assert d["tune_id"] == 2620
assert d["setting_id"] == 31341


def test_the_session_get_member_sets():
sets = the_session.get_member_sets(65013)
assert len(sets) >= 1
d = sets[0][0]
assert d["name"] == "Garech's Wedding"
assert d["tune_id"] == 2620
assert d["setting_id"] == 31341


def test_the_session_get_member_sets_multipage():
sets = the_session.get_member_sets(1, pages=3, size=2, max_threads=2)
assert len(sets) == 6
d = sets[0][0]
assert d["name"] == "Toss The Feathers"
assert d["tune_id"] == 138


def test_the_session_consume_validation():
f = the_session.get_member_sets

with pytest.raises(ValueError, match="`size`"):
_ = f(1, size=1000)

with pytest.raises(ValueError, match="`pages`"):
_ = f(1, pages=0)

with pytest.raises(ValueError, match="`max_threads`"):
_ = f(1, max_threads=0)


def test_the_session_consume_auto_leading_slash():
(d,) = the_session._consume("tunes/22878")
assert d["name"] == "Jack Farrell's"