Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
messages::AggregatorCapabilities,
};

use crate::{AggregatorDiscoverer, AggregatorEndpoint};
use crate::{AggregatorDiscoverer, AggregatorEndpoint, model::AggregatorEndpointWithCapabilities};

/// Required capabilities for an aggregator.
#[derive(Clone, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -50,14 +50,14 @@
/// An aggregator discoverer for specific capabilities.
pub struct CapableAggregatorDiscoverer {
required_capabilities: RequiredAggregatorCapabilities,
inner_discoverer: Arc<dyn AggregatorDiscoverer>,
inner_discoverer: Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>,
}

impl CapableAggregatorDiscoverer {
/// Creates a new `CapableAggregatorDiscoverer` instance with the provided capabilities.
pub fn new(
capabilities: RequiredAggregatorCapabilities,
inner_discoverer: Arc<dyn AggregatorDiscoverer>,
inner_discoverer: Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>,
) -> Self {
Self {
required_capabilities: capabilities,
Expand All @@ -67,11 +67,11 @@
}

#[async_trait::async_trait]
impl AggregatorDiscoverer for CapableAggregatorDiscoverer {
impl AggregatorDiscoverer<AggregatorEndpointWithCapabilities> for CapableAggregatorDiscoverer {
async fn get_available_aggregators(
&self,
network: MithrilNetwork,
) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpointWithCapabilities>>> {
let aggregator_endpoints = self.inner_discoverer.get_available_aggregators(network).await?;

Ok(Box::new(CapableAggregatorDiscovererIterator {
Expand All @@ -88,20 +88,19 @@
}

impl Iterator for CapableAggregatorDiscovererIterator {
type Item = AggregatorEndpoint;
type Item = AggregatorEndpointWithCapabilities;

fn next(&mut self) -> Option<Self::Item> {
for aggregator_endpoint in self.inner_iterator.by_ref() {
let aggregator_endpoint_clone = aggregator_endpoint.clone();
let aggregator_capabilities = tokio::task::block_in_place(move || {
tokio::runtime::Handle::current().block_on(async move {
aggregator_endpoint_clone.retrieve_capabilities().await
})
});
if let Ok(aggregator_capabilities) = aggregator_capabilities
&& self.required_capabilities.matches(&aggregator_capabilities)
if let Ok(aggregator_with_capabilities) =
AggregatorEndpointWithCapabilities::try_from(aggregator_endpoint)
{
return Some(aggregator_endpoint);
if self
.required_capabilities
.matches(&aggregator_with_capabilities.capabilities())

Check warning

Code scanning / clippy

this expression creates a reference which is immediately dereferenced by the compiler Warning

this expression creates a reference which is immediately dereferenced by the compiler

Check warning

Code scanning / clippy

this expression creates a reference which is immediately dereferenced by the compiler Warning

this expression creates a reference which is immediately dereferenced by the compiler
{
return Some(aggregator_with_capabilities);
}
}

Check warning

Code scanning / clippy

this if statement can be collapsed Warning

this if statement can be collapsed

Check warning

Code scanning / clippy

this if statement can be collapsed Warning

this if statement can be collapsed
}

Expand Down Expand Up @@ -292,7 +291,7 @@
.await
.unwrap();

let next_aggregator = aggregators.next();
let next_aggregator = aggregators.next().map(|endpoint| endpoint.into());
aggregator_server_mock.assert();
assert_eq!(
Some(AggregatorEndpoint::new(aggregator_server.url("/"))),
Expand Down Expand Up @@ -369,7 +368,7 @@
.await
.unwrap();

let next_aggregator = aggregators.next();
let next_aggregator = aggregators.next().map(|endpoint| endpoint.into());
aggregator_server_mock_1.assert();
aggregator_server_mock_2.assert();
assert_eq!(
Expand Down Expand Up @@ -438,7 +437,7 @@
.await
.unwrap();

let next_aggregator = aggregators.next();
let next_aggregator = aggregators.next().map(|endpoint| endpoint.into());
aggregator_server_mock_1.assert();
aggregator_server_mock_2.assert();
aggregator_server_mock_3.assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Default for HttpConfigAggregatorDiscoverer {
}

#[async_trait::async_trait]
impl AggregatorDiscoverer for HttpConfigAggregatorDiscoverer {
impl AggregatorDiscoverer<AggregatorEndpoint> for HttpConfigAggregatorDiscoverer {
async fn get_available_aggregators(
&self,
network: MithrilNetwork,
Expand Down
6 changes: 2 additions & 4 deletions internal/mithril-aggregator-discovery/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

use mithril_common::{StdResult, entities::MithrilNetwork};

use crate::model::AggregatorEndpoint;

/// An aggregator discoverer.
#[cfg_attr(test, mockall::automock)]
#[async_trait::async_trait]
pub trait AggregatorDiscoverer: Sync + Send {
pub trait AggregatorDiscoverer<T: Send + Sync>: Sync + Send {
/// Get an iterator over a list of available aggregators in a Mithril network.
///
/// Note: there is no guarantee that the returned aggregators are sorted, complete or up-to-date.
async fn get_available_aggregators(
&self,
network: MithrilNetwork,
) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>>;
) -> StdResult<Box<dyn Iterator<Item = T>>>;
}
2 changes: 1 addition & 1 deletion internal/mithril-aggregator-discovery/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ pub mod test;
pub use capabilities_discoverer::{CapableAggregatorDiscoverer, RequiredAggregatorCapabilities};
pub use http_config_discoverer::HttpConfigAggregatorDiscoverer;
pub use interface::AggregatorDiscoverer;
pub use model::AggregatorEndpoint;
pub use model::{AggregatorEndpoint, AggregatorEndpointWithCapabilities};
pub use rand_discoverer::ShuffleAggregatorDiscoverer;
52 changes: 51 additions & 1 deletion internal/mithril-aggregator-discovery/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::time::Duration;
use serde::Serialize;

use mithril_aggregator_client::{AggregatorHttpClient, query::GetAggregatorFeaturesQuery};
use mithril_common::{StdResult, messages::AggregatorCapabilities};
use mithril_common::{StdError, StdResult, messages::AggregatorCapabilities};

/// Representation of an aggregator endpoint
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
Expand Down Expand Up @@ -32,6 +32,12 @@ impl AggregatorEndpoint {
}
}

impl From<AggregatorEndpointWithCapabilities> for AggregatorEndpoint {
fn from(endpoint: AggregatorEndpointWithCapabilities) -> Self {
Self::new(endpoint.url)
}
}

impl From<AggregatorEndpoint> for String {
fn from(endpoint: AggregatorEndpoint) -> Self {
endpoint.url
Expand All @@ -43,3 +49,47 @@ impl std::fmt::Display for AggregatorEndpoint {
write!(f, "{}", self.url)
}
}

/// Representation of an aggregator endpoint with capabilities
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct AggregatorEndpointWithCapabilities {
url: String,
capabilities: AggregatorCapabilities,
}
impl AggregatorEndpointWithCapabilities {
/// Create a new AggregatorEndpointWithCapabilities instance
pub fn new(url: String, capabilities: AggregatorCapabilities) -> Self {
Self { url, capabilities }
}

/// Get the capabilities of the aggregator
pub fn capabilities(&self) -> &AggregatorCapabilities {
&self.capabilities
}
}

impl TryFrom<AggregatorEndpoint> for AggregatorEndpointWithCapabilities {
type Error = StdError;

fn try_from(endpoint: AggregatorEndpoint) -> Result<Self, Self::Error> {
let endpoint_clone = endpoint.clone();
let aggregator_capabilities = tokio::task::block_in_place(move || {
tokio::runtime::Handle::current()
.block_on(async move { endpoint_clone.retrieve_capabilities().await })
});

Ok(Self::new(endpoint.url, aggregator_capabilities?))
}
}

impl From<AggregatorEndpointWithCapabilities> for String {
fn from(endpoint: AggregatorEndpointWithCapabilities) -> Self {
endpoint.url
}
}

impl std::fmt::Display for AggregatorEndpointWithCapabilities {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.url)
}
}
11 changes: 8 additions & 3 deletions internal/mithril-aggregator-discovery/src/rand_discoverer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ use crate::{AggregatorDiscoverer, AggregatorEndpoint};
/// A discoverer that returns a random set of aggregators
pub struct ShuffleAggregatorDiscoverer<R: Rng + Send + Sized> {
random_generator: Arc<Mutex<Box<R>>>,
inner_discoverer: Arc<dyn AggregatorDiscoverer>,
inner_discoverer: Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>,
}

impl<R: Rng + Send + Sized> ShuffleAggregatorDiscoverer<R> {
/// Creates a new `ShuffleAggregatorDiscoverer` instance with the provided inner discoverer.
pub fn new(inner_discoverer: Arc<dyn AggregatorDiscoverer>, random_generator: R) -> Self {
pub fn new(
inner_discoverer: Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>,
random_generator: R,
) -> Self {
Self {
inner_discoverer,
random_generator: Arc::new(Mutex::new(Box::new(random_generator))),
Expand All @@ -24,7 +27,9 @@ impl<R: Rng + Send + Sized> ShuffleAggregatorDiscoverer<R> {
}

#[async_trait::async_trait]
impl<R: Rng + Send + Sized> AggregatorDiscoverer for ShuffleAggregatorDiscoverer<R> {
impl<R: Rng + Send + Sized> AggregatorDiscoverer<AggregatorEndpoint>
for ShuffleAggregatorDiscoverer<R>
{
async fn get_available_aggregators(
&self,
network: MithrilNetwork,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl AggregatorDiscovererFake {
}

#[async_trait::async_trait]
impl AggregatorDiscoverer for AggregatorDiscovererFake {
impl AggregatorDiscoverer<AggregatorEndpoint> for AggregatorDiscovererFake {
async fn get_available_aggregators(
&self,
_network: MithrilNetwork,
Expand Down
22 changes: 20 additions & 2 deletions mithril-client-cli/src/commands/tools/aggregator_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,28 @@ impl AggregatorDiscoveryCommand {
} else {
let lines = lines
.into_iter()
.map(|endpoint| vec![endpoint.cell()])
.map(|endpoint| {
let endpoint_clone = endpoint.clone();
let capabilities = endpoint_clone.capabilities();
vec![
endpoint.cell(),
capabilities.aggregate_signature_type.cell(),
capabilities
.signed_entity_types
.iter()
.map(|signed_entity_type| signed_entity_type.to_string())
.collect::<Vec<_>>()
.join(",")
.cell(),
]
})
.collect::<Vec<_>>()
.table()
.title(vec!["Aggregator Endpoint".cell()]);
.title(vec![
"Aggregator Endpoint".cell(),
"Aggregate Signature Type".cell(),
"Signed Entity Types".cell(),
]);
print_stdout(lines)?;
}

Expand Down
20 changes: 12 additions & 8 deletions mithril-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use slog::{Logger, o};

#[cfg(not(target_family = "wasm"))]
use mithril_aggregator_discovery::{
AggregatorDiscoverer, AggregatorEndpoint, CapableAggregatorDiscoverer,
HttpConfigAggregatorDiscoverer, RequiredAggregatorCapabilities, ShuffleAggregatorDiscoverer,
AggregatorDiscoverer, AggregatorEndpoint, AggregatorEndpointWithCapabilities,
CapableAggregatorDiscoverer, HttpConfigAggregatorDiscoverer, RequiredAggregatorCapabilities,
ShuffleAggregatorDiscoverer,
};
use mithril_common::api_version::APIVersionProvider;
use mithril_common::{MITHRIL_CLIENT_TYPE_HEADER, MITHRIL_ORIGIN_TAG_HEADER};
Expand Down Expand Up @@ -201,7 +202,7 @@ pub struct ClientBuilder {
#[cfg(not(target_family = "wasm"))]
aggregator_capabilities: Option<RequiredAggregatorCapabilities>,
#[cfg(not(target_family = "wasm"))]
aggregator_discoverer: Option<Arc<dyn AggregatorDiscoverer>>,
aggregator_discoverer: Option<Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>>,
genesis_verification_key: Option<GenesisVerificationKey>,
origin_tag: Option<String>,
client_type: Option<String>,
Expand Down Expand Up @@ -295,7 +296,7 @@ impl ClientBuilder {
#[cfg(not(target_family = "wasm"))]
pub fn with_aggregator_discoverer(
mut self,
discoverer: Arc<dyn AggregatorDiscoverer>,
discoverer: Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>>,
) -> ClientBuilder {
self.aggregator_discoverer = Some(discoverer);

Expand Down Expand Up @@ -432,7 +433,7 @@ impl ClientBuilder {
pub fn discover_aggregator(
&self,
network: &MithrilNetwork,
) -> MithrilResult<impl Iterator<Item = AggregatorEndpoint>> {
) -> MithrilResult<impl Iterator<Item = AggregatorEndpointWithCapabilities>> {
let discoverer = self
.aggregator_discoverer
.clone()
Expand All @@ -441,9 +442,12 @@ impl ClientBuilder {
Arc::new(CapableAggregatorDiscoverer::new(
capabilities.to_owned(),
discoverer.clone(),
)) as Arc<dyn AggregatorDiscoverer>
)) as Arc<dyn AggregatorDiscoverer<AggregatorEndpointWithCapabilities>>
} else {
discoverer as Arc<dyn AggregatorDiscoverer>
Arc::new(CapableAggregatorDiscoverer::new(
RequiredAggregatorCapabilities::All,
discoverer.clone(),
)) as Arc<dyn AggregatorDiscoverer<AggregatorEndpointWithCapabilities>>
};

tokio::task::block_in_place(move || {
Expand All @@ -458,7 +462,7 @@ impl ClientBuilder {

/// Default aggregator discoverer to use to find the aggregator endpoint when in automatic discovery.
#[cfg(not(target_family = "wasm"))]
fn default_aggregator_discoverer() -> Arc<dyn AggregatorDiscoverer> {
fn default_aggregator_discoverer() -> Arc<dyn AggregatorDiscoverer<AggregatorEndpoint>> {
Arc::new(ShuffleAggregatorDiscoverer::new(
Arc::new(HttpConfigAggregatorDiscoverer::default()),
{
Expand Down
Loading