diff --git a/BREAKING.md b/BREAKING.md index aa965fac..9e818fc0 100644 --- a/BREAKING.md +++ b/BREAKING.md @@ -2,6 +2,40 @@ This document describes breaking changes in CTModels releases and how to migrate your code. +## [0.9.6] - 2026-03-10 + +**No breaking changes** - This release adds a dedicated costate time grid while maintaining full backward compatibility. + +### New Features (Non-Breaking) + +- **4-Grid Time System**: `build_solution` now supports 4 independent time grids + - New signature: `build_solution(ocp, T_state, T_control, T_costate, T_path, X, U, v, P; ...)` + - Legacy signature preserved: `build_solution(ocp, T, X, U, v, P; ...)` still works + - Automatic grid optimization when all grids are identical + +- **Costate Grid Independence**: Costate now has its own dedicated time grid + - `time_grid(sol, :costate)` returns costate-specific grid + - `clean_component_symbols((:costate,))` → `(:costate,)` (was `(:state,)` before) + - Enables different discretizations for state and costate (e.g., symplectic integrators) + +- **Enhanced Serialization**: Multi-grid format includes `time_grid_costate` + - Backward compatible: old files without `time_grid_costate` use `T_state` as fallback + - Forward compatible: new files with 4 grids work with updated readers + +### Migration Notes + +**No action required** - All existing code continues to work unchanged: + +```julia +# Legacy single-grid code (still works) +sol = build_solution(ocp, T, X, U, v, P; objective=obj, ...) + +# New multi-grid code (optional) +sol = build_solution(ocp, T_state, T_control, T_costate, T_path, X, U, v, P; objective=obj, ...) +``` + +The package automatically detects and handles both formats. All tests pass (3324/3324). + ## [0.9.5] - 2026-03-09 **No breaking changes** - This release focuses on internal API cleanup with no impact on public functionality. diff --git a/CHANGELOG.md b/CHANGELOG.md index 1355a485..387b43b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.6] - 2026-03-10 + +### Added + +- **Dedicated Costate Time Grid**: Reintroduced independent `T_costate` time grid for costate trajectories + - `build_solution` now accepts 4 independent time grids: `T_state`, `T_control`, `T_costate`, `T_path` + - Costate can now use a different discretization from state (e.g., for symplectic integrators) + - `MultipleTimeGridModel` extended to include `:costate` grid + - `clean_component_symbols` updated to map `:costate` → `:costate` (own grid) + - `time_grid(sol, :costate)` now returns the costate-specific grid + - All tests passing (3324/3324) + +- **Enhanced Serialization**: Multi-grid format now includes `time_grid_costate` + - JSON/JLD export includes dedicated costate grid + - Backward compatibility: files without `time_grid_costate` use `T_state` as fallback + - Automatic format detection and conversion + +- **Comprehensive Documentation**: Added detailed docstrings explaining time grid semantics + - `build_solution`: 173 lines of detailed documentation on 4-grid system + - `_serialize_solution`: 128 lines explaining serialization formats + - `_discretize_all_components`: 41 lines on grid-component associations + - Complete examples and usage patterns + +### Changed + +- **Time Grid Validation**: `time_grid` getter now accepts `:costate` for both `UnifiedTimeGridModel` and `MultipleTimeGridModel` +- **Legacy Signature**: `build_solution(ocp, T, X, U, v, P; ...)` now forwards to 4-grid version with `T_state = T_control = T_costate = T_path = T` +- **Plotting**: Costate now maps to its dedicated grid in `_map_to_time_grid_component` + +### Fixed + +- **Grid Optimization**: Solutions with identical grids automatically use `UnifiedTimeGridModel` for memory efficiency +- **Test Coverage**: All multi-grid tests updated to use 4-grid signature and verify costate grid independence + ## [0.9.5] - 2026-03-09 ### Fixed diff --git a/Project.toml b/Project.toml index d9838b79..5de1d4bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CTModels" uuid = "34c4fa32-2049-4079-8329-de33c2a22e2d" -version = "0.9.5" +version = "0.9.6" authors = ["Olivier Cots "] [deps] diff --git a/ext/CTModelsJSON.jl b/ext/CTModelsJSON.jl index 968e0be6..ae86a4a9 100644 --- a/ext/CTModelsJSON.jl +++ b/ext/CTModelsJSON.jl @@ -5,6 +5,42 @@ using DocStringExtensions using JSON3 +# ============================================================================ +# Private helpers for JSON matrix conversion +# ============================================================================ + +# Liste des champs matriciels à convertir +const _MATRIX_FIELDS = ["state", "control", "costate"] +const _OPTIONAL_MATRIX_FIELDS = [ + "path_constraints_dual", + "state_constraints_lb_dual", + "state_constraints_ub_dual", + "control_constraints_lb_dual", + "control_constraints_ub_dual", +] + +""" +Convert Matrix fields to Vector{Vector} for JSON3 export. + +JSON3 flattens Matrix{Float64} into 1D arrays, losing the 2D structure. +This function converts all matrix fields to Vector{Vector} format to preserve dimensions. +""" +function _convert_matrices_for_json!(blob::Dict) + # Convert required matrix fields + for key in _MATRIX_FIELDS + if haskey(blob, key) && blob[key] isa Matrix + blob[key] = CTModels.Utils.matrix2vec(blob[key], 1) + end + end + + # Convert optional matrix fields (can be nothing) + for key in _OPTIONAL_MATRIX_FIELDS + if haskey(blob, key) && !isnothing(blob[key]) && blob[key] isa Matrix + blob[key] = CTModels.Utils.matrix2vec(blob[key], 1) + end + end +end + # ============================================================================ # Private helper: broadcast with Nothing fallback # ============================================================================ @@ -150,36 +186,14 @@ julia> export_ocp_solution(JSON3Tag(), sol; filename="mysolution") function CTModels.export_ocp_solution( ::CTModels.JSON3Tag, sol::CTModels.Solution; filename::String ) - T = CTModels.time_grid(sol) - - blob = Dict( - "time_grid" => CTModels.time_grid(sol), - "state" => _apply_over_grid(CTModels.state(sol), T), - "control" => _apply_over_grid(CTModels.control(sol), T), - "variable" => CTModels.variable(sol), - "costate" => _apply_over_grid(CTModels.costate(sol), T), - "objective" => CTModels.objective(sol), - "iterations" => CTModels.iterations(sol), - "constraints_violation" => CTModels.constraints_violation(sol), - "message" => CTModels.message(sol), - "status" => CTModels.status(sol), - "successful" => CTModels.successful(sol), - "path_constraints_dual" => _apply_over_grid(CTModels.path_constraints_dual(sol), T), - "state_constraints_lb_dual" => - _apply_over_grid(CTModels.state_constraints_lb_dual(sol), T), - "state_constraints_ub_dual" => - _apply_over_grid(CTModels.state_constraints_ub_dual(sol), T), - "control_constraints_lb_dual" => - _apply_over_grid(CTModels.control_constraints_lb_dual(sol), T), - "control_constraints_ub_dual" => - _apply_over_grid(CTModels.control_constraints_ub_dual(sol), T), - "boundary_constraints_dual" => CTModels.boundary_constraints_dual(sol), # ctVector or Nothing - "variable_constraints_lb_dual" => CTModels.variable_constraints_lb_dual(sol), # ctVector or Nothing - "variable_constraints_ub_dual" => CTModels.variable_constraints_ub_dual(sol), # ctVector or Nothing - ) + # Use unified serialization that handles both unified and multiple time grids + blob = CTModels.OCP._serialize_solution(sol) + + # Convert Matrix → Vector{Vector} for JSON (to avoid JSON3 flattening) + _convert_matrices_for_json!(blob) # Serialize infos and get Symbol type metadata - infos_serialized, symbol_keys = _serialize_infos(CTModels.infos(sol)) + infos_serialized, symbol_keys = _serialize_infos(blob["infos"]) blob["infos"] = infos_serialized blob["infos_symbol_keys"] = symbol_keys @@ -191,55 +205,35 @@ function CTModels.export_ocp_solution( end """ -$(TYPEDSIGNATURES) - -Convert JSON3 array data to `Matrix{Float64}` for trajectory import. - -# Context +Convert a JSON field (Vector{Vector} via stack) to Matrix{Float64}. -When importing JSON data, `stack(blob[field]; dims=1)` returns different types -depending on the dimensionality of the original trajectory: -- **1D trajectories** (e.g., scalar control): `stack()` → `Vector{Float64}` -- **Multi-D trajectories** (e.g., 2D state): `stack()` → `Matrix{Float64}` - -This function normalizes both cases to `Matrix{Float64}` as required by `build_solution`. +JSON exports matrices as Vector{Vector}. After `stack(blob[field]; dims=1)`, +we get either a Matrix (multi-D) or Vector (1D). This normalizes to Matrix. # Arguments -- `data`: Output from `stack(blob[field]; dims=1)`, either `Vector` or `Matrix` +- `blob_field`: JSON array field (Vector of Vectors) # Returns -- `Matrix{Float64}`: Properly shaped matrix `(n_time_points, n_dim)` for `build_solution` - -# Implementation Details - -- **Vector case**: Converts `Vector{Float64}` of length `n` to `Matrix{Float64}(n, 1)` - using `reduce(hcat, data)'` to preserve time-series ordering -- **Matrix case**: Direct conversion to `Matrix{Float64}` - -# Examples +- `Matrix{Float64}`: (n_time_points, n_dim) +""" +function _json_to_matrix(blob_field)::Matrix{Float64} + stacked = stack(blob_field; dims=1) + # 1D case: stack() returns Vector → reshape to (n, 1) Matrix + # Multi-D case: stack() returns Matrix → use directly + return stacked isa Vector ? reshape(stacked, :, 1) : Matrix{Float64}(stacked) +end -```julia -# 1D control trajectory (101 time points) -control_data = [5.99, 5.93, ..., -5.99] # Vector{Float64} -control_matrix = _json_array_to_matrix(control_data) -# → Matrix{Float64}(101, 1) +""" +Convert an optional JSON field to Matrix{Float64} or nothing. -# 2D state trajectory (101 time points, 2 dimensions) -state_data = [1.0 2.0; 1.1 2.1; ...] # Matrix{Float64}(101, 2) -state_matrix = _json_array_to_matrix(state_data) -# → Matrix{Float64}(101, 2) -``` +# Arguments +- `blob_field`: JSON array field or nothing -# See Also -- Test coverage: `test/suite/serialization/test_export_import.jl` - (testset "JSON stack() behavior investigation") +# Returns +- `Matrix{Float64}` or `nothing` """ -function _json_array_to_matrix(data)::Matrix{Float64} - if data isa Vector - return Matrix{Float64}(reduce(hcat, data)') - else - return Matrix{Float64}(data) - end +function _json_to_optional_matrix(blob_field) + return isnothing(blob_field) ? nothing : _json_to_matrix(blob_field) end """ @@ -275,45 +269,17 @@ function CTModels.import_ocp_solution( json_string = read(filename * ".json", String) blob = JSON3.read(json_string) - # get state - X = _json_array_to_matrix(stack(blob["state"]; dims=1)) - - # get control - U = _json_array_to_matrix(stack(blob["control"]; dims=1)) - - # get costate - P = _json_array_to_matrix(stack(blob["costate"]; dims=1)) + # Convert JSON arrays (Vector{Vector}) back to Matrix{Float64} + X = _json_to_matrix(blob["state"]) + U = _json_to_matrix(blob["control"]) + P = _json_to_matrix(blob["costate"]) - # get dual path constraints: convert to matrix - path_constraints_dual = if isnothing(blob["path_constraints_dual"]) - nothing - else - _json_array_to_matrix(stack(blob["path_constraints_dual"]; dims=1)) - end - - # get state constraints (and dual): convert to matrix - state_constraints_lb_dual = if isnothing(blob["state_constraints_lb_dual"]) - nothing - else - _json_array_to_matrix(stack(blob["state_constraints_lb_dual"]; dims=1)) - end - state_constraints_ub_dual = if isnothing(blob["state_constraints_ub_dual"]) - nothing - else - _json_array_to_matrix(stack(blob["state_constraints_ub_dual"]; dims=1)) - end - - # get control constraints (and dual): convert to matrix - control_constraints_lb_dual = if isnothing(blob["control_constraints_lb_dual"]) - nothing - else - _json_array_to_matrix(stack(blob["control_constraints_lb_dual"]; dims=1)) - end - control_constraints_ub_dual = if isnothing(blob["control_constraints_ub_dual"]) - nothing - else - _json_array_to_matrix(stack(blob["control_constraints_ub_dual"]; dims=1)) - end + # Convert optional dual matrices + path_constraints_dual = _json_to_optional_matrix(blob["path_constraints_dual"]) + state_constraints_lb_dual = _json_to_optional_matrix(blob["state_constraints_lb_dual"]) + state_constraints_ub_dual = _json_to_optional_matrix(blob["state_constraints_ub_dual"]) + control_constraints_lb_dual = _json_to_optional_matrix(blob["control_constraints_lb_dual"]) + control_constraints_ub_dual = _json_to_optional_matrix(blob["control_constraints_ub_dual"]) # get dual of boundary constraints: no conversion needed boundary_constraints_dual = blob["boundary_constraints_dual"] @@ -364,11 +330,19 @@ function CTModels.import_ocp_solution( # Add time grid data (format detection handled by helper) if haskey(blob, "time_grid_state") - # New format: multiple time grids + # Multiple time grids format data["time_grid_state"] = blob.time_grid_state data["time_grid_control"] = blob.time_grid_control - data["time_grid_costate"] = blob.time_grid_costate - data["time_grid_dual"] = blob.time_grid_dual + # Support time_grid_costate (backward compatibility: if missing, will use T_state in reconstruction) + if haskey(blob, "time_grid_costate") + data["time_grid_costate"] = blob.time_grid_costate + end + # Support both new (time_grid_path) and legacy (time_grid_dual) keys + if haskey(blob, "time_grid_path") + data["time_grid_path"] = blob.time_grid_path + elseif haskey(blob, "time_grid_dual") + data["time_grid_path"] = blob.time_grid_dual + end else # Legacy format: single time grid data["time_grid"] = blob.time_grid diff --git a/ext/plot.jl b/ext/plot.jl index 02c02cb4..90c55940 100644 --- a/ext/plot.jl +++ b/ext/plot.jl @@ -1590,10 +1590,10 @@ function _map_to_time_grid_component(sym::Symbol)::Symbol :time => error("Internal error: :time should not be mapped") :state => :state :control => :control - :costate => :costate + :costate => :costate # Costate has its own grid :control_norm => :control # Map control_norm to control for time grid - :path_constraint => :state # Map path_constraint to state for time grid - :dual_path_constraint => :dual # Map dual_path_constraint to dual for time grid + :path_constraint => :path # Path constraints use the path grid + :dual_path_constraint => :path # Path constraint duals use the path grid _ => error("Internal error: unknown component $sym for time grid mapping") end end diff --git a/src/OCP/Building/discretization_utils.jl b/src/OCP/Building/discretization_utils.jl index 38885ed3..48eb2490 100644 --- a/src/OCP/Building/discretization_utils.jl +++ b/src/OCP/Building/discretization_utils.jl @@ -87,3 +87,4 @@ See also: `_discretize_function` function _discretize_dual(dual_func::Union{Function,Nothing}, T, dim::Int=-1) return isnothing(dual_func) ? nothing : _discretize_function(dual_func, T, dim) end + diff --git a/src/OCP/Building/solution.jl b/src/OCP/Building/solution.jl index 6f3b843a..cd597d84 100644 --- a/src/OCP/Building/solution.jl +++ b/src/OCP/Building/solution.jl @@ -1,50 +1,183 @@ """ $(TYPEDSIGNATURES) -Build a solution from the optimal control problem, the time grid, the state, control, variable, and dual variables. +Build a solution from an optimal control problem with independent time grids for each component. + +This function constructs a `Solution` object by assembling trajectory data (state, control, costate, +path constraint duals) defined on potentially different time discretizations. The solution automatically +creates interpolated functions to evaluate trajectories at arbitrary time points, and optimizes storage +when all grids are identical. + +# Time Grid Semantics + +The solution supports **four independent time grids**, each associated with a specific trajectory component: + +- **`T_state`**: Time grid for the state trajectory `X` and state box constraint duals + - Defines discretization points where state values are known + - State box constraint duals (`state_constraints_lb_dual`, `state_constraints_ub_dual`) share this grid + +- **`T_control`**: Time grid for the control trajectory `U` and control box constraint duals + - Defines discretization points where control values are known + - Control box constraint duals (`control_constraints_lb_dual`, `control_constraints_ub_dual`) share this grid + - May differ from `T_state` (e.g., coarser discretization for piecewise constant controls) + +- **`T_costate`**: Time grid for the costate (adjoint) trajectory `P` + - Defines discretization points where costate values are known + - Independent from state grid to accommodate different numerical schemes + - Example: symplectic integrators may use different grids for state and costate + +- **`T_path`**: Time grid for path constraint duals (can be `nothing`) + - Defines discretization points for path constraint dual variables + - Set to `nothing` if no path constraints exist + - When `nothing`, internally defaults to `T_state` for consistency + +**Grid Optimization**: If all non-nothing grids are identical, the solution uses `UnifiedTimeGridModel` +for memory efficiency. Otherwise, it uses `MultipleTimeGridModel` to store each grid separately. + +# Trajectory Data Formats + +Trajectory data (`X`, `U`, `P`, `path_constraints_dual`) can be provided in two formats: + +1. **Matrix format**: `Matrix{Float64}` with dimensions `(n_points, n_dim)` + - Each row corresponds to a time point in the associated grid + - Each column corresponds to a component dimension + - Example: `X` is `(length(T_state), state_dimension(ocp))` + +2. **Function format**: `Function` that takes time `t::Float64` and returns a vector + - Allows analytical or pre-interpolated trajectories + - Function signature: `t -> Vector{Float64}` of appropriate dimension + - Useful for exact solutions or when data is already interpolated # Arguments -- `ocp::Model`: the optimal control problem. -- `T::Vector{Float64}`: the time grid. -- `X::Matrix{Float64}`: the state trajectory. -- `U::Matrix{Float64}`: the control trajectory. -- `v::Vector{Float64}`: the variable trajectory. -- `P::Matrix{Float64}`: the costate trajectory. -- `objective::Float64`: the objective value. -- `iterations::Int`: the number of iterations. -- `constraints_violation::Float64`: the constraints violation. -- `message::String`: the message associated to the status criterion. -- `status::Symbol`: the status criterion. -- `successful::Bool`: the successful status. -- `path_constraints_dual::Matrix{Float64}`: the dual of the path constraints. -- `boundary_constraints_dual::Vector{Float64}`: the dual of the boundary constraints. -- `state_constraints_lb_dual::Matrix{Float64}`: the lower bound dual of the state constraints. -- `state_constraints_ub_dual::Matrix{Float64}`: the upper bound dual of the state constraints. -- `control_constraints_lb_dual::Matrix{Float64}`: the lower bound dual of the control constraints. -- `control_constraints_ub_dual::Matrix{Float64}`: the upper bound dual of the control constraints. -- `variable_constraints_lb_dual::Vector{Float64}`: the lower bound dual of the variable constraints. -- `variable_constraints_ub_dual::Vector{Float64}`: the upper bound dual of the variable constraints. -- `infos::Dict{Symbol,Any}`: additional solver information dictionary. +## Required Positional Arguments + +- `ocp::Model`: The optimal control problem model defining dimensions and structure +- `T_state::Vector{Float64}`: Time grid for state trajectory (must be strictly increasing) +- `T_control::Vector{Float64}`: Time grid for control trajectory (must be strictly increasing) +- `T_costate::Vector{Float64}`: Time grid for costate trajectory (must be strictly increasing) +- `T_path::Union{Vector{Float64},Nothing}`: Time grid for path constraint duals (or `nothing`) +- `X::Union{Matrix{Float64},Function}`: State trajectory data +- `U::Union{Matrix{Float64},Function}`: Control trajectory data +- `v::Vector{Float64}`: Variable values (static optimization variables, not time-dependent) +- `P::Union{Matrix{Float64},Function}`: Costate (adjoint) trajectory data + +## Required Keyword Arguments + +- `objective::Float64`: Optimal objective function value +- `iterations::Int`: Number of solver iterations performed +- `constraints_violation::Float64`: Maximum constraint violation (feasibility measure) +- `message::String`: Solver status message (e.g., "Solve_Succeeded", "Iteration_Limit") +- `status::Symbol`: Solver termination status (e.g., `:Solve_Succeeded`, `:Iteration_Limit`) +- `successful::Bool`: Whether the solve was successful (true/false) + +## Optional Keyword Arguments (Dual Variables) + +All dual variable arguments default to `nothing` if not provided: + +- `path_constraints_dual::Union{Matrix{Float64},Function,Nothing}`: Path constraint duals on `T_path` grid +- `boundary_constraints_dual::Union{Vector{Float64},Nothing}`: Boundary constraint duals (time-independent) +- `state_constraints_lb_dual::Union{Matrix{Float64},Nothing}`: State lower bound duals on `T_state` grid +- `state_constraints_ub_dual::Union{Matrix{Float64},Nothing}`: State upper bound duals on `T_state` grid +- `control_constraints_lb_dual::Union{Matrix{Float64},Nothing}`: Control lower bound duals on `T_control` grid +- `control_constraints_ub_dual::Union{Matrix{Float64},Nothing}`: Control upper bound duals on `T_control` grid +- `variable_constraints_lb_dual::Union{Vector{Float64},Nothing}`: Variable lower bound duals (time-independent) +- `variable_constraints_ub_dual::Union{Vector{Float64},Nothing}`: Variable upper bound duals (time-independent) +- `infos::Dict{Symbol,Any}`: Additional solver-specific information (default: empty dict) # Returns -- `sol::Solution`: the optimal control solution. +- `sol::Solution`: Complete solution object with interpolated trajectory functions and metadata + +# Example + +```julia +using CTModels + +# Build OCP +ocp = Model(...) +state!(ocp, 2) +control!(ocp, 1) +# ... define dynamics, objective, etc. + +# Define independent time grids +T_state = collect(LinRange(0.0, 1.0, 101)) # Fine state grid (101 points) +T_control = collect(LinRange(0.0, 1.0, 51)) # Coarser control grid (51 points) +T_costate = collect(LinRange(0.0, 1.0, 76)) # Custom costate grid (76 points) +T_path = collect(LinRange(0.0, 1.0, 61)) # Path constraint grid (61 points) + +# Trajectory data (matrix format) +X = rand(101, 2) # State on T_state grid +U = rand(51, 1) # Control on T_control grid +P = rand(76, 2) # Costate on T_costate grid +v = [0.5, 1.2] # Static variables + +# Build solution +sol = build_solution( + ocp, + T_state, T_control, T_costate, T_path, + X, U, v, P; + objective=1.23, + iterations=50, + constraints_violation=1e-8, + message="Optimal", + status=:first_order, + successful=true +) + +# Access trajectories (automatically interpolated) +x_at_t = state(sol)(0.5) # Interpolated from T_state grid +u_at_t = control(sol)(0.5) # Interpolated from T_control grid +p_at_t = costate(sol)(0.5) # Interpolated from T_costate grid + +# Query time grids +time_grid(sol, :state) # Returns T_state +time_grid(sol, :control) # Returns T_control +time_grid(sol, :costate) # Returns T_costate +``` # Notes -The dimensions of box constraint dual variables (`state_constraints_*_dual`, `control_constraints_*_dual`, -`variable_constraints_*_dual`) correspond to the **state/control/variable dimension**, not the number of -constraint declarations. If multiple constraints are declared on the same component (e.g., `x₂(t) ≤ 1.2` -and `x₂(t) ≤ 2.0`), only the last bound value is retained, and a warning is emitted during model construction. +## Box Constraint Dual Dimensions + +The dimensions of box constraint dual variables correspond to the **component dimension**, not the +number of constraint declarations: + +- `state_constraints_*_dual`: Dimension `(length(T_state), state_dimension(ocp))` +- `control_constraints_*_dual`: Dimension `(length(T_control), control_dimension(ocp))` +- `variable_constraints_*_dual`: Dimension `variable_dimension(ocp)` + +If multiple constraints are declared on the same component (e.g., `x₂(t) ≤ 1.2` and `x₂(t) ≤ 2.0`), +only the last bound value is retained, and a warning is emitted during model construction. + +## Grid Validation + +All time grids must be: +- Strictly increasing: `T[i] < T[i+1]` for all `i` +- Non-empty: At least one time point +- Finite: No `Inf` or `NaN` values + +The function automatically validates and fixes grids (e.g., converts ranges to vectors). + +## Memory Optimization +When all grids are identical, the solution uses `UnifiedTimeGridModel` to store a single grid, +reducing memory overhead. This is detected automatically. + +## Backward Compatibility + +A legacy signature `build_solution(ocp, T, X, U, v, P; ...)` exists for single-grid solutions. +It internally calls this multi-grid version with `T_state = T_control = T_costate = T_path = T`. + +See also: `Solution`, `UnifiedTimeGridModel`, `MultipleTimeGridModel`, +`time_grid`, `state`, `control`, `costate` """ function build_solution( ocp::Model, T_state::Vector{Float64}, T_control::Vector{Float64}, T_costate::Vector{Float64}, - T_dual::Union{Vector{Float64},Nothing}, + T_path::Union{Vector{Float64},Nothing}, X::TX, U::TU, v::Vector{Float64}, @@ -80,10 +213,10 @@ function build_solution( T_state = _validate_and_fix_time_grid(T_state, "state") T_control = _validate_and_fix_time_grid(T_control, "control") T_costate = _validate_and_fix_time_grid(T_costate, "costate") - T_dual = isnothing(T_dual) ? nothing : _validate_and_fix_time_grid(T_dual, "dual") + T_path = isnothing(T_path) ? nothing : _validate_and_fix_time_grid(T_path, "path") # Detect if all non-nothing grids are identical - non_nothing_grids = filter(g -> !isnothing(g), [T_state, T_control, T_costate, T_dual]) + non_nothing_grids = filter(g -> !isnothing(g), [T_state, T_control, T_costate, T_path]) all_identical = length(non_nothing_grids) <= 1 || all(g -> g == first(non_nothing_grids), non_nothing_grids) @@ -92,19 +225,19 @@ function build_solution( time_grid = if all_identical UnifiedTimeGridModel(first(non_nothing_grids)) else - # For dual grid, use T_state if T_dual is nothing (path constraints share state grid) - T_dual_safe = isnothing(T_dual) ? T_state : T_dual + # For path grid, use T_state if T_path is nothing (path constraints share state grid) + T_path_safe = isnothing(T_path) ? T_state : T_path MultipleTimeGridModel(; state=T_state, control=T_control, costate=T_costate, - path=T_dual_safe, - dual=T_dual_safe, + path=T_path_safe, ) end # Build interpolated functions for state, control, and costate # Using unified API with validation and deepcopy+scalar wrapping + # Note: costate uses its own grid (T_costate) fx = build_interpolated_function(X, T_state, dim_x, TX; expected_dim=dim_x) fu = build_interpolated_function(U, T_control, dim_u, TU; expected_dim=dim_u) fp = build_interpolated_function( @@ -114,40 +247,45 @@ function build_solution( # nonlinear constraints and dual variables (optional, can be nothing) # Note: dim is set to dim_path_constraints_nl for proper scalar wrapping + # Path constraints duals share the path grid (T_path) fpcd = build_interpolated_function( path_constraints_dual, - T_dual, + T_path, dim_path_constraints_nl(ocp), TPCD; allow_nothing=true, ) # box constraints multipliers (optional, can be nothing) + # Note: No expected_dim validation for box constraints because they use + # dim_*_constraints_box which may differ from state/control dimensions + # State box constraint duals share the state grid (T_state) fscbd = build_interpolated_function( state_constraints_lb_dual, - T_dual, - dim_x, + T_state, + dim_state_constraints_box(ocp), Union{Matrix{Float64},Nothing}; allow_nothing=true, ) fscud = build_interpolated_function( state_constraints_ub_dual, - T_dual, - dim_x, + T_state, + dim_state_constraints_box(ocp), Union{Matrix{Float64},Nothing}; allow_nothing=true, ) + # Control box constraint duals share the control grid (T_control) fccbd = build_interpolated_function( control_constraints_lb_dual, - T_dual, - dim_u, + T_control, + dim_control_constraints_box(ocp), Union{Matrix{Float64},Nothing}; allow_nothing=true, ) fccud = build_interpolated_function( control_constraints_ub_dual, - T_dual, - dim_u, + T_control, + dim_control_constraints_box(ocp), Union{Matrix{Float64},Nothing}; allow_nothing=true, ) @@ -716,8 +854,10 @@ Return the time grid for a specific component. # Arguments - `sol::Solution`: The solution (unified or multiple time grids) -- `component::Symbol`: The component (:state, :control, :costate, :path, :dual) - Plural forms (:states, :controls, :costates, :duals) are also accepted +- `component::Symbol`: The component (:state, :control, :path) + Also accepted: :costate/:costates (→ :state), :dual/:duals (→ :path), + :state_box_constraint(s) (→ :state), :control_box_constraint(s) (→ :control), + plural forms (:states, :controls) # Returns - `TimesDisc`: The time grid for the specified component @@ -731,9 +871,10 @@ Return the time grid for a specific component. # Examples ```julia-repl -julia> time_grid(sol, :state) # Works for both unified and multiple grids -julia> time_grid(sol, :control) # Works for both unified and multiple grids -julia> time_grid(sol, :states) # Plural form also works +julia> time_grid(sol, :state) # Works for both unified and multiple grids +julia> time_grid(sol, :control) # Works for both unified and multiple grids +julia> time_grid(sol, :costate) # Maps to :state grid +julia> time_grid(sol, :dual) # Maps to :path grid ``` """ function time_grid( @@ -755,13 +896,13 @@ function time_grid( component_clean = clean_component_symbols((component,))[1] # Validate component - if component_clean ∉ (:state, :control, :costate, :path, :dual) + if component_clean ∉ (:state, :control, :costate, :path) # ⚠️ Applying Exception Rule: Invalid component symbol throw( CTBase.Exceptions.IncorrectArgument( "Invalid component for time grid access"; got=string(component), - expected="one of :state, :control, :costate, :path, :dual (or plural forms)", + expected="one of :state, :control, :costate, :path (or aliases like :dual, plural forms)", suggestion="Use time_grid(sol, :state) or another valid component", context="time_grid for UnifiedTimeGridModel", ), @@ -779,8 +920,10 @@ Return the time grid for a specific component in solutions with multiple time gr # Arguments - `sol::Solution`: The solution with multiple time grids -- `component::Symbol`: The component (:state, :control, :costate, :path, :dual) - Plural forms (:states, :controls, :costates, :duals) are also accepted +- `component::Symbol`: The component (:state, :control, :path) + Also accepted: :costate/:costates (→ :state), :dual/:duals (→ :path), + :state_box_constraint(s) (→ :state), :control_box_constraint(s) (→ :control), + plural forms (:states, :controls) # Returns - `TimesDisc`: The time grid for the specified component @@ -790,9 +933,10 @@ Return the time grid for a specific component in solutions with multiple time gr # Examples ```julia-repl -julia> time_grid(sol, :state) # Get state time grid -julia> time_grid(sol, :control) # Get control time grid -julia> time_grid(sol, :states) # Plural form also works +julia> time_grid(sol, :state) # Get state time grid +julia> time_grid(sol, :control) # Get control time grid +julia> time_grid(sol, :costate) # Maps to state time grid +julia> time_grid(sol, :dual) # Maps to path time grid ``` """ function time_grid( @@ -814,13 +958,13 @@ function time_grid( component_clean = clean_component_symbols((component,))[1] # Validate component - if component_clean ∉ (:state, :control, :costate, :path, :dual) + if component_clean ∉ (:state, :control, :costate, :path) # ⚠️ Applying Exception Rule: Invalid component symbol throw( CTBase.Exceptions.IncorrectArgument( "Invalid component for time grid access"; got=string(component), - expected="one of :state, :control, :costate, :path, :dual (or plural forms)", + expected="one of :state, :control, :costate, :path (or aliases like :dual, plural forms)", suggestion="Use time_grid(sol, :state) or another valid component", context="time_grid for MultipleTimeGridModel", ), @@ -868,7 +1012,7 @@ function time_grid( CTBase.Exceptions.IncorrectArgument( "Component must be specified for solutions with multiple time grids"; got="no component specified", - expected="time_grid(sol, :component) where component ∈ {:state, :control, :costate, :path, :dual}", + expected="time_grid(sol, :component) where component ∈ {:state, :control, :path}", suggestion="Specify which time grid to access, e.g., time_grid(sol, :state)", context="time_grid for MultipleTimeGridModel", ), @@ -1129,43 +1273,131 @@ end """ $(TYPEDSIGNATURES) -Serialize a solution into discrete data for export (JLD2, JSON, etc.). +Serialize a solution into discrete data for export to persistent storage (JLD2, JSON, etc.). + +This function converts a `Solution` object (which may contain interpolated functions) into a +fully discrete, serializable representation. All trajectory functions are evaluated on their +respective time grids and stored as matrices. The serialization format automatically adapts +based on whether the solution uses unified or multiple time grids. + +# Serialization Formats + +The function produces two different formats depending on the solution's time grid model: + +## Unified Time Grid Format (Legacy) + +When all grids are identical (`UnifiedTimeGridModel`), produces: +```julia +Dict( + "time_grid" => T, # Single grid for all components + "state" => Matrix, # Discretized on T + "control" => Matrix, # Discretized on T + "costate" => Matrix, # Discretized on T + "path_constraints_dual" => Matrix, # Discretized on T + # ... other fields +) +``` -Extracts all data from a solution and converts it into a serializable format -(matrices, vectors, scalars). Functions are discretized on the time grid. -Uses public getters to access solution fields. +## Multiple Time Grids Format + +When grids differ (`MultipleTimeGridModel`), produces: +```julia +Dict( + "time_grid_state" => T_state, # State-specific grid + "time_grid_control" => T_control, # Control-specific grid + "time_grid_costate" => T_costate, # Costate-specific grid + "time_grid_path" => T_path, # Path constraints grid + "state" => Matrix, # Discretized on T_state + "control" => Matrix, # Discretized on T_control + "costate" => Matrix, # Discretized on T_costate + "path_constraints_dual" => Matrix, # Discretized on T_path + # ... other fields +) +``` # Arguments -- `sol::Solution`: Solution to serialize. + +- `sol::Solution`: Solution object to serialize (may contain functions or matrices) # Returns -- `Dict{String, Any}`: Dictionary containing all discrete data: - - `"time_grid"`: Time grid - - `"state"`, `"control"`, `"costate"`: Discretized matrices - - `"variable"`: Variable vector - - `"objective"`: Scalar value - - Discretized dual functions (can be `nothing`) - - Boundary and variable duals (vectors) - - Solver information -# Notes -- Functions are discretized via `_discretize_function`. -- `nothing` duals are preserved as `nothing`. -- Compatible with `build_solution` for reconstruction. +- `Dict{String, Any}`: Complete serializable dictionary containing: + - **Time grids**: Either single `"time_grid"` or four separate grids + - **Trajectories**: `"state"`, `"control"`, `"costate"` as `Matrix{Float64}` + - **Variable**: `"variable"` as `Vector{Float64}` (time-independent) + - **Objective**: `"objective"` as `Float64` + - **Dual variables**: All constraint duals (can be `nothing` if not present) + - `"path_constraints_dual"`: Path constraint duals on path grid + - `"state_constraints_lb_dual"`, `"state_constraints_ub_dual"`: State box duals on state grid + - `"control_constraints_lb_dual"`, `"control_constraints_ub_dual"`: Control box duals on control grid + - `"boundary_constraints_dual"`: Boundary duals (time-independent vector) + - `"variable_constraints_lb_dual"`, `"variable_constraints_ub_dual"`: Variable duals (vectors) + - **Solver info**: `"iterations"`, `"message"`, `"status"`, `"successful"`, `"constraints_violation"`, `"infos"` + +# Discretization Behavior + +- **Function trajectories**: Evaluated at each point of their associated time grid +- **Matrix trajectories**: Copied as-is (already discrete) +- **Nothing duals**: Preserved as `nothing` in the dictionary +- **Grid association**: Each component is discretized on its correct grid: + - State and state box duals → `T_state` + - Control and control box duals → `T_control` + - Costate → `T_costate` + - Path constraint duals → `T_path` # Example + ```julia -sol = solve(ocp) -data = CTModels._serialize_solution(sol) -# Reconstruction -sol_reconstructed = CTModels.build_solution( - ocp, data["time_grid"], data["state"], data["control"], - data["variable"], data["costate"]; - objective=data["objective"], ... -) +using CTModels + +# Solve OCP with multiple grids +sol = solve(ocp, strategy=MyStrategy()) + +# Serialize to dictionary +data = _serialize_solution(sol) + +# Check format +if haskey(data, "time_grid_state") + # Multiple grids format + println("State grid: ", length(data["time_grid_state"]), " points") + println("Control grid: ", length(data["time_grid_control"]), " points") + println("Costate grid: ", length(data["time_grid_costate"]), " points") +else + # Unified grid format + println("Unified grid: ", length(data["time_grid"]), " points") +end + +# Export to file (handled by extensions) +export_ocp_solution(sol; filename="solution", format=:JLD) + +# Reconstruct from data +sol_reconstructed = _reconstruct_solution_from_data(ocp, data) +``` + +# Notes + +## Backward Compatibility + +The serialization format is designed for backward compatibility: +- Old files with single `"time_grid"` can be read (costate defaults to state grid) +- New files with four grids are forward-compatible with updated readers +- The `_reconstruct_solution_from_data` function handles both formats automatically + +## Memory Efficiency + +When all grids are identical, the unified format avoids storing redundant grid data, +reducing file size and memory usage. + +## Round-Trip Guarantee + +The serialized data is fully compatible with `build_solution` for exact reconstruction: +```julia +data = _serialize_solution(sol) +sol_new = build_solution(ocp, data["time_grid_state"], ...; objective=data["objective"], ...) ``` -See also: `build_solution`, `_discretize_function` +See also: [`build_solution`](@ref), [`_reconstruct_solution_from_data`](@ref), +[`export_ocp_solution`](@ref), [`import_ocp_solution`](@ref) """ function _serialize_solution(sol::Solution)::Dict{String,Any} # Use public getters @@ -1177,85 +1409,71 @@ function _serialize_solution(sol::Solution)::Dict{String,Any} end """ -Serialize solution for unified time grid (legacy format). -""" -function _serialize_solution(::UnifiedTimeGridModel, sol::Solution, dim_x::Int, dim_u::Int) - # Legacy format: single time grid - T = time_grid(sol) +$(TYPEDSIGNATURES) - return Dict( - "time_grid" => T, - "state" => _discretize_function(state(sol), T, dim_x), - "control" => _discretize_function(control(sol), T, dim_u), - "costate" => _discretize_function(costate(sol), T, dim_x), - "variable" => variable(sol), - "objective" => objective(sol), +Discretize all solution components on their respective time grids for serialization. - # Discretize dual functions (can be nothing) - "path_constraints_dual" => _discretize_dual(path_constraints_dual(sol), T), - "state_constraints_lb_dual" => _discretize_dual(state_constraints_lb_dual(sol), T), - "state_constraints_ub_dual" => _discretize_dual(state_constraints_ub_dual(sol), T), - "control_constraints_lb_dual" => - _discretize_dual(control_constraints_lb_dual(sol), T), - "control_constraints_ub_dual" => - _discretize_dual(control_constraints_ub_dual(sol), T), +This internal helper function extracts the common discretization logic shared by both +`UnifiedTimeGridModel` and `MultipleTimeGridModel` serialization. It evaluates all +trajectory functions on their associated time grids and assembles them into a dictionary. - # Boundary and variable duals (vectors, not functions) - "boundary_constraints_dual" => boundary_constraints_dual(sol), - "variable_constraints_lb_dual" => variable_constraints_lb_dual(sol), - "variable_constraints_ub_dual" => variable_constraints_ub_dual(sol), +# Grid-Component Association - # Solver info - "iterations" => iterations(sol), - "message" => message(sol), - "status" => status(sol), - "successful" => successful(sol), - "constraints_violation" => constraints_violation(sol), - "infos" => infos(sol), - ) -end +Each component is discretized on its semantically correct time grid: -""" -Serialize solution for multiple time grids format. -""" -function _serialize_solution(::MultipleTimeGridModel, sol::Solution, dim_x::Int, dim_u::Int) - # Multiple time grids format - T_state = time_grid(sol, :state) - T_control = time_grid(sol, :control) - T_costate = time_grid(sol, :costate) - T_dual = time_grid(sol, :dual) # Same as :path +- **State trajectory** → `T_state` grid +- **Control trajectory** → `T_control` grid +- **Costate trajectory** → `T_costate` grid +- **Path constraint duals** → `T_path` grid +- **State box constraint duals** (lb/ub) → `T_state` grid +- **Control box constraint duals** (lb/ub) → `T_control` grid +- **Boundary/variable duals** → Time-independent (vectors, not discretized) + +# Arguments + +- `sol::Solution`: Solution object containing trajectory functions +- `T_state::Vector{Float64}`: Time grid for state discretization +- `T_control::Vector{Float64}`: Time grid for control discretization +- `T_costate::Vector{Float64}`: Time grid for costate discretization +- `T_path::Vector{Float64}`: Time grid for path constraint dual discretization +- `dim_x::Int`: State dimension (for validation) +- `dim_u::Int`: Control dimension (for validation) + +# Returns - return Dict( - # Multiple time grids - "time_grid_state" => T_state, - "time_grid_control" => T_control, - "time_grid_costate" => T_costate, - "time_grid_dual" => T_dual, +- `Dict{String, Any}`: Dictionary with all discretized components (grids not included) - # Discretized functions with appropriate grids +# Notes + +This function does NOT include time grid data in the returned dictionary. The calling +function (`_serialize_solution` for `UnifiedTimeGridModel` or `MultipleTimeGridModel`) +is responsible for adding the appropriate grid keys. + +See also: [`_serialize_solution`](@ref), [`_discretize_function`](@ref), [`_discretize_dual`](@ref) +""" +function _discretize_all_components( + sol::Solution, + T_state::Vector{Float64}, + T_control::Vector{Float64}, + T_costate::Vector{Float64}, + T_path::Vector{Float64}, + dim_x::Int, + dim_u::Int, +)::Dict{String,Any} + return Dict{String,Any}( "state" => _discretize_function(state(sol), T_state, dim_x), "control" => _discretize_function(control(sol), T_control, dim_u), "costate" => _discretize_function(costate(sol), T_costate, dim_x), "variable" => variable(sol), "objective" => objective(sol), - - # Discretize dual functions with dual grid - "path_constraints_dual" => _discretize_dual(path_constraints_dual(sol), T_dual), - "state_constraints_lb_dual" => - _discretize_dual(state_constraints_lb_dual(sol), T_dual), - "state_constraints_ub_dual" => - _discretize_dual(state_constraints_ub_dual(sol), T_dual), - "control_constraints_lb_dual" => - _discretize_dual(control_constraints_lb_dual(sol), T_dual), - "control_constraints_ub_dual" => - _discretize_dual(control_constraints_ub_dual(sol), T_dual), - - # Boundary and variable duals (vectors, not functions) + "path_constraints_dual" => _discretize_dual(path_constraints_dual(sol), T_path, dim_path_constraints_nl(sol)), + "state_constraints_lb_dual" => _discretize_dual(state_constraints_lb_dual(sol), T_state, dim_state_constraints_box(sol)), + "state_constraints_ub_dual" => _discretize_dual(state_constraints_ub_dual(sol), T_state, dim_state_constraints_box(sol)), + "control_constraints_lb_dual" => _discretize_dual(control_constraints_lb_dual(sol), T_control, dim_control_constraints_box(sol)), + "control_constraints_ub_dual" => _discretize_dual(control_constraints_ub_dual(sol), T_control, dim_control_constraints_box(sol)), "boundary_constraints_dual" => boundary_constraints_dual(sol), "variable_constraints_lb_dual" => variable_constraints_lb_dual(sol), "variable_constraints_ub_dual" => variable_constraints_ub_dual(sol), - - # Solver info "iterations" => iterations(sol), "message" => message(sol), "status" => status(sol), @@ -1264,3 +1482,122 @@ function _serialize_solution(::MultipleTimeGridModel, sol::Solution, dim_x::Int, "infos" => infos(sol), ) end + +""" +$(TYPEDSIGNATURES) + +Serialize solution with unified time grid (legacy single-grid format). + +This method handles solutions where all components share the same time grid. It produces +the legacy format with a single `"time_grid"` key, which is backward-compatible with +older versions of the package. + +# Format Produced + +```julia +Dict( + "time_grid" => T, # Single unified grid + "state" => Matrix, # All components discretized on T + "control" => Matrix, + "costate" => Matrix, + # ... all other fields +) +``` + +# Arguments + +- `::UnifiedTimeGridModel`: Time grid model type (dispatch parameter) +- `sol::Solution`: Solution to serialize +- `dim_x::Int`: State dimension +- `dim_u::Int`: Control dimension + +# Returns + +- `Dict{String, Any}`: Serialized data with single time grid + +# Notes + +This format is used when `build_solution` is called with identical grids for all components, +or when using the legacy single-grid signature. It ensures backward compatibility with files +created before the multi-grid feature was introduced. + +See also: [`_serialize_solution(::MultipleTimeGridModel, ...)`](@ref) +""" +function _serialize_solution(::UnifiedTimeGridModel, sol::Solution, dim_x::Int, dim_u::Int) + # Legacy format: single time grid + T = time_grid(sol) + + # Discretize all components + data = _discretize_all_components(sol, T, T, T, T, dim_x, dim_u) + + # Add time grid + data["time_grid"] = T + + return data +end + +""" +$(TYPEDSIGNATURES) + +Serialize solution with multiple independent time grids (modern format). + +This method handles solutions where different components use different time grids. It produces +the modern format with four separate grid keys (`time_grid_state`, `time_grid_control`, +`time_grid_costate`, `time_grid_path`), preserving the independent discretizations. + +# Format Produced + +```julia +Dict( + "time_grid_state" => T_state, # State-specific grid + "time_grid_control" => T_control, # Control-specific grid + "time_grid_costate" => T_costate, # Costate-specific grid + "time_grid_path" => T_path, # Path constraints grid + "state" => Matrix, # Discretized on T_state + "control" => Matrix, # Discretized on T_control + "costate" => Matrix, # Discretized on T_costate + "path_constraints_dual" => Matrix, # Discretized on T_path + # ... all other fields +) +``` + +# Arguments + +- `::MultipleTimeGridModel`: Time grid model type (dispatch parameter) +- `sol::Solution`: Solution to serialize +- `dim_x::Int`: State dimension +- `dim_u::Int`: Control dimension + +# Returns + +- `Dict{String, Any}`: Serialized data with four independent time grids + +# Notes + +This format is used when `build_solution` is called with different grids for different +components. It allows numerical schemes to use optimal discretizations for each component +(e.g., finer grid for state, coarser for control, custom for costate). + +The reconstruction function `_reconstruct_solution_from_data` detects this format by checking +for the presence of `"time_grid_state"` key and handles it appropriately. + +See also: [`_serialize_solution(::UnifiedTimeGridModel, ...)`](@ref), [`build_solution`](@ref) +""" +function _serialize_solution(::MultipleTimeGridModel, sol::Solution, dim_x::Int, dim_u::Int) + # Multiple time grids format + T_state = time_grid(sol, :state) + T_control = time_grid(sol, :control) + T_costate = time_grid(sol, :costate) + T_path = time_grid(sol, :path) + + # Discretize all components + data = _discretize_all_components(sol, T_state, T_control, T_costate, T_path, dim_x, dim_u) + + # Add multiple time grids + data["time_grid_state"] = T_state + data["time_grid_control"] = T_control + data["time_grid_costate"] = T_costate + data["time_grid_path"] = T_path + + return data +end diff --git a/src/OCP/Types/solution.jl b/src/OCP/Types/solution.jl index 31b624c9..0e4576a5 100644 --- a/src/OCP/Types/solution.jl +++ b/src/OCP/Types/solution.jl @@ -48,11 +48,10 @@ Used when variables have different discretisations (e.g., different grid densiti # Fields - `grids::NamedTuple`: Named tuple with time grids for each component: - - `state::TimesDisc`: State trajectory time grid - - `control::TimesDisc`: Control trajectory time grid - - `costate::TimesDisc`: Costate trajectory time grid - - `path::TimesDisc`: Path constraints and duals time grid - - `dual::TimesDisc`: Alias for path constraints grid (same physical grid) + - `state::TimesDisc`: Time grid for state and state box constraint duals + - `control::TimesDisc`: Time grid for control and control box constraint duals + - `costate::TimesDisc`: Time grid for costate + - `path::TimesDisc`: Time grid for path constraints and their duals # Example @@ -61,8 +60,9 @@ julia> using CTModels julia> T_state = LinRange(0, 1, 101) julia> T_control = LinRange(0, 1, 51) +julia> T_costate = LinRange(0, 1, 76) julia> tg = CTModels.MultipleTimeGridModel( - state=T_state, control=T_control, costate=T_state, path=T_state, dual=T_state + state=T_state, control=T_control, costate=T_costate, path=T_state ) julia> length(tg.grids.state) 101 @@ -70,8 +70,8 @@ julia> length(tg.grids.state) """ struct MultipleTimeGridModel <: AbstractTimeGridModel grids::NamedTuple{ - (:state, :control, :costate, :path, :dual), - Tuple{TimesDisc,TimesDisc,TimesDisc,TimesDisc,TimesDisc}, + (:state, :control, :costate, :path), + Tuple{TimesDisc,TimesDisc,TimesDisc,TimesDisc}, } end @@ -81,11 +81,10 @@ $(TYPEDSIGNATURES) Construct a `MultipleTimeGridModel` with keyword arguments for each component time grid. # Arguments -- `state`: Time grid for state variables -- `control`: Time grid for control variables -- `costate`: Time grid for costate variables -- `path`: Time grid for path constraints -- `dual`: Time grid for dual variables +- `state`: Time grid for state and state box constraint duals +- `control`: Time grid for control and control box constraint duals +- `costate`: Time grid for costate +- `path`: Time grid for path constraints and their duals # Returns - `MultipleTimeGridModel`: A model containing all component time grids @@ -94,12 +93,12 @@ Construct a `MultipleTimeGridModel` with keyword arguments for each component ti ```julia-repl julia> T_state = LinRange(0, 1, 101) julia> T_control = LinRange(0, 1, 51) +julia> T_costate = LinRange(0, 1, 76) julia> mtgm = MultipleTimeGridModel( state=T_state, control=T_control, - costate=T_state, - path=T_state, - dual=T_state + costate=T_costate, + path=T_state ) ``` """ @@ -108,10 +107,9 @@ function MultipleTimeGridModel(; control::TimesDisc, costate::TimesDisc, path::TimesDisc, - dual::TimesDisc, ) return MultipleTimeGridModel(( - state=state, control=control, costate=costate, path=path, dual=dual + state=state, control=control, costate=costate, path=path )) end @@ -124,8 +122,11 @@ $(TYPEDSIGNATURES) Clean and standardize component symbols for time grid access. # Behavior -- Converts plural forms (`:states`, `:costates`, etc.) to their singular equivalents. -- Maps ambiguous terms (`:constraint`, `:constraints`, `:cons`) to `:path`. +- Maps all component symbols to their canonical time grid: `:state`, `:control`, `:costate`, or `:path`. +- `:costate`, `:costates` map to `:costate` (costate has its own grid). +- `:dual`, `:duals`, `:constraint`, `:constraints`, `:cons` map to `:path`. +- `:state_box_constraint(s)` maps to `:state`. +- `:control_box_constraint(s)` maps to `:control`. - Removes duplicate symbols. # Arguments @@ -137,20 +138,25 @@ Clean and standardize component symbols for time grid access. # Example ```julia-repl julia> clean_component_symbols((:states, :controls, :costate, :constraint, :duals)) -# → (:state, :control, :costate, :path, :dual) +# → (:state, :control, :costate, :path) ``` """ function clean_component_symbols(description) - # remove the nouns in plural form + # map all component symbols to their canonical time grid description = replace( description, :states => :state, :costates => :costate, :controls => :control, + :state_box_constraints => :state, + :state_box_constraint => :state, + :control_box_constraints => :control, + :control_box_constraint => :control, :constraints => :path, :constraint => :path, :cons => :path, - :duals => :dual, + :duals => :path, + :dual => :path, ) # remove the duplicates while preserving order seen = Set{Symbol}() diff --git a/src/Serialization/reconstruction_helpers.jl b/src/Serialization/reconstruction_helpers.jl index b8292315..f0910a5f 100644 --- a/src/Serialization/reconstruction_helpers.jl +++ b/src/Serialization/reconstruction_helpers.jl @@ -24,7 +24,7 @@ Reconstruct a solution from imported data, detecting the format (single vs multi - `Solution`: Reconstructed solution with appropriate time grid model # Notes -- If `time_grid_state` key exists, assumes new multiple time grid format +- If `time_grid_state` key exists, assumes multiple time grid format - Otherwise, uses legacy single time grid format - Handles both raw vectors and TimeGridModel objects for legacy format @@ -48,11 +48,14 @@ function _reconstruct_solution_from_data( ) # Detect format and extract time grids if haskey(data, "time_grid_state") - # New format: multiple time grids + # Multiple time grids format T_state = _extract_time_vector(data["time_grid_state"]) T_control = _extract_time_vector(data["time_grid_control"]) - T_costate = _extract_time_vector(data["time_grid_costate"]) - T_dual = _extract_time_vector(data["time_grid_dual"]) + # Backward compatibility: if time_grid_costate is missing, use T_state + T_costate = haskey(data, "time_grid_costate") ? _extract_time_vector(data["time_grid_costate"]) : T_state + # Support both new (time_grid_path) and legacy (time_grid_dual) keys + T_path_key = haskey(data, "time_grid_path") ? "time_grid_path" : "time_grid_dual" + T_path = _extract_time_vector(data[T_path_key]) # Reconstruct solution with multiple time grids return OCP.build_solution( @@ -60,7 +63,7 @@ function _reconstruct_solution_from_data( T_state, T_control, T_costate, - T_dual, + T_path, data["state"], data["control"], _extract_time_vector(data["variable"]), diff --git a/test/suite/ocp/test_solution_multi_grids.jl b/test/suite/ocp/test_solution_multi_grids.jl index c40a771c..d4933cdb 100644 --- a/test/suite/ocp/test_solution_multi_grids.jl +++ b/test/suite/ocp/test_solution_multi_grids.jl @@ -35,22 +35,20 @@ function test_solution_multi_grids() T_state = LinRange(0, 1, 101) T_control = LinRange(0, 1, 51) T_costate = LinRange(0, 1, 76) - T_dual = LinRange(0, 1, 101) + T_path = LinRange(0, 1, 61) mtgm = CTModels.MultipleTimeGridModel( state=T_state, control=T_control, costate=T_costate, - path=T_dual, - dual=T_dual, + path=T_path, ) Test.@test mtgm isa CTModels.MultipleTimeGridModel Test.@test mtgm isa CTModels.AbstractTimeGridModel Test.@test mtgm.grids.state == T_state Test.@test mtgm.grids.control == T_control Test.@test mtgm.grids.costate == T_costate - Test.@test mtgm.grids.path == T_dual - Test.@test mtgm.grids.dual == T_dual + Test.@test mtgm.grids.path == T_path end end @@ -60,32 +58,45 @@ function test_solution_multi_grids() Test.@testset "Component Symbol Cleaning" begin Test.@testset "clean_component_symbols" begin - # Test singular forms (unchanged) + # Test canonical forms (unchanged) Test.@test CTModels.clean_component_symbols((:state,)) == (:state,) Test.@test CTModels.clean_component_symbols((:control,)) == (:control,) - Test.@test CTModels.clean_component_symbols((:costate,)) == (:costate,) Test.@test CTModels.clean_component_symbols((:path,)) == (:path,) - Test.@test CTModels.clean_component_symbols((:dual,)) == (:dual,) - # Test plural forms (converted to singular) + # Test costate maps to costate (has its own grid) + Test.@test CTModels.clean_component_symbols((:costate,)) == (:costate,) + Test.@test CTModels.clean_component_symbols((:costates,)) == (:costate,) + + # Test dual maps to path (shares path grid) + Test.@test CTModels.clean_component_symbols((:dual,)) == (:path,) + Test.@test CTModels.clean_component_symbols((:duals,)) == (:path,) + + # Test plural forms Test.@test CTModels.clean_component_symbols((:states,)) == (:state,) Test.@test CTModels.clean_component_symbols((:controls,)) == (:control,) - Test.@test CTModels.clean_component_symbols((:costates,)) == (:costate,) - Test.@test CTModels.clean_component_symbols((:duals,)) == (:dual,) # Test ambiguous terms (mapped to :path) Test.@test CTModels.clean_component_symbols((:constraint,)) == (:path,) Test.@test CTModels.clean_component_symbols((:constraints,)) == (:path,) Test.@test CTModels.clean_component_symbols((:cons,)) == (:path,) - # Test mixed input + # Test box constraint aliases + Test.@test CTModels.clean_component_symbols((:state_box_constraint,)) == (:state,) + Test.@test CTModels.clean_component_symbols((:state_box_constraints,)) == (:state,) + Test.@test CTModels.clean_component_symbols((:control_box_constraint,)) == (:control,) + Test.@test CTModels.clean_component_symbols((:control_box_constraints,)) == (:control,) + + # Test mixed input (costate→state, dual→path, so only 3 unique) Test.@test CTModels.clean_component_symbols(( :states, :controls, :constraint, :duals - )) == (:state, :control, :path, :dual) + )) == (:state, :control, :path) # Test duplicate removal Test.@test CTModels.clean_component_symbols((:state, :state)) == (:state,) Test.@test CTModels.clean_component_symbols((:states, :state)) == (:state,) + Test.@test CTModels.clean_component_symbols((:costate, :costate)) == (:costate,) + Test.@test CTModels.clean_component_symbols((:costates, :costate)) == (:costate,) + Test.@test CTModels.clean_component_symbols((:dual, :path)) == (:path,) end end @@ -162,11 +173,11 @@ function test_solution_multi_grids() T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) T_costate = collect(LinRange(0, 1, 76)) - T_dual = collect(LinRange(0, 1, 101)) + T_path = collect(LinRange(0, 1, 61)) X = [1.0 - t/100 for t in 1:101, i in 1:2] U = [sin(2π * t/50) for t in 1:51, i in 1:1] - P = zeros(76, 2) + P = zeros(76, 2) # Costate has its own grid v = Float64[] sol = CTModels.build_solution( @@ -174,7 +185,7 @@ function test_solution_multi_grids() T_state, T_control, T_costate, - T_dual, + T_path, X, U, v, @@ -191,19 +202,20 @@ function test_solution_multi_grids() Test.@test CTModels.time_grid(sol, :state) == T_state Test.@test CTModels.time_grid(sol, :control) == T_control Test.@test CTModels.time_grid(sol, :costate) == T_costate - Test.@test CTModels.time_grid(sol, :dual) == T_dual - Test.@test CTModels.time_grid(sol, :path) == T_dual # Same as dual + Test.@test CTModels.time_grid(sol, :path) == T_path + # Dual maps to path grid + Test.@test CTModels.time_grid(sol, :dual) == T_path end - Test.@testset "Nothing dual grid" begin + Test.@testset "Nothing path grid" begin T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) T_costate = collect(LinRange(0, 1, 76)) - T_dual = nothing + T_path = nothing X = [1.0 - t/100 for t in 1:101, i in 1:2] U = [sin(2π * t/50) for t in 1:51, i in 1:1] - P = zeros(76, 2) + P = zeros(76, 2) # Costate has its own grid v = Float64[] sol = CTModels.build_solution( @@ -211,7 +223,7 @@ function test_solution_multi_grids() T_state, T_control, T_costate, - T_dual, + T_path, X, U, v, @@ -228,8 +240,9 @@ function test_solution_multi_grids() Test.@test CTModels.time_grid(sol, :state) == T_state Test.@test CTModels.time_grid(sol, :control) == T_control Test.@test CTModels.time_grid(sol, :costate) == T_costate - Test.@test CTModels.time_grid(sol, :dual) == T_state # Falls back to state grid + # Path grid falls back to state grid when nothing Test.@test CTModels.time_grid(sol, :path) == T_state + Test.@test CTModels.time_grid(sol, :dual) == T_state end end @@ -313,19 +326,21 @@ function test_solution_multi_grids() Test.@testset "MultipleTimeGridModel getters" begin T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) + T_costate = collect(LinRange(0, 1, 76)) + T_path = collect(LinRange(0, 1, 61)) # Create data matching the grid sizes X_multi = [1.0 - t/100 for t in 1:101, i in 1:2] # 101 points for state U_multi = [sin(2π * t/50) for t in 1:51, i in 1:1] # 51 points for control - P_multi = zeros(101, 2) # 101 points for costate + P_multi = zeros(76, 2) # 76 points for costate (has its own grid) v_multi = Float64[] sol = CTModels.build_solution( ocp, T_state, T_control, - T_state, - T_state, + T_costate, + T_path, X_multi, U_multi, v_multi, @@ -344,9 +359,9 @@ function test_solution_multi_grids() # Should work with component specification Test.@test CTModels.time_grid(sol, :state) == T_state Test.@test CTModels.time_grid(sol, :control) == T_control - Test.@test CTModels.time_grid(sol, :costate) == T_state - Test.@test CTModels.time_grid(sol, :dual) == T_state - Test.@test CTModels.time_grid(sol, :path) == T_state + Test.@test CTModels.time_grid(sol, :costate) == T_costate + Test.@test CTModels.time_grid(sol, :dual) == T_path + Test.@test CTModels.time_grid(sol, :path) == T_path # Test plural forms Test.@test CTModels.time_grid(sol, :states) == T_state @@ -399,11 +414,11 @@ function test_solution_multi_grids() T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) T_costate = collect(LinRange(0, 1, 76)) - T_dual = collect(LinRange(0, 1, 101)) + T_path = collect(LinRange(0, 1, 61)) X = [1.0 - t/100 for t in 1:101, i in 1:2] U = [sin(2π * t/50) for t in 1:51, i in 1:1] - P = zeros(76, 2) + P = zeros(76, 2) # Costate has its own grid v = Float64[] sol = CTModels.build_solution( @@ -411,7 +426,7 @@ function test_solution_multi_grids() T_state, T_control, T_costate, - T_dual, + T_path, X, U, v, @@ -431,16 +446,17 @@ function test_solution_multi_grids() Test.@test haskey(data, "time_grid_state") Test.@test haskey(data, "time_grid_control") Test.@test haskey(data, "time_grid_costate") - Test.@test haskey(data, "time_grid_dual") + Test.@test haskey(data, "time_grid_path") - # Should not have legacy single time grid + # Should not have legacy single time grid or old keys Test.@test !haskey(data, "time_grid") + Test.@test !haskey(data, "time_grid_dual") # Time grids should match Test.@test data["time_grid_state"] == T_state Test.@test data["time_grid_control"] == T_control Test.@test data["time_grid_costate"] == T_costate - Test.@test data["time_grid_dual"] == T_dual + Test.@test data["time_grid_path"] == T_path end end @@ -552,17 +568,19 @@ function test_solution_multi_grids() T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) + T_costate = collect(LinRange(0, 1, 76)) + T_path = collect(LinRange(0, 1, 61)) X = [1.0 - t/100 for t in 1:101, i in 1:2] U = [sin(2π * t/50) for t in 1:51, i in 1:1] - P = zeros(101, 2) + P = zeros(76, 2) v = Float64[] sol = CTModels.build_solution( ocp, T_state, T_control, - T_state, - T_state, + T_costate, + T_path, X, U, v, @@ -658,19 +676,21 @@ function test_solution_multi_grids() Test.@testset "MultipleTimeGridModel type stability" begin T_state = collect(LinRange(0, 1, 101)) T_control = collect(LinRange(0, 1, 51)) + T_costate = collect(LinRange(0, 1, 76)) + T_path = collect(LinRange(0, 1, 61)) # Create data matching the grid sizes X_stab = [1.0 - t/100 for t in 1:101, i in 1:2] # 101 points for state U_stab = [sin(2π * t/50) for t in 1:51, i in 1:1] # 51 points for control - P_stab = zeros(101, 2) # 101 points for costate + P_stab = zeros(76, 2) # 76 points for costate v_stab = Float64[] sol = CTModels.build_solution( ocp, T_state, T_control, - T_state, - T_state, + T_costate, + T_path, X_stab, U_stab, v_stab, diff --git a/test/suite/serialization/test_export_import.jl b/test/suite/serialization/test_export_import.jl index 328f8d8d..37bd8709 100644 --- a/test/suite/serialization/test_export_import.jl +++ b/test/suite/serialization/test_export_import.jl @@ -383,10 +383,11 @@ function test_export_import() x_func = CTModels.state(sol) for (i, t) in enumerate(T_orig) x_expected = x_func(t) - x_from_json = if state_json[i] isa Number - state_json[i] - else - Vector{Float64}(state_json[i]) + # After fix: state_json[i] is always a vector (even for 1D states) + x_from_json = Vector{Float64}(state_json[i]) + # For 1D states, extract scalar to match x_expected type + if length(x_from_json) == 1 && x_expected isa Number + x_from_json = x_from_json[1] end Test.@test x_from_json ≈ x_expected atol = 1e-8 end @@ -397,10 +398,11 @@ function test_export_import() u_func = CTModels.control(sol) for (i, t) in enumerate(T_orig) u_expected = u_func(t) - u_from_json = if control_json[i] isa Number - control_json[i] - else - Vector{Float64}(control_json[i]) + # After fix: control_json[i] is always a vector (even for 1D controls) + u_from_json = Vector{Float64}(control_json[i]) + # For 1D controls, extract scalar + if length(u_from_json) == 1 + u_from_json = u_from_json[1] end Test.@test u_from_json ≈ u_expected atol = 1e-8 end @@ -411,10 +413,11 @@ function test_export_import() p_func = CTModels.costate(sol) for (i, t) in enumerate(T_orig) p_expected = p_func(t) - p_from_json = if costate_json[i] isa Number - costate_json[i] - else - Vector{Float64}(costate_json[i]) + # After fix: costate_json[i] is always a vector + p_from_json = Vector{Float64}(costate_json[i]) + # For 1D costates, extract scalar to match p_expected type + if length(p_from_json) == 1 && p_expected isa Number + p_from_json = p_from_json[1] end Test.@test p_from_json ≈ p_expected atol = 1e-8 end @@ -977,14 +980,13 @@ function test_export_import() Test.@testset "JSON stack() behavior investigation" verbose = VERBOSE showtiming = SHOWTIMING begin - # Empirical investigation: When does stack() return Vector vs Matrix? - # This validates the need for the conditional in _json_array_to_matrix + # Empirical investigation: JSON export format verification # - # Findings: - # - Multi-dimensional trajectories (state, costate): stack() → Matrix - # - 1-dimensional trajectories (control in TestProblems.solution_example): stack() → Vector + # After fix: All trajectories are exported as Vector{Vector} to preserve structure + # - Multi-dimensional trajectories (state, costate): Vector{Vector} with each element being a vector + # - 1-dimensional trajectories (control): Vector{Vector} with each element being a 1-element vector # - # This proves the refactoring with _json_array_to_matrix is correct and necessary. + # This ensures proper round-trip serialization without dimension loss. ocp, sol = TestProblems.solution_example() @@ -996,16 +998,21 @@ function test_export_import() blob = JSON3.read(json_string) # Test state (multi-dimensional: 2D in TestProblems.solution_example) + # Now exported as Vector{Vector}, so stack() returns Matrix state_stacked = stack(blob["state"]; dims=1) - Test.@test state_stacked isa Matrix # Multi-D → Matrix + Test.@test state_stacked isa Matrix # Vector{Vector} → Matrix + Test.@test size(state_stacked, 2) == 2 # 2D state # Test control (1-dimensional in TestProblems.solution_example) + # Now exported as Vector{Vector}, so stack() returns Matrix (N×1) control_stacked = stack(blob["control"]; dims=1) - Test.@test control_stacked isa Vector # 1D → Vector + Test.@test control_stacked isa Matrix # Vector{Vector} → Matrix + Test.@test size(control_stacked, 2) == 1 # 1D control # Test costate (multi-dimensional: 2D) costate_stacked = stack(blob["costate"]; dims=1) - Test.@test costate_stacked isa Matrix # Multi-D → Matrix + Test.@test costate_stacked isa Matrix # Vector{Vector} → Matrix + Test.@test size(costate_stacked, 2) == 2 # 2D costate # Verify import works correctly (indirect test of _json_array_to_matrix) sol_reloaded = CTModels.import_ocp_solution( diff --git a/test/suite/serialization/test_multi_grids.jl b/test/suite/serialization/test_multi_grids.jl new file mode 100644 index 00000000..32273613 --- /dev/null +++ b/test/suite/serialization/test_multi_grids.jl @@ -0,0 +1,566 @@ +module TestMultiGrids + +import Test +import CTModels +import JLD2 +import JSON3 +import CTBase.Exceptions + +include(joinpath("..", "..", "problems", "TestProblems.jl")) +import .TestProblems + +const VERBOSE = isdefined(Main, :TestData) ? Main.TestData.VERBOSE : true +const SHOWTIMING = isdefined(Main, :TestData) ? Main.TestData.SHOWTIMING : true + +function remove_if_exists(filename::String) + isfile(filename) && rm(filename) +end + +function test_multi_grids() + Test.@testset "Multi-Grid Serialization Tests" verbose=VERBOSE showtiming=SHOWTIMING begin + + # ==================================================================== + # UNIT TESTS - Abstract Types + # ==================================================================== + + Test.@testset "Abstract Types" begin + # Pure unit tests for multi-grid serialization functionality + end + + # ==================================================================== + # INTEGRATION TESTS - Multi-Grid Support + # ==================================================================== + + # Create base solution with unified grid + ocp, sol_unified = TestProblems.solution_example() + + # Extract data from unified solution + T_unified = CTModels.time_grid(sol_unified) + X = CTModels.state(sol_unified).(T_unified) + U = CTModels.control(sol_unified).(T_unified) + P = CTModels.costate(sol_unified).(T_unified) + v = CTModels.variable(sol_unified) + + # Convert to matrices + dim_x = CTModels.state_dimension(sol_unified) + dim_u = CTModels.control_dimension(sol_unified) + X_mat = hcat([x for x in X]...)' + U_mat = hcat([u isa Number ? [u] : u for u in U]...)' + P_mat = hcat([p for p in P]...)' + + # ==================================================================== + # Test 1: Unified Grid (should use UnifiedTimeGridModel) + # ==================================================================== + + Test.@testset "Unified grid detection" begin + # Create solution with same grid for all components using functions + T = collect(LinRange(0.0, 1.0, 11)) + + # Use functions (simpler and more robust) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, + T, T, T, T, # All grids identical + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create UnifiedTimeGridModel (optimization) + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + + # time_grid without argument should work + T_retrieved = CTModels.time_grid(sol) + Test.@test T_retrieved ≈ T + end + + # ==================================================================== + # Test 2: Multiple Grids (should use MultipleTimeGridModel) + # ==================================================================== + + Test.@testset "Multiple grids detection" begin + # Create solution with different grids using functions + T_state = collect(LinRange(0.0, 1.0, 21)) # Fine grid + T_control = collect(LinRange(0.0, 1.0, 11)) # Coarse grid + T_costate = collect(LinRange(0.0, 1.0, 16)) # Medium grid + T_path = collect(LinRange(0.0, 1.0, 21)) # Fine grid + + # Use functions instead of matrices (simpler) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_path, # Different grids + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_multi) isa CTModels.MultipleTimeGridModel + + # time_grid with component should work + Test.@test CTModels.time_grid(sol_multi, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_multi, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_multi, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_multi, :dual) ≈ T_path + end + + # ==================================================================== + # Test 3: JLD2 Export/Import with Multiple Grids + # ==================================================================== + + Test.@testset "JLD2 multi-grid round-trip" begin + # Create solution with different grids using functions + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 16)) + T_path = collect(LinRange(0.0, 1.0, 21)) + # T_path same as T_state for this test + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_path, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Export + CTModels.export_ocp_solution(sol_multi; filename="multi_grid_test", format=:JLD) + + # Import + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="multi_grid_test", format=:JLD) + + # Verify time grid model type + Test.@test CTModels.time_grid_model(sol_reloaded) isa CTModels.MultipleTimeGridModel + + # Verify grids are preserved + Test.@test CTModels.time_grid(sol_reloaded, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_reloaded, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_reloaded, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_reloaded, :dual) ≈ T_path + + # Verify data integrity + Test.@test CTModels.objective(sol_reloaded) ≈ CTModels.objective(sol_multi) + Test.@test CTModels.variable(sol_reloaded) ≈ CTModels.variable(sol_multi) + + # Verify trajectories at their respective grids + for t in T_state + Test.@test CTModels.state(sol_reloaded)(t) ≈ CTModels.state(sol_multi)(t) atol=1e-8 + end + for t in T_control + Test.@test CTModels.control(sol_reloaded)(t) ≈ CTModels.control(sol_multi)(t) atol=1e-8 + end + + remove_if_exists("multi_grid_test.jld2") + end + + # ==================================================================== + # Test 4: Error Handling for MultipleTimeGridModel + # ==================================================================== + + Test.@testset "Error handling - MultipleTimeGridModel" begin + # Create a multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 16)) + T_path = collect(LinRange(0.0, 1.0, 21)) + # T_path same as T_state for this test + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_path, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # time_grid without component should throw error + Test.@test_throws Exceptions.IncorrectArgument CTModels.time_grid(sol_multi) + + # Invalid component should throw error + Test.@test_throws Exceptions.IncorrectArgument CTModels.time_grid(sol_multi, :invalid) + end + + # ==================================================================== + # Test 5: Component Symbol Mapping + # ==================================================================== + + Test.@testset "Component symbol mapping" begin + # Create a multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 16)) + T_path = collect(LinRange(0.0, 1.0, 21)) + # T_path same as T_state for this test + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_path, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Test plural forms work + Test.@test CTModels.time_grid(sol_multi, :states) ≈ T_state + Test.@test CTModels.time_grid(sol_multi, :controls) ≈ T_control + Test.@test CTModels.time_grid(sol_multi, :costates) ≈ T_costate + + # Test path/dual equivalence + Test.@test CTModels.time_grid(sol_multi, :path) ≈ T_path + Test.@test CTModels.time_grid(sol_multi, :dual) ≈ T_path + end + + # ==================================================================== + # Test 6: Edge Cases + # ==================================================================== + + Test.@testset "Edge cases" begin + # Test with T_path = nothing + T_state = collect(LinRange(0.0, 1.0, 11)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 11)) + T_path = collect(LinRange(0.0, 1.0, 11)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, nothing, # T_path = nothing + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should still work (uses T_state for dual) + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol) ≈ T_state + end + + # ==================================================================== + # Test 7: Unified vs Multiple Grid Optimization + # ==================================================================== + + Test.@testset "Unified grid optimization" begin + # When all grids are identical, should optimize to UnifiedTimeGridModel + T = collect(LinRange(0.0, 1.0, 11)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + # Pass same grid 4 times + sol = CTModels.build_solution( + ocp, + T, T, T, T, # All identical + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should detect and optimize to UnifiedTimeGridModel + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol) ≈ T + + # Now with different grids + T_control_diff = collect(LinRange(0.0, 1.0, 6)) + + sol_multi = CTModels.build_solution( + ocp, + T, T_control_diff, T, T, # One different + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should use MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_multi) isa CTModels.MultipleTimeGridModel + Test.@test CTModels.time_grid(sol_multi, :state) ≈ T + Test.@test CTModels.time_grid(sol_multi, :control) ≈ T_control_diff + end + + # ==================================================================== + # Test 8: Serialization Internal Structure + # ==================================================================== + + Test.@testset "Serialization structure" begin + # Test UnifiedTimeGridModel serialization + T = collect(LinRange(0.0, 1.0, 11)) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_uni = CTModels.build_solution( + ocp, T, T, T, T, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize and check structure + data_uni = CTModels.OCP._serialize_solution(sol_uni) + + # Should have legacy format keys + Test.@test haskey(data_uni, "time_grid") + Test.@test !haskey(data_uni, "time_grid_state") + Test.@test data_uni["time_grid"] ≈ T + + # Test MultipleTimeGridModel serialization + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 16)) + T_path = collect(LinRange(0.0, 1.0, 21)) + # T_path same as T_state for this test + + sol_multi = CTModels.build_solution( + ocp, T_state, T_control, T_costate, T_path, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize and check structure + data_multi = CTModels.OCP._serialize_solution(sol_multi) + + # Should have multi-grid format keys + Test.@test haskey(data_multi, "time_grid_state") + Test.@test haskey(data_multi, "time_grid_control") + Test.@test haskey(data_multi, "time_grid_costate") + Test.@test haskey(data_multi, "time_grid_path") + Test.@test !haskey(data_multi, "time_grid") + + # Verify grid values + Test.@test data_multi["time_grid_state"] ≈ T_state + Test.@test data_multi["time_grid_control"] ≈ T_control + Test.@test data_multi["time_grid_costate"] ≈ T_costate + Test.@test data_multi["time_grid_path"] ≈ T_path + end + + # ==================================================================== + # Test 9: Extreme Grid Sizes + # ==================================================================== + + Test.@testset "Extreme grid sizes" begin + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + # Very different grid sizes + T_state_large = collect(LinRange(0.0, 1.0, 1001)) # Fine grid + T_control_small = collect(LinRange(0.0, 1.0, 11)) # Coarse grid + T_costate_medium = collect(LinRange(0.0, 1.0, 101)) # Medium grid + T_path_large = collect(LinRange(0.0, 1.0, 1001)) + + sol_extreme = CTModels.build_solution( + ocp, + T_state_large, T_control_small, T_costate_medium, T_path_large, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_extreme) isa CTModels.MultipleTimeGridModel + + # Verify grids + Test.@test length(CTModels.time_grid(sol_extreme, :state)) == 1001 + Test.@test length(CTModels.time_grid(sol_extreme, :control)) == 11 + Test.@test CTModels.time_grid(sol_extreme, :state) ≈ T_state_large + Test.@test CTModels.time_grid(sol_extreme, :control) ≈ T_control_small + + # Minimum grid size (2 points) + T_min = collect(LinRange(0.0, 1.0, 2)) + + sol_min = CTModels.build_solution( + ocp, T_min, T_min, T_min, T_min, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should work with minimum grid + Test.@test CTModels.time_grid_model(sol_min) isa CTModels.UnifiedTimeGridModel + Test.@test length(CTModels.time_grid(sol_min)) == 2 + end + + # ==================================================================== + # Test 10: Grid Reconstruction from Serialized Data + # ==================================================================== + + Test.@testset "Grid reconstruction" begin + # Create multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 16)) + T_path = collect(LinRange(0.0, 1.0, 21)) + # T_path same as T_state for this test + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_orig = CTModels.build_solution( + ocp, T_state, T_control, T_costate, T_path, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize + data = CTModels.OCP._serialize_solution(sol_orig) + + # Reconstruct using helper function + sol_reconstructed = CTModels.Serialization._reconstruct_solution_from_data( + ocp, data; + path_constraints_dual=data["path_constraints_dual"], + boundary_constraints_dual=data["boundary_constraints_dual"], + state_constraints_lb_dual=data["state_constraints_lb_dual"], + state_constraints_ub_dual=data["state_constraints_ub_dual"], + control_constraints_lb_dual=data["control_constraints_lb_dual"], + control_constraints_ub_dual=data["control_constraints_ub_dual"], + variable_constraints_lb_dual=data["variable_constraints_lb_dual"], + variable_constraints_ub_dual=data["variable_constraints_ub_dual"], + infos=get(data, "infos", Dict{Symbol,Any}()), + ) + + # Verify reconstruction + Test.@test CTModels.time_grid_model(sol_reconstructed) isa CTModels.MultipleTimeGridModel + Test.@test CTModels.time_grid(sol_reconstructed, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_reconstructed, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_reconstructed, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_reconstructed, :dual) ≈ T_path + Test.@test CTModels.objective(sol_reconstructed) ≈ CTModels.objective(sol_orig) + end + + # ==================================================================== + # Test 11: Backward Compatibility - Legacy Format Detection + # ==================================================================== + + Test.@testset "Legacy format detection" begin + # Create a legacy-format data structure (single time_grid) + T = collect(LinRange(0.0, 1.0, 11)) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, T, T, T, T, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize (should produce legacy format) + data = CTModels.OCP._serialize_solution(sol) + + # Verify legacy format + Test.@test haskey(data, "time_grid") + Test.@test !haskey(data, "time_grid_state") + + # Reconstruct from legacy format + sol_from_legacy = CTModels.Serialization._reconstruct_solution_from_data( + ocp, data; + path_constraints_dual=data["path_constraints_dual"], + boundary_constraints_dual=data["boundary_constraints_dual"], + state_constraints_lb_dual=data["state_constraints_lb_dual"], + state_constraints_ub_dual=data["state_constraints_ub_dual"], + control_constraints_lb_dual=data["control_constraints_lb_dual"], + control_constraints_ub_dual=data["control_constraints_ub_dual"], + variable_constraints_lb_dual=data["variable_constraints_lb_dual"], + variable_constraints_ub_dual=data["variable_constraints_ub_dual"], + infos=get(data, "infos", Dict{Symbol,Any}()), + ) + + # Should create UnifiedTimeGridModel from legacy format + Test.@test CTModels.time_grid_model(sol_from_legacy) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol_from_legacy) ≈ T + end + + # ==================================================================== + # TODO: Add JSON tests once matrix dimension issues are fixed + # TODO: Add tests with path_constraints_dual on multi-grids + # ==================================================================== + end +end + +end # module + +# CRITICAL: Redefine in outer scope for TestRunner +test_multi_grids() = TestMultiGrids.test_multi_grids() diff --git a/test/suite/serialization/test_multi_grids.jl.bak b/test/suite/serialization/test_multi_grids.jl.bak new file mode 100644 index 00000000..7a00c49c --- /dev/null +++ b/test/suite/serialization/test_multi_grids.jl.bak @@ -0,0 +1,560 @@ +module TestMultiGrids + +import Test +import CTModels +import JLD2 +import JSON3 +import CTBase.Exceptions + +include(joinpath("..", "..", "problems", "TestProblems.jl")) +import .TestProblems + +const VERBOSE = isdefined(Main, :TestData) ? Main.TestData.VERBOSE : true +const SHOWTIMING = isdefined(Main, :TestData) ? Main.TestData.SHOWTIMING : true + +function remove_if_exists(filename::String) + isfile(filename) && rm(filename) +end + +function test_multi_grids() + Test.@testset "Multi-Grid Serialization Tests" verbose=VERBOSE showtiming=SHOWTIMING begin + + # ==================================================================== + # UNIT TESTS - Abstract Types + # ==================================================================== + + Test.@testset "Abstract Types" begin + # Pure unit tests for multi-grid serialization functionality + end + + # ==================================================================== + # INTEGRATION TESTS - Multi-Grid Support + # ==================================================================== + + # Create base solution with unified grid + ocp, sol_unified = TestProblems.solution_example() + + # Extract data from unified solution + T_unified = CTModels.time_grid(sol_unified) + X = CTModels.state(sol_unified).(T_unified) + U = CTModels.control(sol_unified).(T_unified) + P = CTModels.costate(sol_unified).(T_unified) + v = CTModels.variable(sol_unified) + + # Convert to matrices + dim_x = CTModels.state_dimension(sol_unified) + dim_u = CTModels.control_dimension(sol_unified) + X_mat = hcat([x for x in X]...)' + U_mat = hcat([u isa Number ? [u] : u for u in U]...)' + P_mat = hcat([p for p in P]...)' + + # ==================================================================== + # Test 1: Unified Grid (should use UnifiedTimeGridModel) + # ==================================================================== + + Test.@testset "Unified grid detection" begin + # Create solution with same grid for all components using functions + T = collect(LinRange(0.0, 1.0, 11)) + + # Use functions (simpler and more robust) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, + T, T, T, T, # All grids identical + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create UnifiedTimeGridModel (optimization) + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + + # time_grid without argument should work + T_retrieved = CTModels.time_grid(sol) + Test.@test T_retrieved ≈ T + end + + # ==================================================================== + # Test 2: Multiple Grids (should use MultipleTimeGridModel) + # ==================================================================== + + Test.@testset "Multiple grids detection" begin + # Create solution with different grids using functions + T_state = collect(LinRange(0.0, 1.0, 21)) # Fine grid + T_control = collect(LinRange(0.0, 1.0, 11)) # Coarse grid + T_costate = collect(LinRange(0.0, 1.0, 21)) # Fine grid + T_dual = collect(LinRange(0.0, 1.0, 21)) # Fine grid + + # Use functions instead of matrices (simpler) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_dual, # Different grids + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_multi) isa CTModels.MultipleTimeGridModel + + # time_grid with component should work + Test.@test CTModels.time_grid(sol_multi, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_multi, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_multi, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_multi, :dual) ≈ T_dual + end + + # ==================================================================== + # Test 3: JLD2 Export/Import with Multiple Grids + # ==================================================================== + + Test.@testset "JLD2 multi-grid round-trip" begin + # Create solution with different grids using functions + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 21)) + T_dual = collect(LinRange(0.0, 1.0, 21)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_dual, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Export + CTModels.export_ocp_solution(sol_multi; filename="multi_grid_test", format=:JLD) + + # Import + sol_reloaded = CTModels.import_ocp_solution(ocp; filename="multi_grid_test", format=:JLD) + + # Verify time grid model type + Test.@test CTModels.time_grid_model(sol_reloaded) isa CTModels.MultipleTimeGridModel + + # Verify grids are preserved + Test.@test CTModels.time_grid(sol_reloaded, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_reloaded, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_reloaded, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_reloaded, :dual) ≈ T_dual + + # Verify data integrity + Test.@test CTModels.objective(sol_reloaded) ≈ CTModels.objective(sol_multi) + Test.@test CTModels.variable(sol_reloaded) ≈ CTModels.variable(sol_multi) + + # Verify trajectories at their respective grids + for t in T_state + Test.@test CTModels.state(sol_reloaded)(t) ≈ CTModels.state(sol_multi)(t) atol=1e-8 + end + for t in T_control + Test.@test CTModels.control(sol_reloaded)(t) ≈ CTModels.control(sol_multi)(t) atol=1e-8 + end + + remove_if_exists("multi_grid_test.jld2") + end + + # ==================================================================== + # Test 4: Error Handling for MultipleTimeGridModel + # ==================================================================== + + Test.@testset "Error handling - MultipleTimeGridModel" begin + # Create a multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 21)) + T_dual = collect(LinRange(0.0, 1.0, 21)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_dual, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # time_grid without component should throw error + Test.@test_throws Exceptions.IncorrectArgument CTModels.time_grid(sol_multi) + + # Invalid component should throw error + Test.@test_throws Exceptions.IncorrectArgument CTModels.time_grid(sol_multi, :invalid) + end + + # ==================================================================== + # Test 5: Component Symbol Mapping + # ==================================================================== + + Test.@testset "Component symbol mapping" begin + # Create a multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 21)) + T_dual = collect(LinRange(0.0, 1.0, 21)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_multi = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, T_dual, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Test plural forms work + Test.@test CTModels.time_grid(sol_multi, :states) ≈ T_state + Test.@test CTModels.time_grid(sol_multi, :controls) ≈ T_control + Test.@test CTModels.time_grid(sol_multi, :costates) ≈ T_costate + + # Test path/dual equivalence + Test.@test CTModels.time_grid(sol_multi, :path) ≈ T_dual + Test.@test CTModels.time_grid(sol_multi, :dual) ≈ T_dual + end + + # ==================================================================== + # Test 6: Edge Cases + # ==================================================================== + + Test.@testset "Edge cases" begin + # Test with T_dual = nothing + T_state = collect(LinRange(0.0, 1.0, 11)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 11)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, + T_state, T_control, T_costate, nothing, # T_dual = nothing + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should still work (uses T_state for dual) + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol) ≈ T_state + end + + # ==================================================================== + # Test 7: Unified vs Multiple Grid Optimization + # ==================================================================== + + Test.@testset "Unified grid optimization" begin + # When all grids are identical, should optimize to UnifiedTimeGridModel + T = collect(LinRange(0.0, 1.0, 11)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + # Pass same grid 4 times + sol = CTModels.build_solution( + ocp, + T, T, T, T, # All identical + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should detect and optimize to UnifiedTimeGridModel + Test.@test CTModels.time_grid_model(sol) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol) ≈ T + + # Now with different grids + T_control_diff = collect(LinRange(0.0, 1.0, 6)) + + sol_multi = CTModels.build_solution( + ocp, + T, T_control_diff, T, T, # One different + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should use MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_multi) isa CTModels.MultipleTimeGridModel + Test.@test CTModels.time_grid(sol_multi, :state) ≈ T + Test.@test CTModels.time_grid(sol_multi, :control) ≈ T_control_diff + end + + # ==================================================================== + # Test 8: Serialization Internal Structure + # ==================================================================== + + Test.@testset "Serialization structure" begin + # Test UnifiedTimeGridModel serialization + T = collect(LinRange(0.0, 1.0, 11)) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_uni = CTModels.build_solution( + ocp, T, T, T, T, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize and check structure + data_uni = CTModels.OCP._serialize_solution(sol_uni) + + # Should have legacy format keys + Test.@test haskey(data_uni, "time_grid") + Test.@test !haskey(data_uni, "time_grid_state") + Test.@test data_uni["time_grid"] ≈ T + + # Test MultipleTimeGridModel serialization + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 21)) + T_dual = collect(LinRange(0.0, 1.0, 21)) + + sol_multi = CTModels.build_solution( + ocp, T_state, T_control, T_costate, T_dual, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize and check structure + data_multi = CTModels.OCP._serialize_solution(sol_multi) + + # Should have multi-grid format keys + Test.@test haskey(data_multi, "time_grid_state") + Test.@test haskey(data_multi, "time_grid_control") + Test.@test haskey(data_multi, "time_grid_costate") + Test.@test haskey(data_multi, "time_grid_dual") + Test.@test !haskey(data_multi, "time_grid") + + # Verify grid values + Test.@test data_multi["time_grid_state"] ≈ T_state + Test.@test data_multi["time_grid_control"] ≈ T_control + Test.@test data_multi["time_grid_costate"] ≈ T_costate + Test.@test data_multi["time_grid_dual"] ≈ T_dual + end + + # ==================================================================== + # Test 9: Extreme Grid Sizes + # ==================================================================== + + Test.@testset "Extreme grid sizes" begin + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + # Very different grid sizes + T_state_large = collect(LinRange(0.0, 1.0, 1001)) # Fine grid + T_control_small = collect(LinRange(0.0, 1.0, 11)) # Coarse grid + T_costate_large = collect(LinRange(0.0, 1.0, 1001)) + T_dual_large = collect(LinRange(0.0, 1.0, 1001)) + + sol_extreme = CTModels.build_solution( + ocp, + T_state_large, T_control_small, T_costate_large, T_dual_large, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should create MultipleTimeGridModel + Test.@test CTModels.time_grid_model(sol_extreme) isa CTModels.MultipleTimeGridModel + + # Verify grids + Test.@test length(CTModels.time_grid(sol_extreme, :state)) == 1001 + Test.@test length(CTModels.time_grid(sol_extreme, :control)) == 11 + Test.@test CTModels.time_grid(sol_extreme, :state) ≈ T_state_large + Test.@test CTModels.time_grid(sol_extreme, :control) ≈ T_control_small + + # Minimum grid size (2 points) + T_min = collect(LinRange(0.0, 1.0, 2)) + + sol_min = CTModels.build_solution( + ocp, T_min, T_min, T_min, T_min, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Should work with minimum grid + Test.@test CTModels.time_grid_model(sol_min) isa CTModels.UnifiedTimeGridModel + Test.@test length(CTModels.time_grid(sol_min)) == 2 + end + + # ==================================================================== + # Test 10: Grid Reconstruction from Serialized Data + # ==================================================================== + + Test.@testset "Grid reconstruction" begin + # Create multi-grid solution + T_state = collect(LinRange(0.0, 1.0, 21)) + T_control = collect(LinRange(0.0, 1.0, 11)) + T_costate = collect(LinRange(0.0, 1.0, 21)) + T_dual = collect(LinRange(0.0, 1.0, 21)) + + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol_orig = CTModels.build_solution( + ocp, T_state, T_control, T_costate, T_dual, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize + data = CTModels.OCP._serialize_solution(sol_orig) + + # Reconstruct using helper function + sol_reconstructed = CTModels.Serialization._reconstruct_solution_from_data( + ocp, data; + path_constraints_dual=data["path_constraints_dual"], + boundary_constraints_dual=data["boundary_constraints_dual"], + state_constraints_lb_dual=data["state_constraints_lb_dual"], + state_constraints_ub_dual=data["state_constraints_ub_dual"], + control_constraints_lb_dual=data["control_constraints_lb_dual"], + control_constraints_ub_dual=data["control_constraints_ub_dual"], + variable_constraints_lb_dual=data["variable_constraints_lb_dual"], + variable_constraints_ub_dual=data["variable_constraints_ub_dual"], + infos=get(data, "infos", Dict{Symbol,Any}()), + ) + + # Verify reconstruction + Test.@test CTModels.time_grid_model(sol_reconstructed) isa CTModels.MultipleTimeGridModel + Test.@test CTModels.time_grid(sol_reconstructed, :state) ≈ T_state + Test.@test CTModels.time_grid(sol_reconstructed, :control) ≈ T_control + Test.@test CTModels.time_grid(sol_reconstructed, :costate) ≈ T_costate + Test.@test CTModels.time_grid(sol_reconstructed, :dual) ≈ T_dual + Test.@test CTModels.objective(sol_reconstructed) ≈ CTModels.objective(sol_orig) + end + + # ==================================================================== + # Test 11: Backward Compatibility - Legacy Format Detection + # ==================================================================== + + Test.@testset "Legacy format detection" begin + # Create a legacy-format data structure (single time_grid) + T = collect(LinRange(0.0, 1.0, 11)) + X_func = CTModels.state(sol_unified) + U_func = CTModels.control(sol_unified) + P_func = CTModels.costate(sol_unified) + + sol = CTModels.build_solution( + ocp, T, T, T, T, + X_func, U_func, v, P_func; + objective=CTModels.objective(sol_unified), + iterations=CTModels.iterations(sol_unified), + constraints_violation=CTModels.constraints_violation(sol_unified), + message=CTModels.message(sol_unified), + status=CTModels.status(sol_unified), + successful=CTModels.successful(sol_unified), + ) + + # Serialize (should produce legacy format) + data = CTModels.OCP._serialize_solution(sol) + + # Verify legacy format + Test.@test haskey(data, "time_grid") + Test.@test !haskey(data, "time_grid_state") + + # Reconstruct from legacy format + sol_from_legacy = CTModels.Serialization._reconstruct_solution_from_data( + ocp, data; + path_constraints_dual=data["path_constraints_dual"], + boundary_constraints_dual=data["boundary_constraints_dual"], + state_constraints_lb_dual=data["state_constraints_lb_dual"], + state_constraints_ub_dual=data["state_constraints_ub_dual"], + control_constraints_lb_dual=data["control_constraints_lb_dual"], + control_constraints_ub_dual=data["control_constraints_ub_dual"], + variable_constraints_lb_dual=data["variable_constraints_lb_dual"], + variable_constraints_ub_dual=data["variable_constraints_ub_dual"], + infos=get(data, "infos", Dict{Symbol,Any}()), + ) + + # Should create UnifiedTimeGridModel from legacy format + Test.@test CTModels.time_grid_model(sol_from_legacy) isa CTModels.UnifiedTimeGridModel + Test.@test CTModels.time_grid(sol_from_legacy) ≈ T + end + + # ==================================================================== + # TODO: Add JSON tests once matrix dimension issues are fixed + # TODO: Add tests with path_constraints_dual on multi-grids + # ==================================================================== + end +end + +end # module + +# CRITICAL: Redefine in outer scope for TestRunner +test_multi_grids() = TestMultiGrids.test_multi_grids()