11module GepUtils
22
3-
43using OrderedCollections
54using DynamicExpressions
65using LinearAlgebra
76using Optim
87using LineSearches
98using Zygote
109using Serialization
10+ using Statistics
11+ using Base. Threads: @spawn
1112
12- export Toolbox
13- export find_indices_with_sum, compile_djl_datatype, optimize_constants, minmax_scale, float16_scale, isclose
14- export save_state, load_state
1513
14+ export find_indices_with_sum, compile_djl_datatype, optimize_constants!, minmax_scale, float16_scale, isclose
15+ export save_state, load_state
16+ export create_history_recorder, record_history!, record!, close_recorder!
17+ export HistoryRecorder, OptimizationHistory
18+
19+ struct OptimizationHistory{T<: AbstractFloat }
20+ train_loss:: Vector{T}
21+ val_loss:: Vector{T}
22+ train_mean:: Vector{T}
23+ train_std:: Vector{T}
24+
25+ function OptimizationHistory (epochs:: Int , :: Type{T} ) where T<: AbstractFloat
26+ return new {T} (
27+ Vector {T} (undef, epochs),
28+ Vector {T} (undef, epochs),
29+ Vector {T} (undef, epochs),
30+ Vector {T} (undef, epochs)
31+ )
32+ end
33+ end
1634
35+ struct HistoryRecorder{T<: AbstractFloat }
36+ channel:: Channel{Tuple{Int,T,T,Vector{T}}}
37+ task:: Task
38+ history:: OptimizationHistory{T}
39+
40+ function HistoryRecorder (epochs:: Int , :: Type{T} ; buffer_size:: Int = 32 ) where T<: AbstractFloat
41+ history = OptimizationHistory (epochs, T)
42+ channel = Channel {Tuple{Int,T,T,Vector{T}}} (buffer_size)
43+ task = @spawn record_history! (channel, history)
44+ return new {T} (channel, task, history)
45+ end
46+ end
1747
18- struct Toolbox
19- gene_count:: Int
20- head_len:: Int
21- symbols:: OrderedDict{Int8,Int8}
22- gene_connections:: Vector{Int8}
23- headsyms:: Vector{Int8}
24- unary_syms:: Vector{Int8}
25- tailsyms:: Vector{Int8}
26- arrity_by_id:: OrderedDict{Int8,Int8}
27- callbacks:: Dict
28- nodes:: OrderedDict
29- gen_start_indices:: Vector{Int}
30- gep_probs:: Dict{String,AbstractFloat}
31- unary_prob:: Real
32- fitness_reset:: Tuple
33- preamble_syms:: Vector{Int8}
34- len_preamble:: Int8
35-
36-
37- function Toolbox (gene_count:: Int , head_len:: Int , symbols:: OrderedDict{Int8,Int8} , gene_connections:: Vector{Int8} ,
38- callbacks:: Dict , nodes:: OrderedDict , gep_probs:: Dict{String,AbstractFloat} ;
39- unary_prob:: Real = 0.4 , fitness_reset:: Tuple = (Inf , NaN ), preamble_syms= Int8[])
40- gene_len = head_len * 2 + 1
41- headsyms = [key for (key, arity) in symbols if arity == 2 ]
42- unary_syms = [key for (key, arity) in symbols if arity == 1 ]
43- tailsyms = [key for (key, arity) in symbols if arity < 1 && ! (key in preamble_syms)]
44- len_preamble = length (preamble_syms) == 0 ? 0 : gene_count
45- gen_start_indices = [gene_count + len_preamble + (gene_len * (i - 1 )) for i in 1 : gene_count] # depending on the usage should shift everthing
46- new (gene_count, head_len, symbols, gene_connections, headsyms, unary_syms, tailsyms, symbols,
47- callbacks, nodes, gen_start_indices, gep_probs, unary_prob, fitness_reset, preamble_syms, len_preamble)
48+ # Usage in record_history!
49+ @inline function record_history! (
50+ channel:: Channel{Tuple{Int,T,T,Vector{T}}} ,
51+ history:: OptimizationHistory{T}
52+ ) where T<: AbstractFloat
53+ for (epoch, train_loss, val_loss, fit_vector) in channel
54+ @inbounds begin
55+ history. train_loss[epoch] = train_loss
56+ history. val_loss[epoch] = val_loss
57+ history. train_mean[epoch] = mean (fit_vector)
58+ history. train_std[epoch] = std (fit_vector)
59+ end
4860 end
4961end
5062
63+ @inline function record! (
64+ recorder:: HistoryRecorder{T} ,
65+ epoch:: Int ,
66+ train_loss:: T ,
67+ val_loss:: T ,
68+ fit_vector:: Vector{T}
69+ ) where T<: AbstractFloat
70+ put! (recorder. channel, (epoch, train_loss, val_loss, fit_vector))
71+ end
72+
73+
74+ @inline function close_recorder! (recorder:: HistoryRecorder )
75+ close (recorder. channel)
76+ wait (recorder. task)
77+ end
78+
79+
5180function isclose (a:: T , b:: T ; rtol:: T = 1e-5 , atol:: T = 1e-8 ) where {T<: Number }
5281 return abs (a - b) <= (atol + rtol * abs (b))
5382end
@@ -110,29 +139,21 @@ function retrieve_constants_from_node(node::Node)
110139end
111140
112141
113- function optimize_constants (
142+ @inline function optimize_constants! (
114143 node:: Node ,
115- x_data:: AbstractArray{T} ,
116- y_data:: AbstractArray{T} ,
117- loss:: Function ,
118- operators:: AbstractOperatorEnum ;
144+ loss:: Function ;
119145 opt_method:: Symbol = :cg ,
120146 max_iterations:: Int = 250 ,
121147 n_restarts:: Int = 3
122- ) where {T <: AbstractFloat }
148+ )
123149
124150 nconst = count_constants (node)
125151
126152 if nconst == 0
127153 return node, 0.0
128154 end
129-
130- function f (tree:: Node )
131- y_pred, flag = eval_tree_array (tree, x_data, operators)
132- return loss (y_pred, y_data)
133- end
134-
135- baseline = f (node)
155+
156+ baseline = loss (node)
136157 best_node = deepcopy (node)
137158 best_loss = baseline
138159
@@ -157,7 +178,7 @@ function optimize_constants(
157178 end
158179 end
159180
160- result = Optim. optimize (f , current_node, algorithm, optimizer_options)
181+ result = Optim. optimize (loss , current_node, algorithm, optimizer_options)
161182
162183 if result. minimum < best_loss
163184 best_node = result. minimizer
0 commit comments