diff --git a/src/types/arrow.rs b/src/types/arrow.rs index f0a74949..cfac4aea 100644 --- a/src/types/arrow.rs +++ b/src/types/arrow.rs @@ -356,6 +356,8 @@ impl<'brand> DisconnectConstructible<'brand, Option<&Arrow<'brand>>> for Arrow<' impl<'brand> JetConstructible<'brand> for Arrow<'brand> { fn jet(inference_context: &Context<'brand>, jet: &dyn Jet) -> Self { + inference_context.check_jet(jet); + Arrow { source: jet.source_ty().to_type(inference_context), target: jet.target_ty().to_type(inference_context), diff --git a/src/types/context.rs b/src/types/context.rs index 306bc2bd..7992d2ae 100644 --- a/src/types/context.rs +++ b/src/types/context.rs @@ -14,6 +14,7 @@ //! the other. //! +use std::any::TypeId; use std::fmt; use std::marker::PhantomData; use std::sync::{Arc, Mutex, MutexGuard}; @@ -21,6 +22,7 @@ use std::sync::{Arc, Mutex, MutexGuard}; use ghost_cell::GhostToken; use crate::dag::{Dag, DagLike}; +use crate::jet::Jet; use super::{ Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken, @@ -48,6 +50,8 @@ pub struct Context<'brand> { struct ContextInner<'brand> { slab: Vec>, + /// Concrete jet type registered in this context, if any. + jet_type: Option, } impl fmt::Debug for Context<'_> { @@ -81,7 +85,10 @@ impl<'brand> Context<'brand> { Context { inner: Arc::new(Mutex::new(WithGhostToken { token, - inner: ContextInner { slab: vec![] }, + inner: ContextInner { + slab: vec![], + jet_type: None, + }, })), } } @@ -147,6 +154,23 @@ impl<'brand> Context<'brand> { } } + /// Asserts that all jets in this context have the same concrete type. + /// + /// Records the jet's type on first call, panics on subsequent calls with + /// a different concrete type. + pub fn check_jet(&self, jet: &dyn Jet) { + let new_id = jet.as_any().type_id(); + let mut lock = self.lock(); + + if let Some(existing_id) = lock.inner.jet_type { + assert!(existing_id == new_id, "mixed jet types in context"); + + return; + } + + lock.inner.jet_type = Some(new_id); + } + /// Accesses a bound. pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> { let lock = self.lock(); diff --git a/src/types/mod.rs b/src/types/mod.rs index 39e4d4cd..070faf77 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -415,6 +415,8 @@ impl fmt::Display for Type<'_> { mod tests { use super::*; + use crate::jet::Core; + use crate::node::JetConstructible; use crate::node::{ConstructNode, CoreConstructible}; #[test] @@ -446,4 +448,24 @@ mod tests { let _ = format!("{:?}", case.arrow().source); }); } + + #[test] + fn check_jet_same_type_ok() { + Context::with_context(|ctx| { + let _ = Arc::::jet(&ctx, &Core::Add32); + let _ = Arc::::jet(&ctx, &Core::Subtract32); + }); + } + + #[cfg(feature = "elements")] + #[test] + #[should_panic(expected = "mixed jet types in context")] + fn check_jet_different_types_panics() { + use crate::jet::Elements; + + Context::with_context(|ctx| { + let _ = Arc::::jet(&ctx, &Core::Add32); + let _ = Arc::::jet(&ctx, &Elements::Add32); + }); + } }