@@ -13,6 +13,7 @@ import DiffOpt
1313import Ipopt
1414import ChainRulesCore
1515import Flux
16+ import MLDatasets
1617import Statistics
1718import Base. Iterators: repeated
1819using LinearAlgebra
@@ -22,67 +23,46 @@ using LinearAlgebra
2223# Define a relu through an optimization problem solved by a quadratic solver.
2324# Return the solution of the problem.
2425function matrix_relu (
25- y:: AbstractArray{T} ;
26+ y:: Matrix ;
2627 model = Model (() -> DiffOpt. diff_optimizer (Ipopt. Optimizer))
27- ) where T
28- _x = zeros (size (y))
29- N = length (y[:, 1 ])
28+ )
29+ N, M = size (y)
3030 empty! (model)
3131 set_silent (model)
32- @variable (model, x[1 : N] >= 0 )
33- for i in 1 : size (y, 2 )
34- @objective (
35- model,
36- Min,
37- dot (x, x) - 2 dot (y[:, i], x)
38- )
39- optimize! (model)
40- _x[:, i] = value .(x)
41- end
42- return _x
32+ @variable (model, x[1 : N, 1 : M] >= 0 )
33+ @objective (model, Min, x[:]' x[:] - 2 y[:]' x[:])
34+ optimize! (model)
35+ return value .(x)
4336end
4437
4538
4639# Define the backward differentiation rule, for the function we defined above.
47- function ChainRulesCore. rrule (
48- :: typeof (matrix_relu),
49- y:: AbstractArray{T} ;
40+ function ChainRulesCore. rrule (:: typeof (matrix_relu), y:: Matrix{T} ) where T
5041 model = Model (() -> DiffOpt. diff_optimizer (Ipopt. Optimizer))
51- ) where T
5242 pv = matrix_relu (y, model = model)
5343 function pullback_matrix_relu (dl_dx)
5444 # # some value from the backpropagation (e.g., loss) is denoted by `l`
5545 # # so `dl_dy` is the derivative of `l` wrt `y`
5646 x = model[:x ] # # load decision variable `x` into scope
5747 dl_dy = zeros (T, size (dl_dx))
58- dl_dq = zeros (T, size (dl_dx)) # # for step-by-step explanation
59- for i in 1 : size (y, 2 )
60- # # set sensitivities
61- MOI. set .(
62- model,
63- DiffOpt. BackwardInVariablePrimal (),
64- x,
65- dl_dx[:, i]
66- )
67- # # compute grad
68- DiffOpt. backward (model)
69- # # return gradient wrt objective function parameters
70- obj_exp = MOI. get (
71- model,
72- DiffOpt. BackwardOutObjective ()
73- )
74- dl_dq[:, i] = JuMP. coefficient .(obj_exp, x) # # coeff of `x` in q'x = -2y'x
75- dq_dy = - 2 # # dq/dy = -2
76- dl_dy[:, i] = dl_dq[:, i] * dq_dy
77- end
48+ dl_dq = zeros (T, size (dl_dx))
49+ # # set sensitivities
50+ MOI. set .(model, DiffOpt. BackwardInVariablePrimal (), x[:], dl_dx[:])
51+ # # compute grad
52+ DiffOpt. backward (model)
53+ # # return gradient wrt objective function parameters
54+ obj_exp = MOI. get (model, DiffOpt. BackwardOutObjective ())
55+ # # coeff of `x` in q'x = -2y'x
56+ dl_dq[:] .= JuMP. coefficient .(obj_exp, x[:])
57+ dq_dy = - 2 # # dq/dy = -2
58+ dl_dy[:] .= dl_dq[:] * dq_dy
7859 return (ChainRulesCore. NoTangent (), dl_dy,)
7960 end
8061 return pv, pullback_matrix_relu
8162end
8263
8364# For more details about backpropagation, visit [Introduction, ChainRulesCore.jl](https://juliadiff.org/ChainRulesCore.jl/dev/).
8465# ## prepare data
85- import MLDatasets
8666N = 1000
8767imgs = MLDatasets. MNIST. traintensor (1 : N)
8868labels = MLDatasets. MNIST. trainlabels (1 : N);
@@ -99,7 +79,7 @@ test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);
9979
10080# Network structure
10181
102- inner = 15
82+ inner = 10
10383
10484m = Flux. Chain (
10585 Flux. Dense (784 , inner), # 784 being image linear dimension (28 x 28)
@@ -112,7 +92,8 @@ m = Flux.Chain(
11292# The original data is repeated `epochs` times because `Flux.train!` only
11393# loops through the data set once
11494
115- epochs = 5
95+ epochs = 50 # ~1 minute (i7 8th gen with 16gb RAM)
96+ # # epochs = 100 # leads to 77.8% in about 2 minutes
11697
11798dataset = repeated ((train_X, train_Y), epochs);
11899
0 commit comments