@@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
2323const EMPTY_DICT_T = typeof (EMPTY_DICT)
2424
2525@compactify show_methods= false begin
26- @abstract struct BasicSymbolic{T} <: Symbolic{T}
26+ @abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
2727 metadata:: Metadata = NO_METADATA
2828 end
29- struct Sym{T} <: BasicSymbolic{T}
29+ mutable struct Sym{T} <: BasicSymbolic{T}
3030 name:: Symbol = :OOF
3131 end
32- struct Term{T} <: BasicSymbolic{T}
32+ mutable struct Term{T} <: BasicSymbolic{T}
3333 f:: Any = identity # base/num if Pow; issorted if Add/Dict
3434 arguments:: Vector{Any} = EMPTY_ARGS
3535 hash:: RefValue{UInt} = EMPTY_HASH
3636 end
37- struct Mul{T} <: BasicSymbolic{T}
37+ mutable struct Mul{T} <: BasicSymbolic{T}
3838 coeff:: Any = 0 # exp/den if Pow
3939 dict:: EMPTY_DICT_T = EMPTY_DICT
4040 hash:: RefValue{UInt} = EMPTY_HASH
4141 arguments:: Vector{Any} = EMPTY_ARGS
4242 issorted:: RefValue{Bool} = NOT_SORTED
4343 end
44- struct Add{T} <: BasicSymbolic{T}
44+ mutable struct Add{T} <: BasicSymbolic{T}
4545 coeff:: Any = 0 # exp/den if Pow
4646 dict:: EMPTY_DICT_T = EMPTY_DICT
4747 hash:: RefValue{UInt} = EMPTY_HASH
4848 arguments:: Vector{Any} = EMPTY_ARGS
4949 issorted:: RefValue{Bool} = NOT_SORTED
5050 end
51- struct Div{T} <: BasicSymbolic{T}
51+ mutable struct Div{T} <: BasicSymbolic{T}
5252 num:: Any = 1
5353 den:: Any = 1
5454 simplified:: Bool = false
5555 arguments:: Vector{Any} = EMPTY_ARGS
5656 end
57- struct Pow{T} <: BasicSymbolic{T}
57+ mutable struct Pow{T} <: BasicSymbolic{T}
5858 base:: Any = 1
5959 exp:: Any = 1
6060 arguments:: Vector{Any} = EMPTY_ARGS
@@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
7777 end
7878end
7979
80+ const wvd = WeakValueDict {UInt, BasicSymbolic} ()
81+
8082# Same but different error messages
8183@noinline error_on_type () = error (" Internal error: unreachable reached!" )
8284@noinline error_sym () = error (" Sym doesn't have a operation or arguments!" )
@@ -92,7 +94,11 @@ const SIMPLIFIED = 0x01 << 0
9294function ConstructionBase. setproperties (obj:: BasicSymbolic{T} , patch:: NamedTuple ):: BasicSymbolic{T} where T
9395 nt = getproperties (obj)
9496 nt_new = merge (nt, patch)
95- Unityper. rt_constructor (obj){T}(;nt_new... )
97+ # Call outer constructor because hash consing cannot be applied in inner constructor
98+ @compactified obj:: BasicSymbolic begin
99+ Sym => Sym {T} (nt_new. name; nt_new... )
100+ _ => Unityper. rt_constructor (obj){T}(;nt_new... )
101+ end
96102end
97103
98104# ##
@@ -265,6 +271,26 @@ function _isequal(a, b, E)
265271 end
266272end
267273
274+ """
275+ $(TYPEDSIGNATURES)
276+
277+ Checks for equality between two `BasicSymbolic` objects, considering both their
278+ values and metadata.
279+
280+ The default `Base.isequal` function for `BasicSymbolic` only compares their expressions
281+ and ignores metadata. This does not help deal with hash collisions when metadata is
282+ relevant for distinguishing expressions, particularly in hashing contexts. This function
283+ provides a stricter equality check that includes metadata comparison, preventing
284+ such collisions.
285+
286+ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and
287+ downstream packages like `ModelingToolkit.jl`, hence the need for this separate
288+ function.
289+ """
290+ function isequal_with_metadata (a:: BasicSymbolic , b:: BasicSymbolic ):: Bool
291+ isequal (a, b) && isequal (metadata (a), metadata (b))
292+ end
293+
268294Base. one ( s:: Symbolic ) = one ( symtype (s))
269295Base. zero (s:: Symbolic ) = zero (symtype (s))
270296
@@ -307,12 +333,61 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
307333 end
308334end
309335
336+ """
337+ $(TYPEDSIGNATURES)
338+
339+ Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and
340+ symtype.
341+
342+ This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic`
343+ objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also
344+ includes the metadata and symtype in the hash calculation. This can be beneficial for hash
345+ consing, allowing for more effective deduplication of symbolically equivalent expressions
346+ with different metadata or symtypes.
347+ """
348+ hash2 (s:: BasicSymbolic ) = hash2 (s, zero (UInt))
349+ function hash2 (s:: BasicSymbolic{T} , salt:: UInt ):: UInt where {T}
350+ hash (metadata (s), hash (T, hash (s, salt)))
351+ end
352+
310353# ##
311354# ## Constructors
312355# ##
313356
314- function Sym {T} (name:: Symbol ; kw... ) where T
315- Sym {T} (; name= name, kw... )
357+ """
358+ $(TYPEDSIGNATURES)
359+
360+ Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
361+
362+ This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
363+ custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
364+ objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
365+ different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
366+ which includes metadata comparison, is used to confirm the equivalence of objects with
367+ matching hashes. If an equivalent object is found, the existing object is returned;
368+ otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
369+ for runtime code generation, and supports built-in common subexpression elimination,
370+ particularly when working with symbolic objects with metadata.
371+
372+ Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
373+ stored, allowing objects that are no longer strongly referenced to be garbage collected.
374+ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
375+ `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
376+ original behavior of those functions.
377+ """
378+ function BasicSymbolic (s:: BasicSymbolic ):: BasicSymbolic
379+ h = hash2 (s)
380+ t = get! (wvd, h, s)
381+ if t === s || isequal_with_metadata (t, s)
382+ return t
383+ else
384+ return s
385+ end
386+ end
387+
388+ function Sym {T} (name:: Symbol ; kw... ) where {T}
389+ s = Sym {T} (; name, kw... )
390+ BasicSymbolic (s)
316391end
317392
318393function Term {T} (f, args; kw... ) where T
0 commit comments