Skip to content

Commit 39fd937

Browse files
authored
Merge pull request #1 from maxreiss123/develop
Develop
2 parents 3ececd5 + 0e6f109 commit 39fd937

File tree

9 files changed

+558
-459
lines changed

9 files changed

+558
-459
lines changed

src/Entities.jl

Lines changed: 394 additions & 1 deletion
Large diffs are not rendered by default.

src/Gep.jl

Lines changed: 83 additions & 400 deletions
Large diffs are not rendered by default.

src/JGep.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ export get_loss_function
1919

2020

2121
using .GepUtils
22-
export Toolbox
23-
export find_indices_with_sum, compile_djl_datatype, optimize_constants, minmax_scale, float16_scale, isclose
22+
export find_indices_with_sum, compile_djl_datatype, optimize_constants!, minmax_scale, float16_scale, isclose
2423
export save_state, load_state
2524

2625

@@ -41,8 +40,12 @@ export equal_unit_forward, mul_unit_forward, div_unit_forward, zero_unit_backwar
4140
export get_feature_dims_json, get_target_dim_json, retrieve_coeffs_based_on_similarity
4241

4342

44-
using .SymbolicEntities
43+
using .GepEntities
44+
export Chromosome, Toolbox
4545
export AbstractSymbol, FunctionalSymbol, BasicSymbol, SymbolConfig
46+
export fitness, set_fitness!
47+
export generate_gene, generate_preamle!, compile_expression!, generate_chromosome, generate_population
48+
export genetic_operations!
4649

4750

4851
end

src/Util.jl

Lines changed: 69 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,82 @@
11
module GepUtils
22

3-
43
using OrderedCollections
54
using DynamicExpressions
65
using LinearAlgebra
76
using Optim
87
using LineSearches
98
using Zygote
109
using Serialization
10+
using Statistics
11+
using Base.Threads: @spawn
1112

12-
export Toolbox
13-
export find_indices_with_sum, compile_djl_datatype, optimize_constants, minmax_scale, float16_scale, isclose
14-
export save_state, load_state
1513

14+
export find_indices_with_sum, compile_djl_datatype, optimize_constants!, minmax_scale, float16_scale, isclose
15+
export save_state, load_state
16+
export create_history_recorder, record_history!, record!, close_recorder!
17+
export HistoryRecorder, OptimizationHistory
18+
19+
struct OptimizationHistory{T<:AbstractFloat}
20+
train_loss::Vector{T}
21+
val_loss::Vector{T}
22+
train_mean::Vector{T}
23+
train_std::Vector{T}
24+
25+
function OptimizationHistory(epochs::Int, ::Type{T}) where T<:AbstractFloat
26+
return new{T}(
27+
Vector{T}(undef, epochs),
28+
Vector{T}(undef, epochs),
29+
Vector{T}(undef, epochs),
30+
Vector{T}(undef, epochs)
31+
)
32+
end
33+
end
1634

35+
struct HistoryRecorder{T<:AbstractFloat}
36+
channel::Channel{Tuple{Int,T,T,Vector{T}}}
37+
task::Task
38+
history::OptimizationHistory{T}
39+
40+
function HistoryRecorder(epochs::Int, ::Type{T}; buffer_size::Int=32) where T<:AbstractFloat
41+
history = OptimizationHistory(epochs, T)
42+
channel = Channel{Tuple{Int,T,T,Vector{T}}}(buffer_size)
43+
task = @spawn record_history!(channel, history)
44+
return new{T}(channel, task, history)
45+
end
46+
end
1747

18-
struct Toolbox
19-
gene_count::Int
20-
head_len::Int
21-
symbols::OrderedDict{Int8,Int8}
22-
gene_connections::Vector{Int8}
23-
headsyms::Vector{Int8}
24-
unary_syms::Vector{Int8}
25-
tailsyms::Vector{Int8}
26-
arrity_by_id::OrderedDict{Int8,Int8}
27-
callbacks::Dict
28-
nodes::OrderedDict
29-
gen_start_indices::Vector{Int}
30-
gep_probs::Dict{String,AbstractFloat}
31-
unary_prob::Real
32-
fitness_reset::Tuple
33-
preamble_syms::Vector{Int8}
34-
len_preamble::Int8
35-
36-
37-
function Toolbox(gene_count::Int, head_len::Int, symbols::OrderedDict{Int8,Int8}, gene_connections::Vector{Int8},
38-
callbacks::Dict, nodes::OrderedDict, gep_probs::Dict{String,AbstractFloat};
39-
unary_prob::Real=0.4, fitness_reset::Tuple=(Inf, NaN), preamble_syms=Int8[])
40-
gene_len = head_len * 2 + 1
41-
headsyms = [key for (key, arity) in symbols if arity == 2]
42-
unary_syms = [key for (key, arity) in symbols if arity == 1]
43-
tailsyms = [key for (key, arity) in symbols if arity < 1 && !(key in preamble_syms)]
44-
len_preamble = length(preamble_syms) == 0 ? 0 : gene_count
45-
gen_start_indices = [gene_count + len_preamble + (gene_len * (i - 1)) for i in 1:gene_count] #depending on the usage should shift everthing
46-
new(gene_count, head_len, symbols, gene_connections, headsyms, unary_syms, tailsyms, symbols,
47-
callbacks, nodes, gen_start_indices, gep_probs, unary_prob, fitness_reset, preamble_syms, len_preamble)
48+
# Usage in record_history!
49+
@inline function record_history!(
50+
channel::Channel{Tuple{Int,T,T,Vector{T}}},
51+
history::OptimizationHistory{T}
52+
) where T<:AbstractFloat
53+
for (epoch, train_loss, val_loss, fit_vector) in channel
54+
@inbounds begin
55+
history.train_loss[epoch] = train_loss
56+
history.val_loss[epoch] = val_loss
57+
history.train_mean[epoch] = mean(fit_vector)
58+
history.train_std[epoch] = std(fit_vector)
59+
end
4860
end
4961
end
5062

63+
@inline function record!(
64+
recorder::HistoryRecorder{T},
65+
epoch::Int,
66+
train_loss::T,
67+
val_loss::T,
68+
fit_vector::Vector{T}
69+
) where T<:AbstractFloat
70+
put!(recorder.channel, (epoch, train_loss, val_loss, fit_vector))
71+
end
72+
73+
74+
@inline function close_recorder!(recorder::HistoryRecorder)
75+
close(recorder.channel)
76+
wait(recorder.task)
77+
end
78+
79+
5180
function isclose(a::T, b::T; rtol::T=1e-5, atol::T=1e-8) where {T<:Number}
5281
return abs(a - b) <= (atol + rtol * abs(b))
5382
end
@@ -110,29 +139,21 @@ function retrieve_constants_from_node(node::Node)
110139
end
111140

112141

113-
function optimize_constants(
142+
@inline function optimize_constants!(
114143
node::Node,
115-
x_data::AbstractArray{T},
116-
y_data::AbstractArray{T},
117-
loss::Function,
118-
operators::AbstractOperatorEnum;
144+
loss::Function;
119145
opt_method::Symbol=:cg,
120146
max_iterations::Int=250,
121147
n_restarts::Int=3
122-
) where {T<:AbstractFloat}
148+
)
123149

124150
nconst = count_constants(node)
125151

126152
if nconst == 0
127153
return node, 0.0
128154
end
129-
130-
function f(tree::Node)
131-
y_pred, flag = eval_tree_array(tree, x_data, operators)
132-
return loss(y_pred, y_data)
133-
end
134-
135-
baseline = f(node)
155+
156+
baseline = loss(node)
136157
best_node = deepcopy(node)
137158
best_loss = baseline
138159

@@ -157,7 +178,7 @@ function optimize_constants(
157178
end
158179
end
159180

160-
result = Optim.optimize(f, current_node, algorithm, optimizer_options)
181+
result = Optim.optimize(loss, current_node, algorithm, optimizer_options)
161182

162183
if result.minimum < best_loss
163184
best_node = result.minimizer

test/Main_min_example.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ y_data = @. x_data[1,:] * x_data[1,:] + x_data[1,:] * x_data[2,:] - 2 * x_data[2
5151

5252
#call the function -> return value yields the best:
5353

54-
best=runGep(1000, 1000,4,10,utilized_syms,operators, callbacks, nodes, x_data,y_data, connection_syms, gep_params;
54+
best,history =runGep(1000, 1000,4,10,utilized_syms,operators, callbacks, nodes, x_data,y_data, connection_syms, gep_params;
5555
loss_fun_str="mse", opt_method_const=:cg, hof=1)
5656
@show string(best[1].fitness)
5757
@show string(best[1].compiled_function)

test/Main_min_with_csv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ gene_count = 3
8484
head_len = 4
8585

8686
best=runGep(epochs, pop_size, gene_count, head_len, utilized_syms,operators, callbacks, nodes, x_data',y_data, connection_syms, gep_params;
87-
loss_fun_str="mse",x_data_test=x_data_test', y_data_test=y_data_test ,opt_method_const=:cg, hof=1)
87+
8888

8989
#Show the result of the optimization
9090
@show ("Fitness: (loss-fun): ", best[1].fitness)

test/paper_test.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ end
2121
function create_symbol_config(
2222
features_names::Vector{String},
2323
constants::Vector{T};
24-
feature_dims::Dict{String, Vector{Float16}} = Dict{String,Vector{Float16}}(),
25-
phy_constants::Union{Dict{String, Vector{Float16}}, Nothing} = nothing
24+
feature_dims::Dict{String, Vector{Float16}} = Dict{String,Vector{Float16}}()
2625
) where T <: AbstractFloat
2726
operators_djl = OperatorEnum(; binary_operators=[*,/,+,-], unary_operators=[sqr,sqrt,sin,cos,exp,log])
2827
nodes_djl = OrderedDict{Int8, Any}()
@@ -240,7 +239,7 @@ function main()
240239
)
241240

242241
start_time = time_ns()
243-
best=runGep(pop_size, generations,
242+
best,_=runGep(pop_size, generations,
244243
gene_count,head_length,utilized_syms, config.operators_djl,
245244
config.callbacks,
246245
config.nodes_djl,

tutorial/JGEP_CSV_demo.ipynb

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tutorial/JGEP_demo.ipynb

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)