From ca461fb9fe73d421c3c79a550445036c72cf0861 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 26 Nov 2025 14:11:53 -0800 Subject: [PATCH 1/8] introduce cuda sdpa Differential Revision: D87950475 --- backends/cuda/runtime/TARGETS | 4 + backends/cuda/runtime/shims/sdpa.cu | 649 +++++++++++++++ backends/cuda/runtime/shims/sdpa.cuh | 282 +++++++ backends/cuda/runtime/shims/sdpa.h | 104 +++ backends/cuda/runtime/shims/tests/targets.bzl | 1 + ...orch_cuda_scaled_dot_product_attention.cpp | 781 ++++++++++++++++++ 6 files changed, 1821 insertions(+) create mode 100644 backends/cuda/runtime/shims/sdpa.cu create mode 100644 backends/cuda/runtime/shims/sdpa.cuh create mode 100644 backends/cuda/runtime/shims/sdpa.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index a85f3a7e6a3..01dabee9086 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -53,6 +53,7 @@ runtime.cxx_library( "shims/cuda_guard.cpp", "shims/int4mm.cu", "shims/memory.cpp", + "shims/sdpa.cu", "shims/tensor_attribute.cpp", ], headers = [ @@ -61,6 +62,8 @@ runtime.cxx_library( "shims/int4mm.cuh", "shims/int4mm.h", "shims/memory.h", + "shims/sdpa.cuh", + "shims/sdpa.h", "shims/tensor_attribute.h", "utils.h", ], @@ -84,6 +87,7 @@ runtime.cxx_library( ], external_deps = [ ("cuda", None, "cuda-lazy"), + ("cuda", None, "cublas-lazy"), ], ) diff --git a/backends/cuda/runtime/shims/sdpa.cu b/backends/cuda/runtime/shims/sdpa.cu new file mode 100644 index 00000000000..c15f1f006bc --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.cu @@ -0,0 +1,649 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +using executorch::backends::aoti::AOTITorchError; +using executorch::runtime::Error; + +// ============================================================================ +// CUDA Kernels for Softmax and Masking +// ============================================================================ + +// Helper function for max with different types +__device__ __forceinline__ float device_max(float a, float b) { + return fmaxf(a, b); +} + +__device__ __forceinline__ __half device_max(__half a, __half b) { + return __hgt(a, b) ? a : b; +} + +__device__ __forceinline__ __nv_bfloat16 device_max(__nv_bfloat16 a, __nv_bfloat16 b) { + #if __CUDA_ARCH__ >= 800 + return __hgt(a, b) ? a : b; + #else + return __float2bfloat16(fmaxf(__bfloat162float(a), __bfloat162float(b))); + #endif +} + +/** + * Softmax kernel with optional causal masking + * + * Computes softmax along the last dimension (seq_len_k) of a 4D tensor. + * Supports causal masking where positions j > i are masked out. + * + * Input: [batch, num_heads, seq_len_q, seq_len_k] + * Output: [batch, num_heads, seq_len_q, seq_len_k] + * + * Each thread processes one row (seq_len_q position). + * + * Note: Supports in-place operation (input == output). + */ +template +__global__ void softmax_with_causal_mask_kernel( + const scalar_t* input, + scalar_t* output, + const int64_t batch, + const int64_t num_heads, + const int64_t seq_len_q, + const int64_t seq_len_k, + const bool is_causal, + const float scale) { + + // Each block processes one row of the attention matrix + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total_rows = batch * num_heads * seq_len_q; + + if (idx >= total_rows) { + return; + } + + // Decode position - we only need i for causal masking + const int64_t i = idx % seq_len_q; + + // Pointer to the start of this row + const int64_t row_offset = idx * seq_len_k; + const scalar_t* input_row = input + row_offset; + scalar_t* output_row = output + row_offset; + + // Find max for numerical stability (two-pass algorithm) + float max_val = -FLT_MAX; + for (int64_t j = 0; j < seq_len_k; ++j) { + if (!is_causal || j <= i) { + float val = static_cast(input_row[j]) * scale; + max_val = fmaxf(max_val, val); + } + } + + // Compute exp and sum + float sum_exp = 0.0f; + for (int64_t j = 0; j < seq_len_k; ++j) { + float val; + if (!is_causal || j <= i) { + val = expf(static_cast(input_row[j]) * scale - max_val); + } else { + val = 0.0f; + } + output_row[j] = static_cast(val); + sum_exp += val; + } + + // Normalize + const float inv_sum = 1.0f / sum_exp; + for (int64_t j = 0; j < seq_len_k; ++j) { + output_row[j] = static_cast(static_cast(output_row[j]) * inv_sum); + } +} + +/** + * Scale kernel - multiply all elements by a scalar + */ +template +__global__ void scale_kernel( + scalar_t* __restrict__ data, + const int64_t size, + const float scale) { + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + data[idx] = static_cast(static_cast(data[idx]) * scale); + } +} + +// ============================================================================ +// cuBLAS Helper Functions +// ============================================================================ + +/** + * Get or create a cuBLAS handle for the current stream + * + * Note: In production, this should use a handle pool or be managed + * by the backend infrastructure. This is a simplified version. + */ +cublasHandle_t get_cublas_handle(cudaStream_t stream) { + static cublasHandle_t handle = nullptr; + + if (handle == nullptr) { + cublasCreate(&handle); + } + + cublasSetStream(handle, stream); + return handle; +} + +/** + * Batched matrix multiplication wrapper for cuBLAS + * + * Computes: C = alpha * op(A) @ op(B) + beta * C + * for a batch of matrices + */ +template +cublasStatus_t batched_gemm( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, int n, int k, + const scalar_t* alpha, + const scalar_t* A, int lda, int64_t strideA, + const scalar_t* B, int ldb, int64_t strideB, + const scalar_t* beta, + scalar_t* C, int ldc, int64_t strideC, + int batchCount); + +// Specializations for different types +template<> +cublasStatus_t batched_gemm( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* A, int lda, int64_t strideA, + const float* B, int ldb, int64_t strideB, + const float* beta, + float* C, int ldc, int64_t strideC, + int batchCount) { + return cublasSgemmStridedBatched( + handle, transa, transb, m, n, k, + alpha, A, lda, strideA, B, ldb, strideB, + beta, C, ldc, strideC, batchCount); +} + +template<> +cublasStatus_t batched_gemm<__half>( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, int n, int k, + const __half* alpha, + const __half* A, int lda, int64_t strideA, + const __half* B, int ldb, int64_t strideB, + const __half* beta, + __half* C, int ldc, int64_t strideC, + int batchCount) { + return cublasHgemmStridedBatched( + handle, transa, transb, m, n, k, + alpha, A, lda, strideA, B, ldb, strideB, + beta, C, ldc, strideC, batchCount); +} + +// Note: BFloat16 uses compute type float internally +template<> +cublasStatus_t batched_gemm<__nv_bfloat16>( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, int n, int k, + const __nv_bfloat16* alpha, + const __nv_bfloat16* A, int lda, int64_t strideA, + const __nv_bfloat16* B, int ldb, int64_t strideB, + const __nv_bfloat16* beta, + __nv_bfloat16* C, int ldc, int64_t strideC, + int batchCount) { + + // cuBLAS BFloat16 GEMM - introduced in CUDA 11+ + #if CUDA_VERSION >= 11000 + // For BFloat16, we need to use cublasGemmStridedBatchedEx + // with compute type CUBLAS_COMPUTE_32F + float alpha_f = 1.0f; + float beta_f = 0.0f; + + return cublasGemmStridedBatchedEx( + handle, + transa, transb, + m, n, k, + &alpha_f, + A, CUDA_R_16BF, lda, strideA, + B, CUDA_R_16BF, ldb, strideB, + &beta_f, + C, CUDA_R_16BF, ldc, strideC, + batchCount, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); + #else + ET_LOG(Error, "BFloat16 GEMM requires CUDA 11.0 or later"); + return CUBLAS_STATUS_NOT_SUPPORTED; + #endif +} + +// ============================================================================ +// Math Fallback Implementation +// ============================================================================ + +/** + * Math fallback implementation for SDPA + * + * This implementation uses cuBLAS for matrix multiplications and custom + * kernels for softmax. It provides maximum compatibility across all CUDA + * devices but may not be as optimized as Flash Attention or Memory Efficient + * Attention. + * + * Algorithm: + * 1. Compute attention scores: S = (Q @ K^T) + * 2. Apply scaling and compute softmax with optional causal mask + * 3. Compute output: O = attention_weights @ V + */ +template +Tensor* sdpa_math_fallback_impl( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + float scale_factor, + cudaStream_t stream) { + + // Get tensor dimensions + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim = query->size(3); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_v = value->size(3); + + // Get cuBLAS handle + cublasHandle_t handle = get_cublas_handle(stream); + + // Step 1: Allocate temporary buffer for attention scores + // Shape: [batch, num_heads, seq_len_q, seq_len_k] + const int64_t scores_size = batch * num_heads * seq_len_q * seq_len_k; + scalar_t* scores_ptr = nullptr; + cudaMalloc(&scores_ptr, scores_size * sizeof(scalar_t)); + if (scores_ptr == nullptr) { + ET_LOG(Error, "sdpa_math_fallback: Failed to allocate scores buffer"); + return nullptr; + } + + // Step 2: Compute Q @ K^T using cuBLAS + // Q: [batch * num_heads, seq_len_q, head_dim] + // K^T: [batch * num_heads, head_dim, seq_len_k] + // Output: [batch * num_heads, seq_len_q, seq_len_k] + + const int m = seq_len_q; + const int n = seq_len_k; + const int k = head_dim; + const int batch_count = batch * num_heads; + + const scalar_t alpha = static_cast(1.0f); + const scalar_t beta = static_cast(0.0f); + + const scalar_t* q_ptr = reinterpret_cast(query->data_ptr()); + const scalar_t* k_ptr = reinterpret_cast(key->data_ptr()); + + // Strides for batched GEMM + const int64_t stride_q = seq_len_q * head_dim; + const int64_t stride_k = seq_len_k * head_dim; + const int64_t stride_scores = seq_len_q * seq_len_k; + + // Q @ K^T + cublasStatus_t status = batched_gemm( + handle, + CUBLAS_OP_T, // Transpose K + CUBLAS_OP_N, // No transpose Q + n, // seq_len_k + m, // seq_len_q + k, // head_dim + &alpha, + k_ptr, k, // K matrix + stride_k, + q_ptr, k, // Q matrix + stride_q, + &beta, + scores_ptr, n, // Output scores + stride_scores, + batch_count); + + if (status != CUBLAS_STATUS_SUCCESS) { + ET_LOG(Error, "sdpa_math_fallback: cuBLAS GEMM failed for Q @ K^T"); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 3: Apply softmax with scaling and optional causal mask + const int threads_per_block = 256; + const int64_t total_rows = batch * num_heads * seq_len_q; + const int num_blocks = (total_rows + threads_per_block - 1) / threads_per_block; + + softmax_with_causal_mask_kernel<<>>( + scores_ptr, + scores_ptr, // in-place + batch, + num_heads, + seq_len_q, + seq_len_k, + is_causal, + scale_factor); + + cudaError_t cuda_err = cudaGetLastError(); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "sdpa_math_fallback: Softmax kernel launch failed: %s", + cudaGetErrorString(cuda_err)); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 4: Allocate output tensor [batch, num_heads, seq_len_q, head_dim_v] + Tensor* output = nullptr; + std::array output_shape = {batch, num_heads, seq_len_q, head_dim_v}; + std::array output_stride = { + num_heads * seq_len_q * head_dim_v, + seq_len_q * head_dim_v, + head_dim_v, + 1}; + + auto dtype_int = static_cast(query->dtype()); + aoti_torch_empty_strided( + 4, + output_shape.data(), + output_stride.data(), + dtype_int, + static_cast(SupportedDevices::CUDA), + 0, + &output); + + if (output == nullptr) { + ET_LOG(Error, "sdpa_math_fallback: Failed to allocate output tensor"); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 5: Compute attention_weights @ V + // attention_weights: [batch * num_heads, seq_len_q, seq_len_k] + // V: [batch * num_heads, seq_len_k, head_dim_v] + // Output: [batch * num_heads, seq_len_q, head_dim_v] + + const int m_v = seq_len_q; + const int n_v = head_dim_v; + const int k_v = seq_len_k; + + const scalar_t* v_ptr = reinterpret_cast(value->data_ptr()); + scalar_t* out_ptr = reinterpret_cast(output->data_ptr()); + + const int64_t stride_v = seq_len_k * head_dim_v; + const int64_t stride_out = seq_len_q * head_dim_v; + + status = batched_gemm( + handle, + CUBLAS_OP_N, // No transpose V + CUBLAS_OP_N, // No transpose attention_weights + n_v, // head_dim_v + m_v, // seq_len_q + k_v, // seq_len_k + &alpha, + v_ptr, n_v, // V matrix + stride_v, + scores_ptr, k_v, // attention_weights + stride_scores, + &beta, + out_ptr, n_v, // Output + stride_out, + batch_count); + + // Cleanup temporary buffers + cudaFree(scores_ptr); + + if (status != CUBLAS_STATUS_SUCCESS) { + ET_LOG(Error, "sdpa_math_fallback: cuBLAS GEMM failed for attention_weights @ V"); + aoti_torch_delete_tensor_object(output); + return nullptr; + } + + return output; +} + +Tensor* sdpa_math_fallback( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + double scale_factor, + cudaStream_t stream) { + + // Dispatch based on dtype + auto dtype = query->dtype(); + + if (dtype == executorch::aten::ScalarType::Float) { + return sdpa_math_fallback_impl( + query, key, value, attn_mask, is_causal, + static_cast(scale_factor), stream); + } else if (dtype == executorch::aten::ScalarType::Half) { + return sdpa_math_fallback_impl<__half>( + query, key, value, attn_mask, is_causal, + static_cast(scale_factor), stream); + } else if (dtype == executorch::aten::ScalarType::BFloat16) { + return sdpa_math_fallback_impl<__nv_bfloat16>( + query, key, value, attn_mask, is_causal, + static_cast(scale_factor), stream); + } else { + ET_LOG(Error, "sdpa_math_fallback: Unsupported dtype"); + return nullptr; + } +} + +/** + * Main entry point for SDPA computation + */ +Tensor* scaled_dot_product_attention_cuda( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal, + const double* scale, + bool enable_gqa, + cudaStream_t stream) { + + // Select backend + SDPBackend backend = select_sdp_backend( + query, key, value, attn_mask, dropout_p, is_causal); + + if (backend == SDPBackend::Error) { + ET_LOG(Error, "scaled_dot_product_attention_cuda: No valid backend selected"); + return nullptr; + } + + // Calculate scale factor + double scale_factor = calculate_scale(query, scale); + + // Handle GQA if needed + if (enable_gqa && is_gqa_configuration(query, key, value)) { + if (!validate_gqa(query, key, value)) { + ET_LOG(Error, "scaled_dot_product_attention_cuda: Invalid GQA configuration"); + return nullptr; + } + ET_LOG( + Error, + "scaled_dot_product_attention_cuda: GQA support not yet implemented. " + "Need to repeat K/V heads to match Q heads."); + return nullptr; + } + + // Dispatch to appropriate backend + switch (backend) { + case SDPBackend::Math: + return sdpa_math_fallback( + query, key, value, attn_mask, is_causal, scale_factor, stream); + + case SDPBackend::FlashAttention: + ET_LOG(Error, "Flash Attention backend not yet implemented"); + return nullptr; + + case SDPBackend::MemoryEfficientAttention: + ET_LOG(Error, "Memory Efficient Attention backend not yet implemented"); + return nullptr; + + case SDPBackend::CuDNN: + ET_LOG(Error, "cuDNN backend not yet implemented"); + return nullptr; + + default: + ET_LOG(Error, "Unknown SDPA backend"); + return nullptr; + } +} + +// ============================================================================ +// C API Implementation +// ============================================================================ + +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_scaled_dot_product_attention( + Tensor* query, + Tensor* key, + Tensor* value, + Tensor* attn_mask, + double dropout_p, + int32_t is_causal, + double* scale, + int32_t enable_gqa, + Tensor** ret0) { + + // Input validation + if (!query || !key || !value || !ret0) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Null pointer input"); + return Error::InvalidArgument; + } + + // Currently only support dropout_p = 0.0 for inference + if (dropout_p != 0.0) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: dropout_p != 0.0 is not supported"); + return Error::InvalidArgument; + } + + // Check tensor dimensions + if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query, Key, Value must be 4D tensors"); + return Error::InvalidArgument; + } + + // Check that Q, K, V have the same dtype + if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query, Key, Value must have the same dtype"); + return Error::InvalidArgument; + } + + // Check tensor shapes + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim_q = query->size(3); + + const int64_t num_heads_kv = key->size(1); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_k = key->size(3); + + const int64_t seq_len_v = value->size(2); + const int64_t head_dim_v = value->size(3); + + // Validate shapes + if (key->size(0) != batch || value->size(0) != batch) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Batch size mismatch"); + return Error::InvalidArgument; + } + + if (seq_len_k != seq_len_v) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Key and Value sequence length mismatch"); + return Error::InvalidArgument; + } + + if (head_dim_q != head_dim_k) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query and Key head dimension mismatch"); + return Error::InvalidArgument; + } + + if (value->size(1) != num_heads_kv) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Key and Value num_heads mismatch"); + return Error::InvalidArgument; + } + + // GQA validation + if (enable_gqa && num_heads % num_heads_kv != 0) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: For GQA, num_heads must be divisible by num_heads_kv"); + return Error::InvalidArgument; + } + + // Validate attn_mask if provided + if (attn_mask) { + if (is_causal) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Cannot use both attn_mask and is_causal"); + return Error::InvalidArgument; + } + } + + // Get CUDA stream + auto stream_result = getCurrentCUDAStream(0); + if (!stream_result.ok()) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Failed to get CUDA stream"); + return Error::Internal; + } + cudaStream_t stream = stream_result.get(); + + // Call the main SDPA function + Tensor* output = scaled_dot_product_attention_cuda( + query, + key, + value, + attn_mask, + dropout_p, + is_causal != 0, + scale, + enable_gqa != 0, + stream); + + if (output == nullptr) { + ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: SDPA computation failed"); + return Error::Internal; + } + + *ret0 = output; + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.cuh b/backends/cuda/runtime/shims/sdpa.cuh new file mode 100644 index 00000000000..5cc941f4120 --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.cuh @@ -0,0 +1,282 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file implements scaled_dot_product_attention for ExecuTorch. +// +// IMPLEMENTATION NOTES: +// --------------------- +// This is NOT a direct port from PyTorch. Instead, we implemented +// a custom Math Fallback using cuBLAS and custom CUDA kernels. +// +// PyTorch reference implementations (for architecture reference only): +// - CPU/General: aten/src/ATen/native/transformers/attention.cpp +// - CUDA: aten/src/ATen/native/transformers/cuda/attention.cu +// +// Key differences from PyTorch: +// - PyTorch uses high-level ATen ops (at::matmul, at::_safe_softmax) +// - We use direct cuBLAS calls and custom softmax kernels +// - Optimized for inference (no dropout, no backward pass) +// - Simplified memory management +// - No ATen/c10 dependencies +// +// PORTING NOTES: +// -------------- +// 1. KERNEL CODE: Adapted from PyTorch attention kernels +// - Math fallback implementation for maximum compatibility +// - Supports Float32, Float16, and BFloat16 dtypes +// - Standard attention computation: softmax(Q @ K^T / scale) @ V +// +// 2. API ADAPTATIONS: +// - Replaced at::Tensor with executorch::backends::aoti::Tensor +// - Output returned via pointer-to-pointer instead of by-value +// - Simplified interface for inference (dropout=0.0 only) +// +// 3. REMOVED FEATURES: +// - Flash Attention backend (requires external library) +// - Memory Efficient Attention backend (requires external library) +// - cuDNN backend (requires cuDNN library) +// - Dropout support (training-only feature) +// - Nested tensor support (complex layout) +// - Backward pass (training-only feature) +// +// 4. INFRASTRUCTURE CHANGES: +// - Removed c10::cuda::CUDAGuard: Device management handled by AOTI backend +// - Removed at::cuda::getCurrentCUDAStream(): Stream passed explicitly +// - Simplified error handling using ExecutorTorch Error codes + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +using executorch::runtime::Error; + +// ============================================================================ +// Utility Functions for SDPA +// ============================================================================ + +// Calculate the scaling factor for attention scores +inline double calculate_scale(const Tensor* query, const double* scale) { + if (scale != nullptr) { + return *scale; + } + // Default: 1 / sqrt(head_dim) + // Query shape: [batch, num_heads, seq_len_q, head_dim] + // head_dim is at index 3 (0-indexed) + const int64_t head_dim = query->size(3); + return 1.0 / std::sqrt(static_cast(head_dim)); +} + +// Check if tensor dtype is supported for SDPA +inline bool is_supported_dtype(const Tensor* tensor) { + auto dtype = tensor->dtype(); + return dtype == executorch::aten::ScalarType::Float || + dtype == executorch::aten::ScalarType::Half || + dtype == executorch::aten::ScalarType::BFloat16; +} + +// ============================================================================ +// Math Fallback Implementation +// ============================================================================ + +// This is the basic, portable implementation that works on all CUDA devices. +// It computes attention using explicit matrix multiplications and softmax: +// 1. Compute scores: S = Q @ K^T * scale +// 2. Apply mask if provided +// 3. Compute attention weights: A = softmax(S) +// 4. Compute output: O = A @ V + +/** + * Math fallback kernel for scaled dot product attention + * + * This is a basic implementation that performs: + * output = softmax(query @ key^T / scale) @ value + * + * Supports: + * - Batch processing + * - Multiple attention heads + * - Optional causal masking + * - Optional explicit attention mask + * - Float32, Float16, BFloat16 dtypes + * + * Note: This implementation is for reference and maximum compatibility. + * For production use, consider using Flash Attention or other optimized backends. + */ +Tensor* sdpa_math_fallback( + const Tensor* query, // [batch, num_heads, seq_len_q, head_dim] + const Tensor* key, // [batch, num_heads_kv, seq_len_k, head_dim] + const Tensor* value, // [batch, num_heads_kv, seq_len_k, head_dim_v] + const Tensor* attn_mask, // Optional: [batch, num_heads, seq_len_q, seq_len_k] or broadcastable + bool is_causal, // Apply causal masking + double scale_factor, // Scaling factor for attention scores + cudaStream_t stream); // CUDA stream for execution + +// ============================================================================ +// Backend Selection +// ============================================================================ + +enum class SDPBackend { + Error = -1, + Math = 0, + FlashAttention = 1, + MemoryEfficientAttention = 2, + CuDNN = 3 +}; + +/** + * Select the best available backend for SDPA based on input parameters + * + * For now, only Math fallback is supported. Future implementations may add: + * - Flash Attention (Ampere+ GPUs) + * - Memory Efficient Attention + * - cuDNN backend + */ +inline SDPBackend select_sdp_backend( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal) { + + // Check for unsupported features + if (dropout_p > 0.0) { + ET_LOG(Error, "SDPA: Dropout not supported in inference mode"); + return SDPBackend::Error; + } + + // Check tensor dimensions + if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { + ET_LOG(Error, "SDPA: All inputs must be 4D tensors"); + return SDPBackend::Error; + } + + // Check dtype support + if (!is_supported_dtype(query) || !is_supported_dtype(key) || !is_supported_dtype(value)) { + ET_LOG(Error, "SDPA: Unsupported dtype, only Float32/Float16/BFloat16 supported"); + return SDPBackend::Error; + } + + // Check dtype consistency + if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { + ET_LOG(Error, "SDPA: Query, Key, Value must have the same dtype"); + return SDPBackend::Error; + } + + // For now, always use math fallback + // Future: Add logic to select Flash Attention, MemEff, or cuDNN when available + return SDPBackend::Math; +} + +// ============================================================================ +// Helper Functions for Causal Mask +// ============================================================================ + +/** + * Check if we need to apply causal masking + */ +inline bool needs_causal_mask(bool is_causal, const Tensor* attn_mask) { + if (!is_causal) { + return false; + } + if (attn_mask != nullptr) { + ET_LOG(Error, "SDPA: Cannot use both is_causal=true and explicit attn_mask"); + return false; + } + return true; +} + +// ============================================================================ +// Grouped Query Attention (GQA) Support +// ============================================================================ + +/** + * Check if inputs require GQA handling + * + * GQA allows num_heads_q != num_heads_kv, where num_heads_q must be + * divisible by num_heads_kv. Key and Value heads are repeated to match + * Query heads. + */ +inline bool is_gqa_configuration( + const Tensor* query, + const Tensor* key, + const Tensor* value) { + + const int64_t num_heads_q = query->size(1); + const int64_t num_heads_kv = key->size(1); + + return num_heads_q != num_heads_kv; +} + +/** + * Validate GQA configuration + */ +inline bool validate_gqa( + const Tensor* query, + const Tensor* key, + const Tensor* value) { + + const int64_t num_heads_q = query->size(1); + const int64_t num_heads_kv = key->size(1); + const int64_t num_heads_v = value->size(1); + + // Key and Value must have same num_heads + if (num_heads_kv != num_heads_v) { + ET_LOG(Error, "SDPA GQA: Key and Value must have same num_heads"); + return false; + } + + // Query heads must be divisible by Key/Value heads + if (num_heads_q % num_heads_kv != 0) { + ET_LOG(Error, "SDPA GQA: Query num_heads must be divisible by Key/Value num_heads"); + return false; + } + + return true; +} + +// ============================================================================ +// Main SDPA Entry Point +// ============================================================================ + +/** + * Compute scaled dot product attention + * + * This is the main entry point that selects the appropriate backend + * and dispatches to the corresponding implementation. + * + * Currently only Math fallback is implemented. Future versions may add: + * - Flash Attention + * - Memory Efficient Attention + * - cuDNN backend + */ +Tensor* scaled_dot_product_attention_cuda( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal, + const double* scale, + bool enable_gqa, + cudaStream_t stream); + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.h b/backends/cuda/runtime/shims/sdpa.h new file mode 100644 index 00000000000..4db08576ca0 --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Performs scaled dot-product attention on CUDA. + * + * This is a port of PyTorch's scaled_dot_product_attention CUDA implementation + * (aten/src/ATen/native/transformers/cuda/attention.cu) adapted for the + * ExecuTorch runtime. + * + * Computes attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ V + * + * HARDWARE REQUIREMENTS: + * - CUDA-capable GPU + * - Supports Flash Attention if available (Ampere+ GPUs) + * + * TENSOR REQUIREMENTS: + * @param query Query tensor [batch, num_heads, seq_len_q, head_dim] + * - Must be Float32, Float16, or BFloat16 dtype + * - Must be 4D + * - Must be on CUDA device + * + * @param key Key tensor [batch, num_heads_kv, seq_len_k, head_dim] + * - Must be same dtype as query + * - Must be 4D + * - Must be on CUDA device + * - num_heads_kv can be different from num_heads (for GQA) + * + * @param value Value tensor [batch, num_heads_kv, seq_len_k, head_dim_v] + * - Must be same dtype as query + * - Must be 4D + * - Must be on CUDA device + * + * @param attn_mask Optional attention mask [batch, num_heads, seq_len_q, seq_len_k] + * or broadcastable shape + * - Can be nullptr (no mask) + * - If provided, must be Float32, BFloat16, or Bool dtype + * - Additive mask: positions with large negative values are masked out + * + * @param dropout_p Dropout probability (0.0 to 1.0) + * - Currently only supports 0.0 (no dropout) + * - Must be 0.0 for inference + * + * @param is_causal Whether to apply causal masking + * - If true, applies lower triangular mask + * - Cannot be used together with explicit attn_mask + * + * @param scale Optional scaling factor for attention scores + * - If nullptr, uses 1/sqrt(head_dim) by default + * - If provided, uses the specified value + * + * @param enable_gqa Enable grouped query attention support + * - Allows num_heads_kv != num_heads + * - Query heads must be divisible by key/value heads + * + * @param ret0 Output parameter for attention result + * [batch, num_heads, seq_len_q, head_dim_v] + * - Allocated by this function + * - Same dtype as input tensors + * - Must not be null + * - Caller is responsible for freeing via aoti_torch_delete_tensor_object() + * + * @return AOTITorchError error code: + * - Error::Ok: Success + * - Error::InvalidArgument: Null pointer, wrong dtype, wrong dimensions, + * or invalid parameter combination + * - Error::Internal: CUDA kernel launch failure + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_scaled_dot_product_attention( + Tensor* query, + Tensor* key, + Tensor* value, + Tensor* attn_mask, + double dropout_p, + int32_t is_causal, + double* scale, + int32_t enable_gqa, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b274ecf3675..0896b3b6a3b 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -34,4 +34,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") + cuda_shim_cpp_unittest("aoti_torch_cuda_scaled_dot_product_attention") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp new file mode 100644 index 00000000000..e2677878ea0 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp @@ -0,0 +1,781 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for SDPA tests +class AOTITorchSDPATest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + } + + void TearDown() override { + // Clean up after each test + cleanup_tensor_metadata(); + } + + // Helper function to create a Float32 tensor filled with a specific value + Tensor* create_float_tensor( + std::vector shape, + float fill_value = 1.0f) { + Tensor* tensor = nullptr; + + // Calculate size + int64_t total_size = 1; + for (auto dim : shape) { + total_size *= dim; + } + + // Calculate strides (row-major) + std::vector strides(shape.size()); + int64_t stride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + + // Create tensor + Error error = aoti_torch_empty_strided( + shape.size(), + shape.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill with value + std::vector host_data(total_size, fill_value); + cudaMemcpy( + tensor->data_ptr(), + host_data.data(), + total_size * sizeof(float), + cudaMemcpyHostToDevice); + + return tensor; + } + + // Helper function to create a BFloat16 tensor + Tensor* create_bfloat16_tensor( + std::vector shape, + float fill_value = 1.0f) { + Tensor* tensor = nullptr; + + // Calculate size + int64_t total_size = 1; + for (auto dim : shape) { + total_size *= dim; + } + + // Calculate strides (row-major) + std::vector strides(shape.size()); + int64_t stride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + + // Create tensor + Error error = aoti_torch_empty_strided( + shape.size(), + shape.data(), + strides.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill with value + // Note: For simplicity, we'll fill with float and let the runtime handle conversion + // In production, you'd want to properly convert to bfloat16 + std::vector host_data(total_size, fill_value); + cudaMemcpy( + tensor->data_ptr(), + host_data.data(), + total_size * sizeof(float), + cudaMemcpyHostToDevice); + + return tensor; + } + + // Helper to check if output tensor has expected shape + bool check_output_shape( + Tensor* output, + const std::vector& expected_shape) { + if (output == nullptr) { + return false; + } + if (output->dim() != expected_shape.size()) { + return false; + } + for (size_t i = 0; i < expected_shape.size(); ++i) { + if (output->size(i) != expected_shape[i]) { + return false; + } + } + return true; + } + + // Helper to copy tensor data from GPU to CPU for verification + std::vector copy_tensor_to_host(Tensor* tensor) { + int64_t total_size = 1; + for (int i = 0; i < tensor->dim(); ++i) { + total_size *= tensor->size(i); + } + + std::vector host_data(total_size); + cudaMemcpy( + host_data.data(), + tensor->data_ptr(), + total_size * sizeof(float), + cudaMemcpyDeviceToHost); + + return host_data; + } + + // Helper to check if a value is approximately equal (for floating point comparison) + bool approx_equal(float a, float b, float epsilon = 1e-5f) { + return std::abs(a - b) < epsilon; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +// Test basic SDPA with Float32, no causal mask +TEST_F(AOTITorchSDPATest, BasicFunctionalityFloat32) { + // Create tensors: [batch=1, num_heads=2, seq_len=4, head_dim=8] + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr) << "Failed to create query tensor"; + ASSERT_NE(key, nullptr) << "Failed to create key tensor"; + ASSERT_NE(value, nullptr) << "Failed to create value tensor"; + + printf("Testing SDPA Float32: [%ldx%ldx%ldx%ld]\n", batch, num_heads, seq_len, head_dim); + + // Call SDPA + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, + key, + value, + nullptr, // no explicit mask + 0.0, // no dropout + 0, // not causal + nullptr, // default scale + 0, // no GQA + &output); + + // Check result + EXPECT_EQ(error, Error::Ok) << "SDPA should succeed"; + ASSERT_NE(output, nullptr) << "Output should not be null"; + + // Verify output shape: [batch, num_heads, seq_len, head_dim] + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})) + << "Output shape mismatch"; + + printf("SDPA Float32 test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with causal masking +TEST_F(AOTITorchSDPATest, CausalMasking) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 8; + const int64_t head_dim = 16; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with causal masking: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + // Call SDPA with causal mask + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, + key, + value, + nullptr, + 0.0, + 1, // causal mask enabled + nullptr, + 0, + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Causal masking test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with BFloat16 +TEST_F(AOTITorchSDPATest, BFloat16Precision) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 16; + const int64_t head_dim = 32; + + Tensor* query = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr) << "Failed to create BFloat16 query tensor"; + ASSERT_NE(key, nullptr) << "Failed to create BFloat16 key tensor"; + ASSERT_NE(value, nullptr) << "Failed to create BFloat16 value tensor"; + + printf("Testing SDPA BFloat16: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, + key, + value, + nullptr, + 0.0, + 0, + nullptr, + 0, + &output); + + EXPECT_EQ(error, Error::Ok) << "SDPA BFloat16 should succeed"; + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("BFloat16 precision test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with custom scale factor +TEST_F(AOTITorchSDPATest, CustomScale) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with custom scale\n"); + + // Use custom scale instead of default 1/sqrt(head_dim) + double custom_scale = 0.25; + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, + key, + value, + nullptr, + 0.0, + 0, + &custom_scale, // custom scale + 0, + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Custom scale test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test with larger tensors (closer to real-world usage) +TEST_F(AOTITorchSDPATest, LargerTensors) { + const int64_t batch = 4; + const int64_t num_heads = 8; + const int64_t seq_len = 128; + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with larger tensors: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, + key, + value, + nullptr, + 0.0, + 1, // causal + nullptr, + 0, + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Larger tensors test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +// Test null pointer handling +TEST_F(AOTITorchSDPATest, NullPointerHandling) { + Tensor* query = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* key = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* value = create_float_tensor({1, 1, 4, 8}, 1.0f); + Tensor* output = nullptr; + + // Test null query + { + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + nullptr, key, value, nullptr, 0.0, 0, nullptr, 0, &output); + EXPECT_NE(error, Error::Ok) << "Should fail with null query"; + } + + // Test null key + { + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, nullptr, value, nullptr, 0.0, 0, nullptr, 0, &output); + EXPECT_NE(error, Error::Ok) << "Should fail with null key"; + } + + // Test null value + { + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, nullptr, nullptr, 0.0, 0, nullptr, 0, &output); + EXPECT_NE(error, Error::Ok) << "Should fail with null value"; + } + + // Test null output pointer + { + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, nullptr, 0, nullptr); + EXPECT_NE(error, Error::Ok) << "Should fail with null output pointer"; + } + + printf("Null pointer handling tests passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// Test dimension mismatch +TEST_F(AOTITorchSDPATest, DimensionMismatch) { + Tensor* query = create_float_tensor({1, 2, 4, 8}, 0.5f); + Tensor* key = create_float_tensor({1, 2, 6, 8}, 0.5f); // Different seq_len + Tensor* value = create_float_tensor({1, 2, 6, 8}, 1.0f); + Tensor* output = nullptr; + + // This should succeed (Q and K can have different seq_len) + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); + + EXPECT_EQ(error, Error::Ok) << "Different Q and K seq_len should be allowed"; + + if (output != nullptr) { + // Output should have Q's seq_len + EXPECT_EQ(output->size(2), 4) << "Output seq_len should match Query"; + aoti_torch_delete_tensor_object(output); + } + + printf("Dimension handling test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// Test dropout error (should fail since we don't support dropout) +TEST_F(AOTITorchSDPATest, DropoutNotSupported) { + Tensor* query = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* key = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* value = create_float_tensor({1, 1, 4, 8}, 1.0f); + Tensor* output = nullptr; + + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.5, 0, nullptr, 0, &output); // dropout=0.5 + + EXPECT_NE(error, Error::Ok) << "Should fail with non-zero dropout"; + + printf("Dropout rejection test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// ============================================================================ +// Numerical Correctness Tests +// ============================================================================ + +// Test that output values are in reasonable range +TEST_F(AOTITorchSDPATest, OutputValueRangeCheck) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // Use small values to avoid numerical overflow + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA output value range\n"); + + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output back to CPU for verification + std::vector output_data = copy_tensor_to_host(output); + + // Since V is all 1.0, and softmax produces weights that sum to 1, + // output should be close to 1.0 (weighted average of 1.0) + bool all_in_range = true; + for (size_t i = 0; i < output_data.size(); ++i) { + // Output should be around 1.0 with some tolerance + if (output_data[i] < 0.5f || output_data[i] > 1.5f) { + printf("Output[%zu] = %f is out of expected range [0.5, 1.5]\n", + i, output_data[i]); + all_in_range = false; + } + } + + EXPECT_TRUE(all_in_range) << "Some output values are out of reasonable range"; + + printf("Output value range check passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test with identity Q=K, verify attention weights sum to 1 +TEST_F(AOTITorchSDPATest, IdentityQKTest) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // When Q=K, attention scores will be uniform (since all positions are equally similar) + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 2.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with Q=K (identity attention)\n"); + + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output back to CPU + std::vector output_data = copy_tensor_to_host(output); + + // When Q=K and V is uniform, output should be close to V + // (since attention weights are uniform due to identical scores) + bool values_correct = true; + for (size_t i = 0; i < output_data.size(); ++i) { + // Output should be close to 2.0 (the value of V) + if (!approx_equal(output_data[i], 2.0f, 0.1f)) { + printf("Output[%zu] = %f, expected ~2.0\n", i, output_data[i]); + values_correct = false; + } + } + + EXPECT_TRUE(values_correct) << "Output values don't match expected for identity Q=K"; + + printf("Identity Q=K test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test that different scales produce different outputs +TEST_F(AOTITorchSDPATest, ScaleEffectTest) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + // Make K different at different positions so attention scores vary + std::vector key_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // Different values per position: pos 0=0.1, pos 1=0.3, pos 2=0.5, pos 3=0.7 + key_host[pos * head_dim + d] = 0.1f + 0.2f * pos; + } + } + cudaMemcpy( + key->data_ptr(), + key_host.data(), + key_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + // Make V also different at different positions to amplify differences + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // V values: pos 0=1.0, pos 1=2.0, pos 2=3.0, pos 3=4.0 + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + printf("Testing SDPA scale effect\n"); + + // Test with default scale + Tensor* output1 = nullptr; + AOTITorchError error1 = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, nullptr, 0, &output1); + ASSERT_EQ(error1, Error::Ok); + ASSERT_NE(output1, nullptr); + + // Test with custom scale (much smaller, should make attention more uniform) + double small_scale = 0.01; + Tensor* output2 = nullptr; + AOTITorchError error2 = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 0, &small_scale, 0, &output2); + ASSERT_EQ(error2, Error::Ok); + ASSERT_NE(output2, nullptr); + + // Copy outputs back to CPU + std::vector output1_data = copy_tensor_to_host(output1); + std::vector output2_data = copy_tensor_to_host(output2); + + // Outputs should be different (scale affects softmax sharpness) + // With varied V values, even small changes in attention weights will produce + // noticeably different outputs + bool outputs_differ = false; + float max_diff = 0.0f; + for (size_t i = 0; i < output1_data.size(); ++i) { + float diff = std::abs(output1_data[i] - output2_data[i]); + max_diff = std::max(max_diff, diff); + if (diff > 0.05f) { // More lenient threshold due to varied V values + outputs_differ = true; + break; + } + } + + printf("Max difference between outputs: %f\n", max_diff); + EXPECT_TRUE(outputs_differ) << "Different scales should produce different outputs (max_diff=" << max_diff << ")"; + + printf("Scale effect test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output1); + aoti_torch_delete_tensor_object(output2); +} + +// Test causal masking correctness +TEST_F(AOTITorchSDPATest, CausalMaskingCorrectness) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // Create distinct values at different positions in V + // This allows us to verify that causal masking works correctly + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + // Manually set different values for each position in V + // V[position i] = i+1 (so we can track which positions contribute) + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + printf("Testing SDPA causal masking correctness\n"); + + // Run with causal masking + Tensor* output_causal = nullptr; + AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( + query, key, value, nullptr, 0.0, 1, nullptr, 0, &output_causal); + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output_causal, nullptr); + + // Copy output back to CPU + std::vector output_data = copy_tensor_to_host(output_causal); + + // With causal masking: + // - Position 0 can only see position 0, so output[0] should be ~1.0 + // - Position 1 can see positions 0,1, so output[1] should be ~1.5 (average of 1 and 2) + // - Position 2 can see positions 0,1,2, so output[2] should be ~2.0 (average of 1,2,3) + // - Position 3 can see all, so output[3] should be ~2.5 (average of 1,2,3,4) + + std::vector expected_values = {1.0f, 1.5f, 2.0f, 2.5f}; + + bool causal_correct = true; + for (int64_t pos = 0; pos < seq_len; ++pos) { + float avg_output = 0.0f; + for (int64_t d = 0; d < head_dim; ++d) { + avg_output += output_data[pos * head_dim + d]; + } + avg_output /= head_dim; + + printf("Position %ld: output avg = %f, expected ~%f\n", + pos, avg_output, expected_values[pos]); + + if (!approx_equal(avg_output, expected_values[pos], 0.2f)) { + causal_correct = false; + } + } + + EXPECT_TRUE(causal_correct) << "Causal masking did not produce expected values"; + + printf("Causal masking correctness test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output_causal); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From c89691de02c7830d1765d8be5f03adf482a06894 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 9 Dec 2025 14:01:16 -0800 Subject: [PATCH 2/8] init --- 06-fused-attention.py | 775 +++++++++++++++++++++++++++ backends/cuda/cuda_backend.py | 34 +- backends/cuda/triton/kernels/sdpa.py | 482 +++++++---------- 3 files changed, 1000 insertions(+), 291 deletions(-) create mode 100644 06-fused-attention.py diff --git a/06-fused-attention.py b/06-fused-attention.py new file mode 100644 index 00000000000..b1283e8ef99 --- /dev/null +++ b/06-fused-attention.py @@ -0,0 +1,775 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import os + +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_host_descriptor(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def is_blackwell(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 10 + + +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 9 + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype: tl.constexpr, start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + offsetk_y = offset_y + lo + if dtype == tl.float8e5: + offsetv_y = offset_y * HEAD_DIM + lo + else: + offsetv_y = offset_y + lo + # loop over k, v and update accumulator + for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = desc_k.load([offsetk_y, 0]).T + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + l_ij = tl.sum(p, 1) + # -- update output accumulator -- + if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: + BM: tl.constexpr = acc.shape[0] + BN: tl.constexpr = acc.shape[1] + acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] + acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) + else: + acc = acc * alpha[:, None] + # prepare p and v for the dot + if dtype == tl.float8e5: + v = desc_v.load([0, offsetv_y]).T + else: + v = desc_v.load([offsetv_y, 0]) + p = p.to(dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = tl.dot(p, v, acc) + # update m_i and l_i + # place this at the end of the loop to reduce register pressure + l_i = l_i * alpha + l_ij + m_i = m_ij + offsetk_y += BLOCK_N + offsetv_y += BLOCK_N + return acc, l_i, m_i + + +def _host_descriptor_pre_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + HEAD_DIM = nargs["HEAD_DIM"] + if not isinstance(nargs["desc_q"], TensorDescriptor): + return + nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] + if nargs["FP8_OUTPUT"]: + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] + else: + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] + + +if is_hip(): + NUM_STAGES_OPTIONS = [1] +elif supports_host_descriptor(): + NUM_STAGES_OPTIONS = [2, 3, 4] +else: + NUM_STAGES_OPTIONS = [2, 3, 4] + +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ + for BM in [64, 128]\ + for BN in [32, 64, 128]\ + for s in NUM_STAGES_OPTIONS \ + for w in [4, 8]\ +] +if "PYTEST_VERSION" in os.environ: + # Use a single config in testing for reproducibility + configs = [ + triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), + ] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 + and conf.num_warps == 8) + + +def prune_invalid_configs(configs, named_args, **kwargs): + N_CTX = kwargs["N_CTX"] + STAGE = kwargs["STAGE"] + + # Filter out configs where BLOCK_M > N_CTX + # Filter out configs where BLOCK_M < BLOCK_N when causal is True + return [ + conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX and ( + conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0) or STAGE == 1) + ] + + +@triton.jit +def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): + if isinstance(desc_or_ptr, tl.tensor_descriptor): + return desc_or_ptr + else: + return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) + + +@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], + prune_configs_by={'early_config_prune': prune_invalid_configs}) +@triton.jit +def _attn_fwd(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # + warp_specialize: tl.constexpr, # + IS_HOPPER: tl.constexpr, # + ): + dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + y_dim = Z * H * N_CTX + desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + if FP8_OUTPUT: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim], strides=[N_CTX, 1], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = desc_q.load([qo_offset_y, 0]) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, # + warp_specialize, IS_HOPPER) + # stage 2: on-band + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, # + warp_specialize, IS_HOPPER) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + desc_o.store([qo_offset_y, 0], acc.to(dtype)) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + CAUSAL: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = 0 + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + if CAUSAL: + start_m = start_n + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True, # + ) + + start_m += num_steps * MASK_BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + num_steps = (N_CTX - start_m) // BLOCK_M1 + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False, # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + start_n = 0 + num_steps = N_CTX // BLOCK_N2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + if CAUSAL: + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + end_n = start_m + BLOCK_M2 + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True, # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + start_n = end_n - num_steps * BLOCK_N2 + + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, start_n, num_steps, # + MASK=False, # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + # Use device_descriptor for Hopper + warpspec. + if supports_host_descriptor() and not (is_hopper() and warp_specialize): + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + + dummy_block = [1, 1] + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + if q.dtype == torch.float8_e5m2: + desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], + block_shape=dummy_block) + else: + desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], + block_shape=dummy_block) + desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + else: + desc_q = q + desc_v = v + desc_k = k + desc_o = o + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + def grid(META): + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + + ctx.grid = grid + if is_blackwell() and warp_specialize: + if HEAD_DIM_K == 128 and q.dtype == torch.float16: + extra_kern_args["maxnreg"] = 168 + else: + extra_kern_args["maxnreg"] = 80 + _attn_fwd[grid]( + sm_scale, M, # + q.shape[0], q.shape[1], # + desc_q, desc_k, desc_v, desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + STAGE=stage, # + warp_specialize=warp_specialize, # + IS_HOPPER=is_hopper(), # + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES, # + CAUSAL=ctx.causal, # + ) + + return dq, dk, dv, None, None, None, None + + +attention = _attention.apply + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') + + +@pytest.mark.parametrize("Z", [1, 4]) +@pytest.mark.parametrize("H", [2, 48]) +@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024]) +@pytest.mark.parametrize("HEAD_DIM", [64, 128]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False]) +@pytest.mark.parametrize("mode", ["fwd", "bwd"]) +@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else [])) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16): + if mode == "fwd" and "fp16" in provider: + pytest.skip("Avoid running the forward computation twice.") + if mode == "bwd" and "fp8" in provider: + pytest.skip("Backward pass with FP8 is not supported.") + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + sm_scale = 0.5 + # reference implementation + ref_dtype = dtype + if mode == "fwd" and "fp8" in provider: + ref_dtype = torch.float32 + q = q.to(ref_dtype) + k = k.to(ref_dtype) + v = v.to(ref_dtype) + M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1) + p = p.to(ref_dtype) + # p = torch.exp(p) + ref_out = torch.matmul(p, v).half() + if mode == "bwd": + dout = torch.randn_like(q) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half() + if mode == "fwd": + atol = 3 if "fp8" in provider else 1e-2 + torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) + return + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0) + rtol = 0.0 + # Relative tolerance workaround for known hardware limitation of CDNA2 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + rtol = 1e-2 + torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) + torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) + torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS = 4, 32 +# vary seq length for fixed head and batch=4 +configs = [] +for HEAD_DIM in [64, 128]: + for mode in ["fwd", "bwd"]: + for causal in [True, False]: + # Enable warpspec for causal fwd on Hopper + enable_ws = mode == "fwd" and (is_blackwell() or (is_hopper() and not causal)) + for warp_specialize in [False, True] if enable_ws else [False]: + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + (["flash"] if HAS_FLASH else []), + line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="TFLOPS", + plot_name= + f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, + "warp_specialize": warp_specialize, + }, + )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE): + assert mode in ["fwd", "bwd"] + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale, warp_specialize) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops * 1e-12 / (ms * 1e-3) + + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f0d3a000ec0..eb1226ebf8a 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -134,20 +134,20 @@ def get_aoti_compile_options( return options - @classmethod - def get_extra_aoti_compile_context_manager(cls): - """ - Return SDPA MATH backend context manager for CUDA compilation. - - This context manager plays as a fallback solution for any remaining PyTorch SDPA - operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. - - Note: - - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, - this context manager will have no effect on those ops (they are no longer - PyTorch SDPA ops). - - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this - context manager will force them to use the MATH backend, causing them to - be automatically decomposed during compilation. - """ - return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) + # @classmethod + # def get_extra_aoti_compile_context_manager(cls): + # """ + # Return SDPA MATH backend context manager for CUDA compilation. + + # This context manager plays as a fallback solution for any remaining PyTorch SDPA + # operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + # Note: + # - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + # this context manager will have no effect on those ops (they are no longer + # PyTorch SDPA ops). + # - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + # context manager will force them to use the MATH backend, causing them to + # be automatically decomposed during compilation. + # """ + # return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 7e8eb1444df..c06eb4bb5bd 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -12,7 +12,6 @@ us export the model without decomposing the SDPA operator under libtorch free environment and have better performance. """ - import math from typing import Optional @@ -22,254 +21,183 @@ from torch.library import triton_op, wrap_triton -def _next_power_of_2(n: int) -> int: - """Round up to the next power of 2.""" - if n <= 0: - return 1 - if n & (n - 1) == 0: - return n - - power = 1 - while power < n: - power <<= 1 - return power +AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), +] -def _validate_qkv_shapes( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, -) -> tuple[int, int, int, int, int, int]: - """ - Validate dimensions and return shape info. - Args: - query: Query tensor [B, H, L_q, D] - key: Key tensor [B, H, L_kv, D] - value: Value tensor [B, H, L_kv, D] - Returns: - Tuple of (B, H, L_q, L_kv, D_q, D_kv) - Raises: - RuntimeError: If dimensions are incompatible - """ - B_q, H_q, L_q, D_q = query.shape - B_k, H_k, L_kv_k, D_k = key.shape - B_v, H_v, L_kv_v, D_v = value.shape - # Validate batch and head dimensions - if not (B_q == B_k == B_v): - raise RuntimeError( - f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." - ) - - if not (H_q == H_k == H_v): - raise RuntimeError( - f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." - ) - # Head dimension must match - if not (D_q == D_k == D_v): - raise RuntimeError( - f"Head dimension must match across Q, K, V; got D_q={D_q}, D_k={D_k}, D_v={D_v}." - ) - # Key and Value sequence lengths must match - if L_kv_k != L_kv_v: - raise RuntimeError( - f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." - ) - return B_q, H_q, L_q, L_kv_k, D_q, D_k - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=1, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=1, num_warps=2), - ], - key=["L_Q", "L_KV", "HEAD_DIM"], -) +@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["HEAD_DIM", "HAS_MASK", "MASK_IS_BOOL"]) @triton.jit def _sdpa_fwd_kernel( - q_ptr, - k_ptr, - v_ptr, - mask_ptr, - o_ptr, + Q_ptr, + K_ptr, + V_ptr, + O_ptr, B, H, - L_Q, # Query sequence length - L_KV, # Key/Value sequence length - HEAD_DIM, # Actual head dimension (may not be power of 2) + LQ, + LK, stride_qb, stride_qh, - stride_ql, + stride_qm, stride_qd, stride_kb, stride_kh, - stride_kl, + stride_kn, stride_kd, stride_vb, stride_vh, - stride_vl, + stride_vn, stride_vd, - stride_mb, - stride_mh, - stride_ml, - stride_mn, stride_ob, stride_oh, - stride_ol, + stride_om, stride_od, - sm_scale, - IS_CAUSAL: tl.constexpr, + scale, + mask_ptr, + stride_mb, + stride_mh, + stride_mm, + stride_mn, HAS_MASK: tl.constexpr, + MASK_IS_BOOL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - HEAD_DIM_CE: tl.constexpr, # Rounded up for tl.arange + HEAD_DIM: tl.constexpr, + GROUP_M: tl.constexpr, ): - """ - Fused SDPA kernel that handles different sequence lengths for Q and K/V. + pid_m_in = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + b = pid_bh // H + h = pid_bh % H - Q shape: [B, H, L_Q, D] - K/V shape: [B, H, L_KV, D] - Output shape: [B, H, L_Q, D] - """ - # Program IDs - pid_m = tl.program_id(axis=0) # along query length - pid_hz = tl.program_id(axis=1) # flattened batch*head - off_b = pid_hz // H - off_h = pid_hz % H - # Compute ranges for queries + num_pid_m = tl.cdiv(LQ, BLOCK_M) + group_id = pid_m_in // GROUP_M + first_pid_m = group_id * GROUP_M + pid_m = first_pid_m + (pid_m_in + pid_bh) % GROUP_M start_m = pid_m * BLOCK_M + if start_m >= LQ: + return + offs_m = start_m + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, HEAD_DIM_CE) - mask_m = offs_m < L_Q # Mask based on query length - # Base pointers for this (b, h) - q_base = q_ptr + off_b * stride_qb + off_h * stride_qh - k_base = k_ptr + off_b * stride_kb + off_h * stride_kh - v_base = v_ptr + off_b * stride_vb + off_h * stride_vh - o_base = o_ptr + off_b * stride_ob + off_h * stride_oh - # Mask base pointer (if provided) - if HAS_MASK: - mask_base = mask_ptr + off_b * stride_mb + off_h * stride_mh - # Mask for actual head dimension (HEAD_DIM may not be power of 2) - mask_d = offs_d < HEAD_DIM - # Make head-dim addresses compiler-friendly - offs_d_ctg = tl.max_contiguous(tl.multiple_of(offs_d, 16), HEAD_DIM_CE) - # Load Q tile [BLOCK_M, HEAD_DIM] - coalesced along HEAD_DIM - q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d_ctg[None, :] * stride_qd) - q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - q = q.to(tl.bfloat16) - # Initialize accumulators and softmax stats - acc = tl.zeros((BLOCK_M, HEAD_DIM_CE), dtype=tl.float32) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + offs_m = tl.multiple_of(offs_m, BLOCK_M) + offs_d = tl.multiple_of(offs_d, 16) + offs_d = tl.max_contiguous(offs_d, 16) + + q_ptrs = Q_ptr + ( + b * stride_qb + + h * stride_qh + + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd) + ) + q_mask = offs_m[:, None] < LQ + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) - # Convert to base-2 scale for exp2 - qk_scale = sm_scale * 1.4426950408889634 - # Loop over keys/values along L_KV dimension (not L_Q!) - for start_n in tl.range(0, L_KV, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = offs_n < L_KV # Mask based on key/value length - # Load K tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) - k_ptrs = k_base + ( - offs_n[:, None] * stride_kl + offs_d_ctg[None, :] * stride_kd + acc = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + + log2e = 1.4426950408889634 + scale_log2 = scale * log2e + + for start_n in tl.range(0, LK, BLOCK_N): + n_ids = start_n + offs_n + n_mask = n_ids < LK + + k_ptrs = K_ptr + ( + b * stride_kb + + h * stride_kh + + (offs_d[:, None] * stride_kd + n_ids[None, :] * stride_kn) ) - k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) - k = k.to(tl.bfloat16) - # Compute attention logits [BLOCK_M, BLOCK_N] = Q[BM,D] @ K[BN,D]^T - qk = tl.dot(q, tl.trans(k)).to(tl.float32) - qk = qk * qk_scale - # Apply causal mask if needed - # For causal masking with different lengths: position i can attend to position j if i >= j - if IS_CAUSAL: - causal_mask = offs_m[:, None] >= offs_n[None, :] - qk = tl.where(causal_mask, qk, -float("inf")) - # Apply attention mask if provided + k = tl.load(k_ptrs, mask=n_mask[None, :], other=0.0) + + qk = tl.dot(q, k).to(tl.float32) + qk = qk * scale_log2 + if HAS_MASK: - # Load mask tile [BLOCK_M, BLOCK_N] - # Mask shape should be [B, H, L_Q, L_KV] - mask_ptrs = mask_base + ( - offs_m[:, None] * stride_ml + offs_n[None, :] * stride_mn - ) - attn_mask = tl.load( - mask_ptrs, - mask=mask_m[:, None] & mask_n[None, :], - other=0.0, + m_ptrs = mask_ptr + ( + b * stride_mb + + h * stride_mh + + (offs_m[:, None] * stride_mm + n_ids[None, :] * stride_mn) ) - # Convert boolean mask to additive mask (-inf for False, 0 for True) - qk = tl.where(attn_mask, qk, -float("inf")) - # Apply OOB masks for both rows and cols - qk = tl.where(mask_n[None, :], qk, -float("inf")) - qk = tl.where(mask_m[:, None], qk, -float("inf")) - # Online softmax - m_ij = tl.maximum(m_i, tl.max(qk, 1)) + valid = (offs_m[:, None] < LQ) & n_mask[None, :] + if MASK_IS_BOOL: + m_bool = tl.load(m_ptrs, mask=valid, other=True) + qk = tl.where(m_bool, qk, -float("inf")) + else: + m_add = tl.load(m_ptrs, mask=valid, other=0.0).to(tl.float32) + qk = qk + m_add * log2e + + qk = tl.where(n_mask[None, :], qk, -float("inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.math.exp2(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) + l_ij = tl.sum(p, axis=1) alpha = tl.math.exp2(m_i - m_ij) - # Load V tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) - v_ptrs = v_base + ( - offs_n[:, None] * stride_vl + offs_d_ctg[None, :] * stride_vd + + v_ptrs = V_ptr + ( + b * stride_vb + + h * stride_vh + + (n_ids[:, None] * stride_vn + offs_d[None, :] * stride_vd) ) - v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) - v = v.to(tl.bfloat16) - # Update accumulator + v = tl.load(v_ptrs, mask=n_mask[:, None], other=0.0) + acc = acc * alpha[:, None] - p_bf16 = p.to(tl.bfloat16) - acc = tl.dot(p_bf16, v, acc) - # Update softmax stats + acc = tl.dot(p.to(tl.bfloat16), v, acc) + l_i = l_i * alpha + l_ij m_i = m_ij - # Normalize accumulator by softmax denominator - acc = acc / l_i[:, None] - # Store output [BLOCK_M, HEAD_DIM] - shape matches query - o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d_ctg[None, :] * stride_od) - tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + row_mask = offs_m < LQ + l_i = tl.where(row_mask, l_i, 1.0) + out = acc / l_i[:, None] -@triton_op("triton::sdpa", mutates_args={}) -def sdpa( + o_ptrs = O_ptr + ( + b * stride_ob + + h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :] * stride_od) + ) + tl.store(o_ptrs, out.to(tl.bfloat16), mask=row_mask[:, None]) + + +def _check_inputs( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float = 0.0, - enable_gqa: bool = False, -) -> torch.Tensor: - """ - Triton fused Scaled Dot-Product Attention with support for different sequence lengths. - - Args: - query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16 - key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 - value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 - attn_mask: Optional attention mask [B, H, L_q, L_kv] or - broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) - dropout_p: must be 0.0 (others are not supported) - is_causal: whether to apply causal masking - scale: attention scale (default: 1/sqrt(D)) - enable_gqa: must be False (True is not supported) - Returns: - Output tensor [B, H, L_q, D] with dtype torch.bfloat16 - """ - # Validate inputs + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + enable_gqa: bool, +): if not (query.is_cuda and key.is_cuda and value.is_cuda): - raise RuntimeError("Q, K, V must be CUDA tensors.") + raise ValueError("query, key, value must be CUDA tensors.") if ( query.dtype != torch.bfloat16 or key.dtype != torch.bfloat16 or value.dtype != torch.bfloat16 ): - raise RuntimeError("Expected bfloat16 inputs") - if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: - raise RuntimeError( - f"Expected 4D tensors shaped [B, H, L, D]; got " - f"query.dim()={query.dim()}, key.dim()={key.dim()}, " - f"value.dim()={value.dim()}." - ) + raise ValueError("This kernel expects bfloat16 inputs for query, key, value.") + if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: + raise ValueError("query, key, value must be 4D tensors [B, H, L, D].") + if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]: + raise ValueError("Batch dimension mismatch.") + if query.shape[1] != key.shape[1] or query.shape[1] != value.shape[1]: + raise ValueError("Heads dimension mismatch.") + if query.shape[-1] != key.shape[-1] or query.shape[-1] != value.shape[-1]: + raise ValueError("Head dimension (D) mismatch across query, key, value.") + if attn_mask is not None: + if attn_mask.ndim != 4: + raise ValueError("attn_mask must be 4D [B, H, L_q, L_k].") + if attn_mask.shape[0] != query.shape[0] or attn_mask.shape[1] != query.shape[1]: + raise ValueError("attn_mask batch/head dims must match query.") + if attn_mask.shape[2] != query.shape[2] or attn_mask.shape[3] != key.shape[2]: + raise ValueError("attn_mask spatial dims must be [L_q, L_k].") + if attn_mask.dtype not in (torch.bool, torch.bfloat16): + raise ValueError("attn_mask must be dtype bool or bfloat16.") + # Enforce unsupported features if dropout_p != 0.0: raise RuntimeError( @@ -279,85 +207,89 @@ def sdpa( raise RuntimeError( "enable_gqa must be False (not supported in this implementation)." ) - # Validate and get dimensions - B, H, L_q, L_kv, D_q, D_kv = _validate_qkv_shapes(query, key, value) - D = D_q # Head dimension - # Allocate output with query shape - out = torch.empty_like(query) - # Element-wise strides - sqb, sqh, sql, sqd = query.stride() - skb, skh, skl, skd = key.stride() - svb, svh, svl, svd = value.stride() - sob, soh, sol, sod = out.stride() - - # Grid: tile queries (M) and batch*heads axis - def grid(META): - return ( - triton.cdiv(L_q, META["BLOCK_M"]), # Based on query length - B * H, + if is_causal: + raise RuntimeError( + "is_causal must be False (not supported in this implementation)." ) - # Scale factor for SDPA - sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale - # Handle attention mask - has_mask = attn_mask is not None - if has_mask: - # Expand mask to [B, H, L_q, L_kv] if needed - if attn_mask.dim() == 2: - # [L_q, L_kv] -> [B, H, L_q, L_kv] - attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(B, H, -1, -1) - elif attn_mask.dim() == 3: - # [B, L_q, L_kv] -> [B, H, L_q, L_kv] - attn_mask = attn_mask.unsqueeze(1).expand(-1, H, -1, -1) - - # Validate mask shape - if attn_mask.shape != (B, H, L_q, L_kv): - # Try to expand if broadcastable - attn_mask = attn_mask.expand(B, H, L_q, L_kv) - - smb, smh, sml, smn = attn_mask.stride() + +@triton_op("triton::sdpa", mutates_args={}) +def sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + if attn_mask is not None and attn_mask.shape[1] == 1: + attn_mask = attn_mask.expand(-1, query.shape[1], -1, -1).contiguous() + + _check_inputs(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa) + + B, H, LQ, D = query.shape + LK = key.shape[2] + + out = torch.empty_like(query) + + if scale == 0: + scale = 1.0 / math.sqrt(D) + scale = float(scale) + + stride_qb, stride_qh, stride_qm, stride_qd = query.stride() + stride_kb, stride_kh, stride_kn, stride_kd = key.stride() + stride_vb, stride_vh, stride_vn, stride_vd = value.stride() + stride_ob, stride_oh, stride_om, stride_od = out.stride() + + HAS_MASK = 1 if attn_mask is not None else 0 + MASK_IS_BOOL = 1 if (attn_mask is not None and attn_mask.dtype == torch.bool) else 0 + if attn_mask is None: + mask_ptr = query + stride_mb = stride_mh = stride_mm = stride_mn = 0 else: - # Dummy strides and mask - smb, smh, sml, smn = 0, 0, 0, 0 - attn_mask = torch.empty(0, dtype=torch.bool, device=query.device) - # Round up head dimension to next power of 2 for tile.arange in Triton kernel - HEAD_DIM_CE = _next_power_of_2(D) - # Launch kernel + mask_ptr = attn_mask + stride_mb, stride_mh, stride_mm, stride_mn = attn_mask.stride() + + def grid(meta): + return (triton.cdiv(LQ, meta["BLOCK_M"]), B * H) + wrap_triton(_sdpa_fwd_kernel)[grid]( query, key, value, - attn_mask, out, B, H, - L_q, # Query sequence length - L_kv, # Key/Value sequence length - D, # Actual head dimension - sqb, - sqh, - sql, - sqd, - skb, - skh, - skl, - skd, - svb, - svh, - svl, - svd, - smb, - smh, - sml, - smn, - sob, - soh, - sol, - sod, - sm_scale, - IS_CAUSAL=is_causal, - HAS_MASK=has_mask, - HEAD_DIM_CE=HEAD_DIM_CE, # Rounded to power of 2 + LQ, + LK, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + scale, + mask_ptr, + stride_mb, + stride_mh, + stride_mm, + stride_mn, + HAS_MASK=HAS_MASK, + MASK_IS_BOOL=MASK_IS_BOOL, + HEAD_DIM=D, + GROUP_M=8, ) return out @@ -380,8 +312,10 @@ def _sdpa_abstract( This just returns an empty tensor with the correct shape/dtype/device. """ # Validate dtypes match - assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" + assert ( + query.dtype == key.dtype == value.dtype + ), "query, key, value must have the same dtype" # Validate kqv's shape and get the output shape - B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) + B, H, LQ, D = query.shape - return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) + return torch.empty(B, H, LQ, D, dtype=query.dtype, device=query.device) From f129ebd7fda02f3c237a84ce54e1ababa1585aae Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 9 Dec 2025 14:07:14 -0800 Subject: [PATCH 3/8] init --- 06-fused-attention.py | 775 ------------------------------------------ 1 file changed, 775 deletions(-) delete mode 100644 06-fused-attention.py diff --git a/06-fused-attention.py b/06-fused-attention.py deleted file mode 100644 index b1283e8ef99..00000000000 --- a/06-fused-attention.py +++ /dev/null @@ -1,775 +0,0 @@ -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) - -Credits: OpenAI kernel team - -Extra Credits: - -* Original flash attention paper (https://arxiv.org/abs/2205.14135) -* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) - -""" - -import pytest -import torch -import os - -import triton -import triton.language as tl -from triton.tools.tensor_descriptor import TensorDescriptor - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - - -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - - -def is_cuda(): - return triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def supports_host_descriptor(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 - - -def is_blackwell(): - return is_cuda() and torch.cuda.get_device_capability()[0] == 10 - - -def is_hopper(): - return is_cuda() and torch.cuda.get_device_capability()[0] == 9 - - -@triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype: tl.constexpr, start_m, qk_scale, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # - N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False - else: - lo, hi = 0, N_CTX - offsetk_y = offset_y + lo - if dtype == tl.float8e5: - offsetv_y = offset_y * HEAD_DIM + lo - else: - offsetv_y = offset_y + lo - # loop over k, v and update accumulator - for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = desc_k.load([offsetk_y, 0]).T - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) - l_ij = tl.sum(p, 1) - # -- update output accumulator -- - if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: - BM: tl.constexpr = acc.shape[0] - BN: tl.constexpr = acc.shape[1] - acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() - acc0 = acc0 * alpha[:, None] - acc1 = acc1 * alpha[:, None] - acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) - else: - acc = acc * alpha[:, None] - # prepare p and v for the dot - if dtype == tl.float8e5: - v = desc_v.load([0, offsetv_y]).T - else: - v = desc_v.load([offsetv_y, 0]) - p = p.to(dtype) - # note that this non transposed v for FP8 is only supported on Blackwell - acc = tl.dot(p, v, acc) - # update m_i and l_i - # place this at the end of the loop to reduce register pressure - l_i = l_i * alpha + l_ij - m_i = m_ij - offsetk_y += BLOCK_N - offsetv_y += BLOCK_N - return acc, l_i, m_i - - -def _host_descriptor_pre_hook(nargs): - BLOCK_M = nargs["BLOCK_M"] - BLOCK_N = nargs["BLOCK_N"] - HEAD_DIM = nargs["HEAD_DIM"] - if not isinstance(nargs["desc_q"], TensorDescriptor): - return - nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] - if nargs["FP8_OUTPUT"]: - nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] - else: - nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] - nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] - nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] - - -if is_hip(): - NUM_STAGES_OPTIONS = [1] -elif supports_host_descriptor(): - NUM_STAGES_OPTIONS = [2, 3, 4] -else: - NUM_STAGES_OPTIONS = [2, 3, 4] - -configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ - for BM in [64, 128]\ - for BN in [32, 64, 128]\ - for s in NUM_STAGES_OPTIONS \ - for w in [4, 8]\ -] -if "PYTEST_VERSION" in os.environ: - # Use a single config in testing for reproducibility - configs = [ - triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), - ] - - -def keep(conf): - BLOCK_M = conf.kwargs["BLOCK_M"] - BLOCK_N = conf.kwargs["BLOCK_N"] - return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 - and conf.num_warps == 8) - - -def prune_invalid_configs(configs, named_args, **kwargs): - N_CTX = kwargs["N_CTX"] - STAGE = kwargs["STAGE"] - - # Filter out configs where BLOCK_M > N_CTX - # Filter out configs where BLOCK_M < BLOCK_N when causal is True - return [ - conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX and ( - conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0) or STAGE == 1) - ] - - -@triton.jit -def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): - if isinstance(desc_or_ptr, tl.tensor_descriptor): - return desc_or_ptr - else: - return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) - - -@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], - prune_configs_by={'early_config_prune': prune_invalid_configs}) -@triton.jit -def _attn_fwd(sm_scale, M, # - Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # - HEAD_DIM: tl.constexpr, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - FP8_OUTPUT: tl.constexpr, # - STAGE: tl.constexpr, # - warp_specialize: tl.constexpr, # - IS_HOPPER: tl.constexpr, # - ): - dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 - tl.static_assert(BLOCK_N <= HEAD_DIM) - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - - y_dim = Z * H * N_CTX - desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_M, HEAD_DIM]) - if FP8_OUTPUT: - desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim], strides=[N_CTX, 1], - block_shape=[HEAD_DIM, BLOCK_N]) - else: - desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_N, HEAD_DIM]) - desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_N, HEAD_DIM]) - desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_M, HEAD_DIM]) - - offset_y = off_z * (N_CTX * H) + off_h * N_CTX - qo_offset_y = offset_y + start_m * BLOCK_M - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout - q = desc_q.load([qo_offset_y, 0]) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype, start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 4 - STAGE, offs_m, offs_n, N_CTX, # - warp_specialize, IS_HOPPER) - # stage 2: on-band - if STAGE & 2: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype, start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 2, offs_m, offs_n, N_CTX, # - warp_specialize, IS_HOPPER) - # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - desc_o.store([qo_offset_y, 0], acc.to(dtype)) - - -@triton.jit -def _attn_bwd_preprocess(O, DO, # - Delta, # - Z, H, N_CTX, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # - ): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_hz = tl.program_id(1) - off_n = tl.arange(0, HEAD_DIM) - # load - o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) - do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(Delta + off_hz * N_CTX + off_m, delta) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - HEAD_DIM: tl.constexpr, # - # Filled in by the wrapper. - start_n, start_m, num_steps, # - MASK: tl.constexpr): - offs_m = start_m + tl.arange(0, BLOCK_M1) - offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, HEAD_DIM) - qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - curr_m = start_m - step_m = BLOCK_M1 - for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) - # Load m before computing qk to reduce pipeline stall. - offs_m = curr_m + tl.arange(0, BLOCK_M1) - m = tl.load(M + offs_m) - qkT = tl.dot(k, qT) - pT = tl.math.exp2(qkT - m[None, :]) - # Autoregressive masking. - if MASK: - mask = (offs_m[None, :] >= offs_n[:, None]) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) - # Compute dV. - ppT = pT - ppT = ppT.to(tl.float16) - dv += tl.dot(ppT, do) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) - dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) - dk += tl.dot(dsT, tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok - return dk, dv - - -# the main inner-loop logic for computing dQ -@triton.jit -def _attn_bwd_dq(dq, q, K, V, # - do, m, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - # Filled in by the wrapper. - start_m, start_n, num_steps, # - MASK: tl.constexpr): - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) - qk = tl.dot(q, kT) - p = tl.math.exp2(qk - m) - # Autoregressive masking. - if MASK: - offs_n = curr_n + tl.arange(0, BLOCK_N2) - mask = (offs_m[:, None] >= offs_n[None, :]) - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - dp = tl.dot(do, vT).to(tl.float32) - ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - dq += tl.dot(ds, tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok - return dq - - -@triton.jit -def _attn_bwd(Q, K, V, sm_scale, # - DO, # - DQ, DK, DV, # - M, D, - # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - BLK_SLICE_FACTOR: tl.constexpr, # - HEAD_DIM: tl.constexpr, # - CAUSAL: tl.constexpr): - LN2: tl.constexpr = 0.6931471824645996 # = ln(2) - - bhid = tl.program_id(2) - off_chz = (bhid * N_CTX).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - pid = tl.program_id(0) - - # offset pointers for batch/head - Q += adj - K += adj - V += adj - DO += adj - DQ += adj - DK += adj - DV += adj - M += off_chz - D += off_chz - - # load scales - offs_k = tl.arange(0, HEAD_DIM) - - start_n = pid * BLOCK_N1 - start_m = 0 - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - offs_n = start_n + tl.arange(0, BLOCK_N1) - - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - - if CAUSAL: - start_m = start_n - num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=True, # - ) - - start_m += num_steps * MASK_BLOCK_M1 - - # Compute dK and dV for non-masked blocks. - num_steps = (N_CTX - start_m) // BLOCK_M1 - dk, dv = _attn_bwd_dkdv( # - dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=False, # - ) - - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) - - # Write back dK. - dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) - - # THIS BLOCK DOES DQ: - start_m = pid * BLOCK_M2 - start_n = 0 - num_steps = N_CTX // BLOCK_N2 - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - offs_m = start_m + tl.arange(0, BLOCK_M2) - - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - - m = tl.load(M + offs_m) - m = m[:, None] - - if CAUSAL: - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - end_n = start_m + BLOCK_M2 - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # - MASK=True, # - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 - start_n = end_n - num_steps * BLOCK_N2 - - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, BLOCK_N2, HEAD_DIM, # - start_m, start_n, num_steps, # - MASK=False, # - ) - # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - dq *= LN2 - tl.store(dq_ptrs, dq) - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True): - # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) - stage = 3 if causal else 1 - extra_kern_args = {} - # Tuning for AMD target - if is_hip(): - waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 - extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} - - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - # Use device_descriptor for Hopper + warpspec. - if supports_host_descriptor() and not (is_hopper() and warp_specialize): - # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor - y_dim = q.shape[0] * q.shape[1] * q.shape[2] - - dummy_block = [1, 1] - desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - if q.dtype == torch.float8_e5m2: - desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], - block_shape=dummy_block) - else: - desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], - block_shape=dummy_block) - desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - else: - desc_q = q - desc_v = v - desc_k = k - desc_o = o - - def alloc_fn(size: int, align: int, _): - return torch.empty(size, dtype=torch.int8, device="cuda") - - triton.set_allocator(alloc_fn) - - def grid(META): - return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - - ctx.grid = grid - if is_blackwell() and warp_specialize: - if HEAD_DIM_K == 128 and q.dtype == torch.float16: - extra_kern_args["maxnreg"] = 168 - else: - extra_kern_args["maxnreg"] = 80 - _attn_fwd[grid]( - sm_scale, M, # - q.shape[0], q.shape[1], # - desc_q, desc_k, desc_v, desc_o, # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - FP8_OUTPUT=q.dtype == torch.float8_e5m2, # - STAGE=stage, # - warp_specialize=warp_specialize, # - IS_HOPPER=is_hopper(), # - **extra_kern_args) - - ctx.save_for_backward(q, k, v, o, M) - ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, do, # - delta, # - BATCH, N_HEAD, N_CTX, # - BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # - ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # - M, delta, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - N_HEAD, N_CTX, # - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES, # - CAUSAL=ctx.causal, # - ) - - return dq, dk, dv, None, None, None, None - - -attention = _attention.apply - -TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') - - -@pytest.mark.parametrize("Z", [1, 4]) -@pytest.mark.parametrize("H", [2, 48]) -@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024]) -@pytest.mark.parametrize("HEAD_DIM", [64, 128]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False]) -@pytest.mark.parametrize("mode", ["fwd", "bwd"]) -@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else [])) -def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16): - if mode == "fwd" and "fp16" in provider: - pytest.skip("Avoid running the forward computation twice.") - if mode == "bwd" and "fp8" in provider: - pytest.skip("Backward pass with FP8 is not supported.") - torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) - sm_scale = 0.5 - # reference implementation - ref_dtype = dtype - if mode == "fwd" and "fp8" in provider: - ref_dtype = torch.float32 - q = q.to(ref_dtype) - k = k.to(ref_dtype) - v = v.to(ref_dtype) - M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1) - p = p.to(ref_dtype) - # p = torch.exp(p) - ref_out = torch.matmul(p, v).half() - if mode == "bwd": - dout = torch.randn_like(q) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - # triton implementation - if mode == "fwd" and "fp8" in provider: - q = q.to(torch.float8_e5m2) - k = k.to(torch.float8_e5m2) - v = v.permute(0, 1, 3, 2).contiguous() - v = v.permute(0, 1, 3, 2) - v = v.to(torch.float8_e5m2) - tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half() - if mode == "fwd": - atol = 3 if "fp8" in provider else 1e-2 - torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) - return - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0) - rtol = 0.0 - # Relative tolerance workaround for known hardware limitation of CDNA2 GPU. - # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": - rtol = 1e-2 - torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) - torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) - torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol) - - -try: - from flash_attn.flash_attn_interface import \ - flash_attn_qkvpacked_func as flash_attn_func - HAS_FLASH = True -except BaseException: - HAS_FLASH = False - -TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') -BATCH, N_HEADS = 4, 32 -# vary seq length for fixed head and batch=4 -configs = [] -for HEAD_DIM in [64, 128]: - for mode in ["fwd", "bwd"]: - for causal in [True, False]: - # Enable warpspec for causal fwd on Hopper - enable_ws = mode == "fwd" and (is_blackwell() or (is_hopper() and not causal)) - for warp_specialize in [False, True] if enable_ws else [False]: - configs.append( - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + - (["flash"] if HAS_FLASH else []), - line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + - (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-"), ("green", "-")], - ylabel="TFLOPS", - plot_name= - f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}", - args={ - "H": N_HEADS, - "BATCH": BATCH, - "HEAD_DIM": HEAD_DIM, - "mode": mode, - "causal": causal, - "warp_specialize": warp_specialize, - }, - )) - - -@triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE): - assert mode in ["fwd", "bwd"] - dtype = torch.float16 - if "triton" in provider: - q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - if mode == "fwd" and "fp8" in provider: - q = q.to(torch.float8_e5m2) - k = k.to(torch.float8_e5m2) - v = v.permute(0, 1, 3, 2).contiguous() - v = v.permute(0, 1, 3, 2) - v = v.to(torch.float8_e5m2) - sm_scale = 1.3 - fn = lambda: attention(q, k, v, causal, sm_scale, warp_specialize) - if mode == "bwd": - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn) - - if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - fn = lambda: flash_attn_func(qkv, causal=causal) - if mode == "bwd": - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn) - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM - total_flops = 2 * flops_per_matmul - if causal: - total_flops *= 0.5 - if mode == "bwd": - total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - return total_flops * 1e-12 / (ms * 1e-3) - - -if __name__ == "__main__": - # only works on post-Ampere GPUs right now - bench_flash_attention.run(save_path=".", print_data=True) From c48420869d674c1e742886f56dd5c7c905b96e97 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 10 Dec 2025 10:21:39 -0800 Subject: [PATCH 4/8] new triton works --- backends/cuda/triton/kernels/sdpa.py | 594 +++++++++++++++++++-------- 1 file changed, 420 insertions(+), 174 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index c06eb4bb5bd..601c09ad983 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -12,6 +12,7 @@ us export the model without decomposing the SDPA operator under libtorch free environment and have better performance. """ + import math from typing import Optional @@ -21,24 +22,59 @@ from torch.library import triton_op, wrap_triton -AUTOTUNE_CONFIGS = [ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), -] +def _validate_qkv_shapes( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[int, int, int, int, int, int]: + """ + Validate dimensions and return shape info. + Args: + query: Query tensor [B, H, L_q, D] + key: Key tensor [B, H, L_kv, D] + value: Value tensor [B, H, L_kv, D] + Returns: + Tuple of (B, H, L_q, L_kv, D_q, D_kv) + Raises: + RuntimeError: If dimensions are incompatible + """ + B_q, H_q, L_q, D_q = query.shape + B_k, H_k, L_kv_k, D_k = key.shape + B_v, H_v, L_kv_v, D_v = value.shape + # Validate batch and head dimensions + if not (B_q == B_k == B_v): + raise RuntimeError( + f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." + ) + + if not (H_q == H_k == H_v): + raise RuntimeError( + f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." + ) + # Head dimension must match + if not (D_q == D_k == D_v): + raise RuntimeError( + f"Head dimension must match across Q, K, V; got D_q={D_q}, D_k={D_k}, D_v={D_v}." + ) + # Key and Value sequence lengths must match + if L_kv_k != L_kv_v: + raise RuntimeError( + f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." + ) + return B_q, H_q, L_q, L_kv_k, D_q, D_k -@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["HEAD_DIM", "HAS_MASK", "MASK_IS_BOOL"]) @triton.jit -def _sdpa_fwd_kernel( +def _sdpa_fwd_kernel_body( Q_ptr, K_ptr, V_ptr, O_ptr, + Mask_ptr, B, H, - LQ, - LK, + Lq, + Lk, stride_qb, stride_qh, stride_qm, @@ -55,149 +91,319 @@ def _sdpa_fwd_kernel( stride_oh, stride_om, stride_od, - scale, - mask_ptr, stride_mb, - stride_mh, - stride_mm, - stride_mn, + stride_mq, + stride_mk, + sm_scale: tl.float32, HAS_MASK: tl.constexpr, - MASK_IS_BOOL: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, - GROUP_M: tl.constexpr, ): - pid_m_in = tl.program_id(axis=0) + """ + Shared kernel body for SDPA forward pass. + """ + pid_m = tl.program_id(axis=0) pid_bh = tl.program_id(axis=1) b = pid_bh // H h = pid_bh % H - num_pid_m = tl.cdiv(LQ, BLOCK_M) - group_id = pid_m_in // GROUP_M - first_pid_m = group_id * GROUP_M - pid_m = first_pid_m + (pid_m_in + pid_bh) % GROUP_M - start_m = pid_m * BLOCK_M - if start_m >= LQ: - return - - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, HEAD_DIM) - offs_m = tl.multiple_of(offs_m, BLOCK_M) - offs_d = tl.multiple_of(offs_d, 16) - offs_d = tl.max_contiguous(offs_d, 16) q_ptrs = Q_ptr + ( b * stride_qb + h * stride_qh - + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd) + + (offs_m[:, None] * stride_qm) + + (offs_d[None, :] * stride_qd) ) - q_mask = offs_m[:, None] < LQ - q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.bfloat16) - m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) - l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) - acc = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - log2e = 1.4426950408889634 - scale_log2 = scale * log2e - - for start_n in tl.range(0, LK, BLOCK_N): - n_ids = start_n + offs_n - n_mask = n_ids < LK + for start_n in tl.range(0, Lk, BLOCK_N): + offs_n = start_n + offs_n_init k_ptrs = K_ptr + ( b * stride_kb + h * stride_kh - + (offs_d[:, None] * stride_kd + n_ids[None, :] * stride_kn) + + (offs_n[:, None] * stride_kn) + + (offs_d[None, :] * stride_kd) ) - k = tl.load(k_ptrs, mask=n_mask[None, :], other=0.0) + k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - qk = tl.dot(q, k).to(tl.float32) - qk = qk * scale_log2 + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale if HAS_MASK: - m_ptrs = mask_ptr + ( + mask_ptrs = Mask_ptr + ( b * stride_mb - + h * stride_mh - + (offs_m[:, None] * stride_mm + n_ids[None, :] * stride_mn) + + (offs_m[:, None] * stride_mq) + + (offs_n[None, :] * stride_mk) ) - valid = (offs_m[:, None] < LQ) & n_mask[None, :] - if MASK_IS_BOOL: - m_bool = tl.load(m_ptrs, mask=valid, other=True) - qk = tl.where(m_bool, qk, -float("inf")) - else: - m_add = tl.load(m_ptrs, mask=valid, other=0.0).to(tl.float32) - qk = qk + m_add * log2e + mn_mask = (offs_m[:, None] < Lq) & (offs_n[None, :] < Lk) + mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) + qk = tl.where(mask_block, qk, -float("inf")) - qk = tl.where(n_mask[None, :], qk, -float("inf")) + if IS_CAUSAL: + abs_m = offs_m[:, None] + abs_n = offs_n[None, :] + causal = abs_n > abs_m + qk = tl.where(causal, -float("inf"), qk) m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.math.exp2(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.math.exp2(m_i - m_ij) + p_f32 = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p_f32, axis=1) + alpha = tl.exp(m_i - m_ij) v_ptrs = V_ptr + ( b * stride_vb + h * stride_vh - + (n_ids[:, None] * stride_vn + offs_d[None, :] * stride_vd) + + (offs_n[:, None] * stride_vn) + + (offs_d[None, :] * stride_vd) ) - v = tl.load(v_ptrs, mask=n_mask[:, None], other=0.0) - - acc = acc * alpha[:, None] - acc = tl.dot(p.to(tl.bfloat16), v, acc) + v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + p_bf16 = p_f32.to(tl.bfloat16) + acc = acc * alpha[:, None] + tl.dot(p_bf16, v) l_i = l_i * alpha + l_ij m_i = m_ij - row_mask = offs_m < LQ - l_i = tl.where(row_mask, l_i, 1.0) - out = acc / l_i[:, None] + inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) + acc = acc * inv_l_i[:, None] o_ptrs = O_ptr + ( b * stride_ob + h * stride_oh - + (offs_m[:, None] * stride_om + offs_d[None, :] * stride_od) + + (offs_m[:, None] * stride_om) + + (offs_d[None, :] * stride_od) + ) + o_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m64( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=64 optimizations. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m32( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=32 optimizations for small workloads. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, ) - tl.store(o_ptrs, out.to(tl.bfloat16), mask=row_mask[:, None]) -def _check_inputs( +@triton_op("triton::sdpa", mutates_args={}) +def sdpa( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - dropout_p: float, - is_causal: bool, - enable_gqa: bool, -): + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + """ + Triton fused Scaled Dot-Product Attention with optimized dual-kernel approach. + + Args: + query: Query tensor with size [B, H, L_q, D] and dtype torch.bfloat16 + key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 + value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 + attn_mask: Optional attention mask [B, H, L_q, L_kv] with dtype torch.bool + dropout_p: must be 0.0 (others are not supported) + is_causal: whether to apply causal masking + scale: attention scale (default: 1/sqrt(D)) + enable_gqa: must be False (True is not supported) + Returns: + Output tensor [B, H, L_q, D] with dtype torch.bfloat16 + """ + # Validate inputs if not (query.is_cuda and key.is_cuda and value.is_cuda): - raise ValueError("query, key, value must be CUDA tensors.") + raise RuntimeError("Q, K, V must be CUDA tensors.") if ( query.dtype != torch.bfloat16 or key.dtype != torch.bfloat16 or value.dtype != torch.bfloat16 ): - raise ValueError("This kernel expects bfloat16 inputs for query, key, value.") - if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: - raise ValueError("query, key, value must be 4D tensors [B, H, L, D].") - if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]: - raise ValueError("Batch dimension mismatch.") - if query.shape[1] != key.shape[1] or query.shape[1] != value.shape[1]: - raise ValueError("Heads dimension mismatch.") - if query.shape[-1] != key.shape[-1] or query.shape[-1] != value.shape[-1]: - raise ValueError("Head dimension (D) mismatch across query, key, value.") - if attn_mask is not None: - if attn_mask.ndim != 4: - raise ValueError("attn_mask must be 4D [B, H, L_q, L_k].") - if attn_mask.shape[0] != query.shape[0] or attn_mask.shape[1] != query.shape[1]: - raise ValueError("attn_mask batch/head dims must match query.") - if attn_mask.shape[2] != query.shape[2] or attn_mask.shape[3] != key.shape[2]: - raise ValueError("attn_mask spatial dims must be [L_q, L_k].") - if attn_mask.dtype not in (torch.bool, torch.bfloat16): - raise ValueError("attn_mask must be dtype bool or bfloat16.") - + raise RuntimeError("Expected bfloat16 inputs") + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise RuntimeError( + f"Expected 4D tensors shaped [B, H, L, D]; got " + f"query.dim()={query.dim()}, key.dim()={key.dim()}, " + f"value.dim()={value.dim()}." + ) # Enforce unsupported features if dropout_p != 0.0: raise RuntimeError( @@ -207,90 +413,132 @@ def _check_inputs( raise RuntimeError( "enable_gqa must be False (not supported in this implementation)." ) - if is_causal: - raise RuntimeError( - "is_causal must be False (not supported in this implementation)." - ) + # Validate and get dimensions + B, H, L_q, L_kv, D_q, D_kv = _validate_qkv_shapes(query, key, value) + D = D_q # Head dimension + # Enforce causal masking constraint + if is_causal: + if L_q != L_kv: + raise RuntimeError( + f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}." + ) -@triton_op("triton::sdpa", mutates_args={}) -def sdpa( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float = 0.0, - enable_gqa: bool = False, -) -> torch.Tensor: - if attn_mask is not None and attn_mask.shape[1] == 1: - attn_mask = attn_mask.expand(-1, query.shape[1], -1, -1).contiguous() - - _check_inputs(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa) - - B, H, LQ, D = query.shape - LK = key.shape[2] - - out = torch.empty_like(query) - - if scale == 0: - scale = 1.0 / math.sqrt(D) - scale = float(scale) + # Allocate output with query shape + out = torch.empty((B, H, L_q, D), device=query.device, dtype=query.dtype) + # Element-wise strides stride_qb, stride_qh, stride_qm, stride_qd = query.stride() stride_kb, stride_kh, stride_kn, stride_kd = key.stride() stride_vb, stride_vh, stride_vn, stride_vd = value.stride() stride_ob, stride_oh, stride_om, stride_od = out.stride() - HAS_MASK = 1 if attn_mask is not None else 0 - MASK_IS_BOOL = 1 if (attn_mask is not None and attn_mask.dtype == torch.bool) else 0 - if attn_mask is None: - mask_ptr = query - stride_mb = stride_mh = stride_mm = stride_mn = 0 - else: - mask_ptr = attn_mask - stride_mb, stride_mh, stride_mm, stride_mn = attn_mask.stride() + # Scale factor for SDPA + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + + # Handle attention mask + HAS_MASK = attn_mask is not None + Mask_ptr = 0 + stride_mb = stride_mq = stride_mk = 0 + if HAS_MASK: + if attn_mask.dtype != torch.bool: + raise RuntimeError("attn_mask must have dtype torch.bool") + if not attn_mask.is_cuda: + raise RuntimeError("attn_mask must be a CUDA tensor") + if ( + attn_mask.shape[0] != B + or attn_mask.shape[2] != L_q + or attn_mask.shape[3] != L_kv + ): + raise RuntimeError( + f"attn_mask shape mismatch: expected [B={B}, H, L_q={L_q}, L_kv={L_kv}], " + f"got {attn_mask.shape}" + ) + Mask_ptr = attn_mask + stride_mb = attn_mask.stride(0) + stride_mq = attn_mask.stride(2) + stride_mk = attn_mask.stride(3) + # Grid configuration def grid(meta): - return (triton.cdiv(LQ, meta["BLOCK_M"]), B * H) + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + # Dynamic kernel selection based on workload + total_ctas_m64 = ((L_q + 63) // 64) * (B * H) + threshold = 4 * 84 # Heuristic threshold for kernel selection + use_small_block = total_ctas_m64 < threshold + + if use_small_block: + wrap_triton(_sdpa_fwd_kernel_m32)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) + else: + wrap_triton(_sdpa_fwd_kernel_m64)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) - wrap_triton(_sdpa_fwd_kernel)[grid]( - query, - key, - value, - out, - B, - H, - LQ, - LK, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - scale, - mask_ptr, - stride_mb, - stride_mh, - stride_mm, - stride_mn, - HAS_MASK=HAS_MASK, - MASK_IS_BOOL=MASK_IS_BOOL, - HEAD_DIM=D, - GROUP_M=8, - ) return out @@ -312,10 +560,8 @@ def _sdpa_abstract( This just returns an empty tensor with the correct shape/dtype/device. """ # Validate dtypes match - assert ( - query.dtype == key.dtype == value.dtype - ), "query, key, value must have the same dtype" + assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" # Validate kqv's shape and get the output shape - B, H, LQ, D = query.shape + B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) - return torch.empty(B, H, LQ, D, dtype=query.dtype, device=query.device) + return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) From d3e602fd8f38b9d8d25182645106b3ba42648ad5 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 10 Dec 2025 11:02:18 -0800 Subject: [PATCH 5/8] Revert "introduce cuda sdpa" This reverts commit ca461fb9fe73d421c3c79a550445036c72cf0861. --- backends/cuda/runtime/TARGETS | 4 - backends/cuda/runtime/shims/sdpa.cu | 649 --------------- backends/cuda/runtime/shims/sdpa.cuh | 282 ------- backends/cuda/runtime/shims/sdpa.h | 104 --- backends/cuda/runtime/shims/tests/targets.bzl | 1 - ...orch_cuda_scaled_dot_product_attention.cpp | 781 ------------------ 6 files changed, 1821 deletions(-) delete mode 100644 backends/cuda/runtime/shims/sdpa.cu delete mode 100644 backends/cuda/runtime/shims/sdpa.cuh delete mode 100644 backends/cuda/runtime/shims/sdpa.h delete mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 01dabee9086..a85f3a7e6a3 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -53,7 +53,6 @@ runtime.cxx_library( "shims/cuda_guard.cpp", "shims/int4mm.cu", "shims/memory.cpp", - "shims/sdpa.cu", "shims/tensor_attribute.cpp", ], headers = [ @@ -62,8 +61,6 @@ runtime.cxx_library( "shims/int4mm.cuh", "shims/int4mm.h", "shims/memory.h", - "shims/sdpa.cuh", - "shims/sdpa.h", "shims/tensor_attribute.h", "utils.h", ], @@ -87,7 +84,6 @@ runtime.cxx_library( ], external_deps = [ ("cuda", None, "cuda-lazy"), - ("cuda", None, "cublas-lazy"), ], ) diff --git a/backends/cuda/runtime/shims/sdpa.cu b/backends/cuda/runtime/shims/sdpa.cu deleted file mode 100644 index c15f1f006bc..00000000000 --- a/backends/cuda/runtime/shims/sdpa.cu +++ /dev/null @@ -1,649 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace executorch::backends::cuda { - -using executorch::backends::aoti::Tensor; -using executorch::backends::aoti::AOTITorchError; -using executorch::runtime::Error; - -// ============================================================================ -// CUDA Kernels for Softmax and Masking -// ============================================================================ - -// Helper function for max with different types -__device__ __forceinline__ float device_max(float a, float b) { - return fmaxf(a, b); -} - -__device__ __forceinline__ __half device_max(__half a, __half b) { - return __hgt(a, b) ? a : b; -} - -__device__ __forceinline__ __nv_bfloat16 device_max(__nv_bfloat16 a, __nv_bfloat16 b) { - #if __CUDA_ARCH__ >= 800 - return __hgt(a, b) ? a : b; - #else - return __float2bfloat16(fmaxf(__bfloat162float(a), __bfloat162float(b))); - #endif -} - -/** - * Softmax kernel with optional causal masking - * - * Computes softmax along the last dimension (seq_len_k) of a 4D tensor. - * Supports causal masking where positions j > i are masked out. - * - * Input: [batch, num_heads, seq_len_q, seq_len_k] - * Output: [batch, num_heads, seq_len_q, seq_len_k] - * - * Each thread processes one row (seq_len_q position). - * - * Note: Supports in-place operation (input == output). - */ -template -__global__ void softmax_with_causal_mask_kernel( - const scalar_t* input, - scalar_t* output, - const int64_t batch, - const int64_t num_heads, - const int64_t seq_len_q, - const int64_t seq_len_k, - const bool is_causal, - const float scale) { - - // Each block processes one row of the attention matrix - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const int64_t total_rows = batch * num_heads * seq_len_q; - - if (idx >= total_rows) { - return; - } - - // Decode position - we only need i for causal masking - const int64_t i = idx % seq_len_q; - - // Pointer to the start of this row - const int64_t row_offset = idx * seq_len_k; - const scalar_t* input_row = input + row_offset; - scalar_t* output_row = output + row_offset; - - // Find max for numerical stability (two-pass algorithm) - float max_val = -FLT_MAX; - for (int64_t j = 0; j < seq_len_k; ++j) { - if (!is_causal || j <= i) { - float val = static_cast(input_row[j]) * scale; - max_val = fmaxf(max_val, val); - } - } - - // Compute exp and sum - float sum_exp = 0.0f; - for (int64_t j = 0; j < seq_len_k; ++j) { - float val; - if (!is_causal || j <= i) { - val = expf(static_cast(input_row[j]) * scale - max_val); - } else { - val = 0.0f; - } - output_row[j] = static_cast(val); - sum_exp += val; - } - - // Normalize - const float inv_sum = 1.0f / sum_exp; - for (int64_t j = 0; j < seq_len_k; ++j) { - output_row[j] = static_cast(static_cast(output_row[j]) * inv_sum); - } -} - -/** - * Scale kernel - multiply all elements by a scalar - */ -template -__global__ void scale_kernel( - scalar_t* __restrict__ data, - const int64_t size, - const float scale) { - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - data[idx] = static_cast(static_cast(data[idx]) * scale); - } -} - -// ============================================================================ -// cuBLAS Helper Functions -// ============================================================================ - -/** - * Get or create a cuBLAS handle for the current stream - * - * Note: In production, this should use a handle pool or be managed - * by the backend infrastructure. This is a simplified version. - */ -cublasHandle_t get_cublas_handle(cudaStream_t stream) { - static cublasHandle_t handle = nullptr; - - if (handle == nullptr) { - cublasCreate(&handle); - } - - cublasSetStream(handle, stream); - return handle; -} - -/** - * Batched matrix multiplication wrapper for cuBLAS - * - * Computes: C = alpha * op(A) @ op(B) + beta * C - * for a batch of matrices - */ -template -cublasStatus_t batched_gemm( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, int n, int k, - const scalar_t* alpha, - const scalar_t* A, int lda, int64_t strideA, - const scalar_t* B, int ldb, int64_t strideB, - const scalar_t* beta, - scalar_t* C, int ldc, int64_t strideC, - int batchCount); - -// Specializations for different types -template<> -cublasStatus_t batched_gemm( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, int64_t strideA, - const float* B, int ldb, int64_t strideB, - const float* beta, - float* C, int ldc, int64_t strideC, - int batchCount) { - return cublasSgemmStridedBatched( - handle, transa, transb, m, n, k, - alpha, A, lda, strideA, B, ldb, strideB, - beta, C, ldc, strideC, batchCount); -} - -template<> -cublasStatus_t batched_gemm<__half>( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, int n, int k, - const __half* alpha, - const __half* A, int lda, int64_t strideA, - const __half* B, int ldb, int64_t strideB, - const __half* beta, - __half* C, int ldc, int64_t strideC, - int batchCount) { - return cublasHgemmStridedBatched( - handle, transa, transb, m, n, k, - alpha, A, lda, strideA, B, ldb, strideB, - beta, C, ldc, strideC, batchCount); -} - -// Note: BFloat16 uses compute type float internally -template<> -cublasStatus_t batched_gemm<__nv_bfloat16>( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, int n, int k, - const __nv_bfloat16* alpha, - const __nv_bfloat16* A, int lda, int64_t strideA, - const __nv_bfloat16* B, int ldb, int64_t strideB, - const __nv_bfloat16* beta, - __nv_bfloat16* C, int ldc, int64_t strideC, - int batchCount) { - - // cuBLAS BFloat16 GEMM - introduced in CUDA 11+ - #if CUDA_VERSION >= 11000 - // For BFloat16, we need to use cublasGemmStridedBatchedEx - // with compute type CUBLAS_COMPUTE_32F - float alpha_f = 1.0f; - float beta_f = 0.0f; - - return cublasGemmStridedBatchedEx( - handle, - transa, transb, - m, n, k, - &alpha_f, - A, CUDA_R_16BF, lda, strideA, - B, CUDA_R_16BF, ldb, strideB, - &beta_f, - C, CUDA_R_16BF, ldc, strideC, - batchCount, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT); - #else - ET_LOG(Error, "BFloat16 GEMM requires CUDA 11.0 or later"); - return CUBLAS_STATUS_NOT_SUPPORTED; - #endif -} - -// ============================================================================ -// Math Fallback Implementation -// ============================================================================ - -/** - * Math fallback implementation for SDPA - * - * This implementation uses cuBLAS for matrix multiplications and custom - * kernels for softmax. It provides maximum compatibility across all CUDA - * devices but may not be as optimized as Flash Attention or Memory Efficient - * Attention. - * - * Algorithm: - * 1. Compute attention scores: S = (Q @ K^T) - * 2. Apply scaling and compute softmax with optional causal mask - * 3. Compute output: O = attention_weights @ V - */ -template -Tensor* sdpa_math_fallback_impl( - const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* attn_mask, - bool is_causal, - float scale_factor, - cudaStream_t stream) { - - // Get tensor dimensions - const int64_t batch = query->size(0); - const int64_t num_heads = query->size(1); - const int64_t seq_len_q = query->size(2); - const int64_t head_dim = query->size(3); - const int64_t seq_len_k = key->size(2); - const int64_t head_dim_v = value->size(3); - - // Get cuBLAS handle - cublasHandle_t handle = get_cublas_handle(stream); - - // Step 1: Allocate temporary buffer for attention scores - // Shape: [batch, num_heads, seq_len_q, seq_len_k] - const int64_t scores_size = batch * num_heads * seq_len_q * seq_len_k; - scalar_t* scores_ptr = nullptr; - cudaMalloc(&scores_ptr, scores_size * sizeof(scalar_t)); - if (scores_ptr == nullptr) { - ET_LOG(Error, "sdpa_math_fallback: Failed to allocate scores buffer"); - return nullptr; - } - - // Step 2: Compute Q @ K^T using cuBLAS - // Q: [batch * num_heads, seq_len_q, head_dim] - // K^T: [batch * num_heads, head_dim, seq_len_k] - // Output: [batch * num_heads, seq_len_q, seq_len_k] - - const int m = seq_len_q; - const int n = seq_len_k; - const int k = head_dim; - const int batch_count = batch * num_heads; - - const scalar_t alpha = static_cast(1.0f); - const scalar_t beta = static_cast(0.0f); - - const scalar_t* q_ptr = reinterpret_cast(query->data_ptr()); - const scalar_t* k_ptr = reinterpret_cast(key->data_ptr()); - - // Strides for batched GEMM - const int64_t stride_q = seq_len_q * head_dim; - const int64_t stride_k = seq_len_k * head_dim; - const int64_t stride_scores = seq_len_q * seq_len_k; - - // Q @ K^T - cublasStatus_t status = batched_gemm( - handle, - CUBLAS_OP_T, // Transpose K - CUBLAS_OP_N, // No transpose Q - n, // seq_len_k - m, // seq_len_q - k, // head_dim - &alpha, - k_ptr, k, // K matrix - stride_k, - q_ptr, k, // Q matrix - stride_q, - &beta, - scores_ptr, n, // Output scores - stride_scores, - batch_count); - - if (status != CUBLAS_STATUS_SUCCESS) { - ET_LOG(Error, "sdpa_math_fallback: cuBLAS GEMM failed for Q @ K^T"); - cudaFree(scores_ptr); - return nullptr; - } - - // Step 3: Apply softmax with scaling and optional causal mask - const int threads_per_block = 256; - const int64_t total_rows = batch * num_heads * seq_len_q; - const int num_blocks = (total_rows + threads_per_block - 1) / threads_per_block; - - softmax_with_causal_mask_kernel<<>>( - scores_ptr, - scores_ptr, // in-place - batch, - num_heads, - seq_len_q, - seq_len_k, - is_causal, - scale_factor); - - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - ET_LOG(Error, "sdpa_math_fallback: Softmax kernel launch failed: %s", - cudaGetErrorString(cuda_err)); - cudaFree(scores_ptr); - return nullptr; - } - - // Step 4: Allocate output tensor [batch, num_heads, seq_len_q, head_dim_v] - Tensor* output = nullptr; - std::array output_shape = {batch, num_heads, seq_len_q, head_dim_v}; - std::array output_stride = { - num_heads * seq_len_q * head_dim_v, - seq_len_q * head_dim_v, - head_dim_v, - 1}; - - auto dtype_int = static_cast(query->dtype()); - aoti_torch_empty_strided( - 4, - output_shape.data(), - output_stride.data(), - dtype_int, - static_cast(SupportedDevices::CUDA), - 0, - &output); - - if (output == nullptr) { - ET_LOG(Error, "sdpa_math_fallback: Failed to allocate output tensor"); - cudaFree(scores_ptr); - return nullptr; - } - - // Step 5: Compute attention_weights @ V - // attention_weights: [batch * num_heads, seq_len_q, seq_len_k] - // V: [batch * num_heads, seq_len_k, head_dim_v] - // Output: [batch * num_heads, seq_len_q, head_dim_v] - - const int m_v = seq_len_q; - const int n_v = head_dim_v; - const int k_v = seq_len_k; - - const scalar_t* v_ptr = reinterpret_cast(value->data_ptr()); - scalar_t* out_ptr = reinterpret_cast(output->data_ptr()); - - const int64_t stride_v = seq_len_k * head_dim_v; - const int64_t stride_out = seq_len_q * head_dim_v; - - status = batched_gemm( - handle, - CUBLAS_OP_N, // No transpose V - CUBLAS_OP_N, // No transpose attention_weights - n_v, // head_dim_v - m_v, // seq_len_q - k_v, // seq_len_k - &alpha, - v_ptr, n_v, // V matrix - stride_v, - scores_ptr, k_v, // attention_weights - stride_scores, - &beta, - out_ptr, n_v, // Output - stride_out, - batch_count); - - // Cleanup temporary buffers - cudaFree(scores_ptr); - - if (status != CUBLAS_STATUS_SUCCESS) { - ET_LOG(Error, "sdpa_math_fallback: cuBLAS GEMM failed for attention_weights @ V"); - aoti_torch_delete_tensor_object(output); - return nullptr; - } - - return output; -} - -Tensor* sdpa_math_fallback( - const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* attn_mask, - bool is_causal, - double scale_factor, - cudaStream_t stream) { - - // Dispatch based on dtype - auto dtype = query->dtype(); - - if (dtype == executorch::aten::ScalarType::Float) { - return sdpa_math_fallback_impl( - query, key, value, attn_mask, is_causal, - static_cast(scale_factor), stream); - } else if (dtype == executorch::aten::ScalarType::Half) { - return sdpa_math_fallback_impl<__half>( - query, key, value, attn_mask, is_causal, - static_cast(scale_factor), stream); - } else if (dtype == executorch::aten::ScalarType::BFloat16) { - return sdpa_math_fallback_impl<__nv_bfloat16>( - query, key, value, attn_mask, is_causal, - static_cast(scale_factor), stream); - } else { - ET_LOG(Error, "sdpa_math_fallback: Unsupported dtype"); - return nullptr; - } -} - -/** - * Main entry point for SDPA computation - */ -Tensor* scaled_dot_product_attention_cuda( - const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* attn_mask, - double dropout_p, - bool is_causal, - const double* scale, - bool enable_gqa, - cudaStream_t stream) { - - // Select backend - SDPBackend backend = select_sdp_backend( - query, key, value, attn_mask, dropout_p, is_causal); - - if (backend == SDPBackend::Error) { - ET_LOG(Error, "scaled_dot_product_attention_cuda: No valid backend selected"); - return nullptr; - } - - // Calculate scale factor - double scale_factor = calculate_scale(query, scale); - - // Handle GQA if needed - if (enable_gqa && is_gqa_configuration(query, key, value)) { - if (!validate_gqa(query, key, value)) { - ET_LOG(Error, "scaled_dot_product_attention_cuda: Invalid GQA configuration"); - return nullptr; - } - ET_LOG( - Error, - "scaled_dot_product_attention_cuda: GQA support not yet implemented. " - "Need to repeat K/V heads to match Q heads."); - return nullptr; - } - - // Dispatch to appropriate backend - switch (backend) { - case SDPBackend::Math: - return sdpa_math_fallback( - query, key, value, attn_mask, is_causal, scale_factor, stream); - - case SDPBackend::FlashAttention: - ET_LOG(Error, "Flash Attention backend not yet implemented"); - return nullptr; - - case SDPBackend::MemoryEfficientAttention: - ET_LOG(Error, "Memory Efficient Attention backend not yet implemented"); - return nullptr; - - case SDPBackend::CuDNN: - ET_LOG(Error, "cuDNN backend not yet implemented"); - return nullptr; - - default: - ET_LOG(Error, "Unknown SDPA backend"); - return nullptr; - } -} - -// ============================================================================ -// C API Implementation -// ============================================================================ - -#ifdef __cplusplus -extern "C" { -#endif - -AOTITorchError aoti_torch_cuda_scaled_dot_product_attention( - Tensor* query, - Tensor* key, - Tensor* value, - Tensor* attn_mask, - double dropout_p, - int32_t is_causal, - double* scale, - int32_t enable_gqa, - Tensor** ret0) { - - // Input validation - if (!query || !key || !value || !ret0) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Null pointer input"); - return Error::InvalidArgument; - } - - // Currently only support dropout_p = 0.0 for inference - if (dropout_p != 0.0) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: dropout_p != 0.0 is not supported"); - return Error::InvalidArgument; - } - - // Check tensor dimensions - if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query, Key, Value must be 4D tensors"); - return Error::InvalidArgument; - } - - // Check that Q, K, V have the same dtype - if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query, Key, Value must have the same dtype"); - return Error::InvalidArgument; - } - - // Check tensor shapes - const int64_t batch = query->size(0); - const int64_t num_heads = query->size(1); - const int64_t seq_len_q = query->size(2); - const int64_t head_dim_q = query->size(3); - - const int64_t num_heads_kv = key->size(1); - const int64_t seq_len_k = key->size(2); - const int64_t head_dim_k = key->size(3); - - const int64_t seq_len_v = value->size(2); - const int64_t head_dim_v = value->size(3); - - // Validate shapes - if (key->size(0) != batch || value->size(0) != batch) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Batch size mismatch"); - return Error::InvalidArgument; - } - - if (seq_len_k != seq_len_v) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Key and Value sequence length mismatch"); - return Error::InvalidArgument; - } - - if (head_dim_q != head_dim_k) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Query and Key head dimension mismatch"); - return Error::InvalidArgument; - } - - if (value->size(1) != num_heads_kv) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Key and Value num_heads mismatch"); - return Error::InvalidArgument; - } - - // GQA validation - if (enable_gqa && num_heads % num_heads_kv != 0) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: For GQA, num_heads must be divisible by num_heads_kv"); - return Error::InvalidArgument; - } - - // Validate attn_mask if provided - if (attn_mask) { - if (is_causal) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Cannot use both attn_mask and is_causal"); - return Error::InvalidArgument; - } - } - - // Get CUDA stream - auto stream_result = getCurrentCUDAStream(0); - if (!stream_result.ok()) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: Failed to get CUDA stream"); - return Error::Internal; - } - cudaStream_t stream = stream_result.get(); - - // Call the main SDPA function - Tensor* output = scaled_dot_product_attention_cuda( - query, - key, - value, - attn_mask, - dropout_p, - is_causal != 0, - scale, - enable_gqa != 0, - stream); - - if (output == nullptr) { - ET_LOG(Error, "aoti_torch_cuda_scaled_dot_product_attention: SDPA computation failed"); - return Error::Internal; - } - - *ret0 = output; - return Error::Ok; -} - -#ifdef __cplusplus -} -#endif - -} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.cuh b/backends/cuda/runtime/shims/sdpa.cuh deleted file mode 100644 index 5cc941f4120..00000000000 --- a/backends/cuda/runtime/shims/sdpa.cuh +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// This file implements scaled_dot_product_attention for ExecuTorch. -// -// IMPLEMENTATION NOTES: -// --------------------- -// This is NOT a direct port from PyTorch. Instead, we implemented -// a custom Math Fallback using cuBLAS and custom CUDA kernels. -// -// PyTorch reference implementations (for architecture reference only): -// - CPU/General: aten/src/ATen/native/transformers/attention.cpp -// - CUDA: aten/src/ATen/native/transformers/cuda/attention.cu -// -// Key differences from PyTorch: -// - PyTorch uses high-level ATen ops (at::matmul, at::_safe_softmax) -// - We use direct cuBLAS calls and custom softmax kernels -// - Optimized for inference (no dropout, no backward pass) -// - Simplified memory management -// - No ATen/c10 dependencies -// -// PORTING NOTES: -// -------------- -// 1. KERNEL CODE: Adapted from PyTorch attention kernels -// - Math fallback implementation for maximum compatibility -// - Supports Float32, Float16, and BFloat16 dtypes -// - Standard attention computation: softmax(Q @ K^T / scale) @ V -// -// 2. API ADAPTATIONS: -// - Replaced at::Tensor with executorch::backends::aoti::Tensor -// - Output returned via pointer-to-pointer instead of by-value -// - Simplified interface for inference (dropout=0.0 only) -// -// 3. REMOVED FEATURES: -// - Flash Attention backend (requires external library) -// - Memory Efficient Attention backend (requires external library) -// - cuDNN backend (requires cuDNN library) -// - Dropout support (training-only feature) -// - Nested tensor support (complex layout) -// - Backward pass (training-only feature) -// -// 4. INFRASTRUCTURE CHANGES: -// - Removed c10::cuda::CUDAGuard: Device management handled by AOTI backend -// - Removed at::cuda::getCurrentCUDAStream(): Stream passed explicitly -// - Simplified error handling using ExecutorTorch Error codes - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace executorch::backends::cuda { - -using executorch::backends::aoti::Tensor; -using executorch::runtime::Error; - -// ============================================================================ -// Utility Functions for SDPA -// ============================================================================ - -// Calculate the scaling factor for attention scores -inline double calculate_scale(const Tensor* query, const double* scale) { - if (scale != nullptr) { - return *scale; - } - // Default: 1 / sqrt(head_dim) - // Query shape: [batch, num_heads, seq_len_q, head_dim] - // head_dim is at index 3 (0-indexed) - const int64_t head_dim = query->size(3); - return 1.0 / std::sqrt(static_cast(head_dim)); -} - -// Check if tensor dtype is supported for SDPA -inline bool is_supported_dtype(const Tensor* tensor) { - auto dtype = tensor->dtype(); - return dtype == executorch::aten::ScalarType::Float || - dtype == executorch::aten::ScalarType::Half || - dtype == executorch::aten::ScalarType::BFloat16; -} - -// ============================================================================ -// Math Fallback Implementation -// ============================================================================ - -// This is the basic, portable implementation that works on all CUDA devices. -// It computes attention using explicit matrix multiplications and softmax: -// 1. Compute scores: S = Q @ K^T * scale -// 2. Apply mask if provided -// 3. Compute attention weights: A = softmax(S) -// 4. Compute output: O = A @ V - -/** - * Math fallback kernel for scaled dot product attention - * - * This is a basic implementation that performs: - * output = softmax(query @ key^T / scale) @ value - * - * Supports: - * - Batch processing - * - Multiple attention heads - * - Optional causal masking - * - Optional explicit attention mask - * - Float32, Float16, BFloat16 dtypes - * - * Note: This implementation is for reference and maximum compatibility. - * For production use, consider using Flash Attention or other optimized backends. - */ -Tensor* sdpa_math_fallback( - const Tensor* query, // [batch, num_heads, seq_len_q, head_dim] - const Tensor* key, // [batch, num_heads_kv, seq_len_k, head_dim] - const Tensor* value, // [batch, num_heads_kv, seq_len_k, head_dim_v] - const Tensor* attn_mask, // Optional: [batch, num_heads, seq_len_q, seq_len_k] or broadcastable - bool is_causal, // Apply causal masking - double scale_factor, // Scaling factor for attention scores - cudaStream_t stream); // CUDA stream for execution - -// ============================================================================ -// Backend Selection -// ============================================================================ - -enum class SDPBackend { - Error = -1, - Math = 0, - FlashAttention = 1, - MemoryEfficientAttention = 2, - CuDNN = 3 -}; - -/** - * Select the best available backend for SDPA based on input parameters - * - * For now, only Math fallback is supported. Future implementations may add: - * - Flash Attention (Ampere+ GPUs) - * - Memory Efficient Attention - * - cuDNN backend - */ -inline SDPBackend select_sdp_backend( - const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* attn_mask, - double dropout_p, - bool is_causal) { - - // Check for unsupported features - if (dropout_p > 0.0) { - ET_LOG(Error, "SDPA: Dropout not supported in inference mode"); - return SDPBackend::Error; - } - - // Check tensor dimensions - if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { - ET_LOG(Error, "SDPA: All inputs must be 4D tensors"); - return SDPBackend::Error; - } - - // Check dtype support - if (!is_supported_dtype(query) || !is_supported_dtype(key) || !is_supported_dtype(value)) { - ET_LOG(Error, "SDPA: Unsupported dtype, only Float32/Float16/BFloat16 supported"); - return SDPBackend::Error; - } - - // Check dtype consistency - if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { - ET_LOG(Error, "SDPA: Query, Key, Value must have the same dtype"); - return SDPBackend::Error; - } - - // For now, always use math fallback - // Future: Add logic to select Flash Attention, MemEff, or cuDNN when available - return SDPBackend::Math; -} - -// ============================================================================ -// Helper Functions for Causal Mask -// ============================================================================ - -/** - * Check if we need to apply causal masking - */ -inline bool needs_causal_mask(bool is_causal, const Tensor* attn_mask) { - if (!is_causal) { - return false; - } - if (attn_mask != nullptr) { - ET_LOG(Error, "SDPA: Cannot use both is_causal=true and explicit attn_mask"); - return false; - } - return true; -} - -// ============================================================================ -// Grouped Query Attention (GQA) Support -// ============================================================================ - -/** - * Check if inputs require GQA handling - * - * GQA allows num_heads_q != num_heads_kv, where num_heads_q must be - * divisible by num_heads_kv. Key and Value heads are repeated to match - * Query heads. - */ -inline bool is_gqa_configuration( - const Tensor* query, - const Tensor* key, - const Tensor* value) { - - const int64_t num_heads_q = query->size(1); - const int64_t num_heads_kv = key->size(1); - - return num_heads_q != num_heads_kv; -} - -/** - * Validate GQA configuration - */ -inline bool validate_gqa( - const Tensor* query, - const Tensor* key, - const Tensor* value) { - - const int64_t num_heads_q = query->size(1); - const int64_t num_heads_kv = key->size(1); - const int64_t num_heads_v = value->size(1); - - // Key and Value must have same num_heads - if (num_heads_kv != num_heads_v) { - ET_LOG(Error, "SDPA GQA: Key and Value must have same num_heads"); - return false; - } - - // Query heads must be divisible by Key/Value heads - if (num_heads_q % num_heads_kv != 0) { - ET_LOG(Error, "SDPA GQA: Query num_heads must be divisible by Key/Value num_heads"); - return false; - } - - return true; -} - -// ============================================================================ -// Main SDPA Entry Point -// ============================================================================ - -/** - * Compute scaled dot product attention - * - * This is the main entry point that selects the appropriate backend - * and dispatches to the corresponding implementation. - * - * Currently only Math fallback is implemented. Future versions may add: - * - Flash Attention - * - Memory Efficient Attention - * - cuDNN backend - */ -Tensor* scaled_dot_product_attention_cuda( - const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* attn_mask, - double dropout_p, - bool is_causal, - const double* scale, - bool enable_gqa, - cudaStream_t stream); - -} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.h b/backends/cuda/runtime/shims/sdpa.h deleted file mode 100644 index 4db08576ca0..00000000000 --- a/backends/cuda/runtime/shims/sdpa.h +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include -#include - -namespace executorch::backends::cuda { - -using executorch::backends::aoti::AOTITorchError; -using executorch::backends::aoti::Tensor; - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * Performs scaled dot-product attention on CUDA. - * - * This is a port of PyTorch's scaled_dot_product_attention CUDA implementation - * (aten/src/ATen/native/transformers/cuda/attention.cu) adapted for the - * ExecuTorch runtime. - * - * Computes attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ V - * - * HARDWARE REQUIREMENTS: - * - CUDA-capable GPU - * - Supports Flash Attention if available (Ampere+ GPUs) - * - * TENSOR REQUIREMENTS: - * @param query Query tensor [batch, num_heads, seq_len_q, head_dim] - * - Must be Float32, Float16, or BFloat16 dtype - * - Must be 4D - * - Must be on CUDA device - * - * @param key Key tensor [batch, num_heads_kv, seq_len_k, head_dim] - * - Must be same dtype as query - * - Must be 4D - * - Must be on CUDA device - * - num_heads_kv can be different from num_heads (for GQA) - * - * @param value Value tensor [batch, num_heads_kv, seq_len_k, head_dim_v] - * - Must be same dtype as query - * - Must be 4D - * - Must be on CUDA device - * - * @param attn_mask Optional attention mask [batch, num_heads, seq_len_q, seq_len_k] - * or broadcastable shape - * - Can be nullptr (no mask) - * - If provided, must be Float32, BFloat16, or Bool dtype - * - Additive mask: positions with large negative values are masked out - * - * @param dropout_p Dropout probability (0.0 to 1.0) - * - Currently only supports 0.0 (no dropout) - * - Must be 0.0 for inference - * - * @param is_causal Whether to apply causal masking - * - If true, applies lower triangular mask - * - Cannot be used together with explicit attn_mask - * - * @param scale Optional scaling factor for attention scores - * - If nullptr, uses 1/sqrt(head_dim) by default - * - If provided, uses the specified value - * - * @param enable_gqa Enable grouped query attention support - * - Allows num_heads_kv != num_heads - * - Query heads must be divisible by key/value heads - * - * @param ret0 Output parameter for attention result - * [batch, num_heads, seq_len_q, head_dim_v] - * - Allocated by this function - * - Same dtype as input tensors - * - Must not be null - * - Caller is responsible for freeing via aoti_torch_delete_tensor_object() - * - * @return AOTITorchError error code: - * - Error::Ok: Success - * - Error::InvalidArgument: Null pointer, wrong dtype, wrong dimensions, - * or invalid parameter combination - * - Error::Internal: CUDA kernel launch failure - */ -AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_scaled_dot_product_attention( - Tensor* query, - Tensor* key, - Tensor* value, - Tensor* attn_mask, - double dropout_p, - int32_t is_causal, - double* scale, - int32_t enable_gqa, - Tensor** ret0); - -#ifdef __cplusplus -} -#endif - -} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 0896b3b6a3b..b274ecf3675 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -34,5 +34,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") - cuda_shim_cpp_unittest("aoti_torch_cuda_scaled_dot_product_attention") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp deleted file mode 100644 index e2677878ea0..00000000000 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp +++ /dev/null @@ -1,781 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -using namespace executorch::backends::cuda; -using namespace executorch::backends::aoti; -using namespace executorch::runtime; - -// Test fixture for SDPA tests -class AOTITorchSDPATest : public ::testing::Test { - protected: - void SetUp() override { - // Initialize ExecuTorch Platform Abstraction Layer - et_pal_init(); - - // Check if CUDA is available - int device_count = 0; - cudaError_t err = cudaGetDeviceCount(&device_count); - if (err != cudaSuccess || device_count == 0) { - GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; - } - - // Clean up any existing cached metadata before each test - cleanup_tensor_metadata(); - } - - void TearDown() override { - // Clean up after each test - cleanup_tensor_metadata(); - } - - // Helper function to create a Float32 tensor filled with a specific value - Tensor* create_float_tensor( - std::vector shape, - float fill_value = 1.0f) { - Tensor* tensor = nullptr; - - // Calculate size - int64_t total_size = 1; - for (auto dim : shape) { - total_size *= dim; - } - - // Calculate strides (row-major) - std::vector strides(shape.size()); - int64_t stride = 1; - for (int i = shape.size() - 1; i >= 0; --i) { - strides[i] = stride; - stride *= shape[i]; - } - - // Create tensor - Error error = aoti_torch_empty_strided( - shape.size(), - shape.data(), - strides.data(), - static_cast(SupportedDTypes::FLOAT32), - static_cast(SupportedDevices::CUDA), - 0, - &tensor); - - if (error != Error::Ok || tensor == nullptr) { - return nullptr; - } - - // Fill with value - std::vector host_data(total_size, fill_value); - cudaMemcpy( - tensor->data_ptr(), - host_data.data(), - total_size * sizeof(float), - cudaMemcpyHostToDevice); - - return tensor; - } - - // Helper function to create a BFloat16 tensor - Tensor* create_bfloat16_tensor( - std::vector shape, - float fill_value = 1.0f) { - Tensor* tensor = nullptr; - - // Calculate size - int64_t total_size = 1; - for (auto dim : shape) { - total_size *= dim; - } - - // Calculate strides (row-major) - std::vector strides(shape.size()); - int64_t stride = 1; - for (int i = shape.size() - 1; i >= 0; --i) { - strides[i] = stride; - stride *= shape[i]; - } - - // Create tensor - Error error = aoti_torch_empty_strided( - shape.size(), - shape.data(), - strides.data(), - static_cast(SupportedDTypes::BFLOAT16), - static_cast(SupportedDevices::CUDA), - 0, - &tensor); - - if (error != Error::Ok || tensor == nullptr) { - return nullptr; - } - - // Fill with value - // Note: For simplicity, we'll fill with float and let the runtime handle conversion - // In production, you'd want to properly convert to bfloat16 - std::vector host_data(total_size, fill_value); - cudaMemcpy( - tensor->data_ptr(), - host_data.data(), - total_size * sizeof(float), - cudaMemcpyHostToDevice); - - return tensor; - } - - // Helper to check if output tensor has expected shape - bool check_output_shape( - Tensor* output, - const std::vector& expected_shape) { - if (output == nullptr) { - return false; - } - if (output->dim() != expected_shape.size()) { - return false; - } - for (size_t i = 0; i < expected_shape.size(); ++i) { - if (output->size(i) != expected_shape[i]) { - return false; - } - } - return true; - } - - // Helper to copy tensor data from GPU to CPU for verification - std::vector copy_tensor_to_host(Tensor* tensor) { - int64_t total_size = 1; - for (int i = 0; i < tensor->dim(); ++i) { - total_size *= tensor->size(i); - } - - std::vector host_data(total_size); - cudaMemcpy( - host_data.data(), - tensor->data_ptr(), - total_size * sizeof(float), - cudaMemcpyDeviceToHost); - - return host_data; - } - - // Helper to check if a value is approximately equal (for floating point comparison) - bool approx_equal(float a, float b, float epsilon = 1e-5f) { - return std::abs(a - b) < epsilon; - } -}; - -// ============================================================================ -// Basic Functionality Tests -// ============================================================================ - -// Test basic SDPA with Float32, no causal mask -TEST_F(AOTITorchSDPATest, BasicFunctionalityFloat32) { - // Create tensors: [batch=1, num_heads=2, seq_len=4, head_dim=8] - const int64_t batch = 1; - const int64_t num_heads = 2; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr) << "Failed to create query tensor"; - ASSERT_NE(key, nullptr) << "Failed to create key tensor"; - ASSERT_NE(value, nullptr) << "Failed to create value tensor"; - - printf("Testing SDPA Float32: [%ldx%ldx%ldx%ld]\n", batch, num_heads, seq_len, head_dim); - - // Call SDPA - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, - key, - value, - nullptr, // no explicit mask - 0.0, // no dropout - 0, // not causal - nullptr, // default scale - 0, // no GQA - &output); - - // Check result - EXPECT_EQ(error, Error::Ok) << "SDPA should succeed"; - ASSERT_NE(output, nullptr) << "Output should not be null"; - - // Verify output shape: [batch, num_heads, seq_len, head_dim] - EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})) - << "Output shape mismatch"; - - printf("SDPA Float32 test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test SDPA with causal masking -TEST_F(AOTITorchSDPATest, CausalMasking) { - const int64_t batch = 1; - const int64_t num_heads = 1; - const int64_t seq_len = 8; - const int64_t head_dim = 16; - - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - printf("Testing SDPA with causal masking: [%ldx%ldx%ldx%ld]\n", - batch, num_heads, seq_len, head_dim); - - // Call SDPA with causal mask - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, - key, - value, - nullptr, - 0.0, - 1, // causal mask enabled - nullptr, - 0, - &output); - - EXPECT_EQ(error, Error::Ok); - ASSERT_NE(output, nullptr); - EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); - - printf("Causal masking test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test SDPA with BFloat16 -TEST_F(AOTITorchSDPATest, BFloat16Precision) { - const int64_t batch = 2; - const int64_t num_heads = 4; - const int64_t seq_len = 16; - const int64_t head_dim = 32; - - Tensor* query = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); - Tensor* value = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr) << "Failed to create BFloat16 query tensor"; - ASSERT_NE(key, nullptr) << "Failed to create BFloat16 key tensor"; - ASSERT_NE(value, nullptr) << "Failed to create BFloat16 value tensor"; - - printf("Testing SDPA BFloat16: [%ldx%ldx%ldx%ld]\n", - batch, num_heads, seq_len, head_dim); - - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, - key, - value, - nullptr, - 0.0, - 0, - nullptr, - 0, - &output); - - EXPECT_EQ(error, Error::Ok) << "SDPA BFloat16 should succeed"; - ASSERT_NE(output, nullptr); - EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); - - printf("BFloat16 precision test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test SDPA with custom scale factor -TEST_F(AOTITorchSDPATest, CustomScale) { - const int64_t batch = 1; - const int64_t num_heads = 2; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - printf("Testing SDPA with custom scale\n"); - - // Use custom scale instead of default 1/sqrt(head_dim) - double custom_scale = 0.25; - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, - key, - value, - nullptr, - 0.0, - 0, - &custom_scale, // custom scale - 0, - &output); - - EXPECT_EQ(error, Error::Ok); - ASSERT_NE(output, nullptr); - EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); - - printf("Custom scale test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test with larger tensors (closer to real-world usage) -TEST_F(AOTITorchSDPATest, LargerTensors) { - const int64_t batch = 4; - const int64_t num_heads = 8; - const int64_t seq_len = 128; - const int64_t head_dim = 64; - - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - printf("Testing SDPA with larger tensors: [%ldx%ldx%ldx%ld]\n", - batch, num_heads, seq_len, head_dim); - - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, - key, - value, - nullptr, - 0.0, - 1, // causal - nullptr, - 0, - &output); - - EXPECT_EQ(error, Error::Ok); - ASSERT_NE(output, nullptr); - EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); - - printf("Larger tensors test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// ============================================================================ -// Error Handling Tests -// ============================================================================ - -// Test null pointer handling -TEST_F(AOTITorchSDPATest, NullPointerHandling) { - Tensor* query = create_float_tensor({1, 1, 4, 8}, 0.5f); - Tensor* key = create_float_tensor({1, 1, 4, 8}, 0.5f); - Tensor* value = create_float_tensor({1, 1, 4, 8}, 1.0f); - Tensor* output = nullptr; - - // Test null query - { - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - nullptr, key, value, nullptr, 0.0, 0, nullptr, 0, &output); - EXPECT_NE(error, Error::Ok) << "Should fail with null query"; - } - - // Test null key - { - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, nullptr, value, nullptr, 0.0, 0, nullptr, 0, &output); - EXPECT_NE(error, Error::Ok) << "Should fail with null key"; - } - - // Test null value - { - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, nullptr, nullptr, 0.0, 0, nullptr, 0, &output); - EXPECT_NE(error, Error::Ok) << "Should fail with null value"; - } - - // Test null output pointer - { - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, nullptr, 0, nullptr); - EXPECT_NE(error, Error::Ok) << "Should fail with null output pointer"; - } - - printf("Null pointer handling tests passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); -} - -// Test dimension mismatch -TEST_F(AOTITorchSDPATest, DimensionMismatch) { - Tensor* query = create_float_tensor({1, 2, 4, 8}, 0.5f); - Tensor* key = create_float_tensor({1, 2, 6, 8}, 0.5f); // Different seq_len - Tensor* value = create_float_tensor({1, 2, 6, 8}, 1.0f); - Tensor* output = nullptr; - - // This should succeed (Q and K can have different seq_len) - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); - - EXPECT_EQ(error, Error::Ok) << "Different Q and K seq_len should be allowed"; - - if (output != nullptr) { - // Output should have Q's seq_len - EXPECT_EQ(output->size(2), 4) << "Output seq_len should match Query"; - aoti_torch_delete_tensor_object(output); - } - - printf("Dimension handling test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); -} - -// Test dropout error (should fail since we don't support dropout) -TEST_F(AOTITorchSDPATest, DropoutNotSupported) { - Tensor* query = create_float_tensor({1, 1, 4, 8}, 0.5f); - Tensor* key = create_float_tensor({1, 1, 4, 8}, 0.5f); - Tensor* value = create_float_tensor({1, 1, 4, 8}, 1.0f); - Tensor* output = nullptr; - - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.5, 0, nullptr, 0, &output); // dropout=0.5 - - EXPECT_NE(error, Error::Ok) << "Should fail with non-zero dropout"; - - printf("Dropout rejection test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); -} - -// ============================================================================ -// Numerical Correctness Tests -// ============================================================================ - -// Test that output values are in reasonable range -TEST_F(AOTITorchSDPATest, OutputValueRangeCheck) { - const int64_t batch = 1; - const int64_t num_heads = 1; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - // Use small values to avoid numerical overflow - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - printf("Testing SDPA output value range\n"); - - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); - - ASSERT_EQ(error, Error::Ok); - ASSERT_NE(output, nullptr); - - // Copy output back to CPU for verification - std::vector output_data = copy_tensor_to_host(output); - - // Since V is all 1.0, and softmax produces weights that sum to 1, - // output should be close to 1.0 (weighted average of 1.0) - bool all_in_range = true; - for (size_t i = 0; i < output_data.size(); ++i) { - // Output should be around 1.0 with some tolerance - if (output_data[i] < 0.5f || output_data[i] > 1.5f) { - printf("Output[%zu] = %f is out of expected range [0.5, 1.5]\n", - i, output_data[i]); - all_in_range = false; - } - } - - EXPECT_TRUE(all_in_range) << "Some output values are out of reasonable range"; - - printf("Output value range check passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test with identity Q=K, verify attention weights sum to 1 -TEST_F(AOTITorchSDPATest, IdentityQKTest) { - const int64_t batch = 1; - const int64_t num_heads = 1; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - // When Q=K, attention scores will be uniform (since all positions are equally similar) - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 2.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - printf("Testing SDPA with Q=K (identity attention)\n"); - - Tensor* output = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, nullptr, 0, &output); - - ASSERT_EQ(error, Error::Ok); - ASSERT_NE(output, nullptr); - - // Copy output back to CPU - std::vector output_data = copy_tensor_to_host(output); - - // When Q=K and V is uniform, output should be close to V - // (since attention weights are uniform due to identical scores) - bool values_correct = true; - for (size_t i = 0; i < output_data.size(); ++i) { - // Output should be close to 2.0 (the value of V) - if (!approx_equal(output_data[i], 2.0f, 0.1f)) { - printf("Output[%zu] = %f, expected ~2.0\n", i, output_data[i]); - values_correct = false; - } - } - - EXPECT_TRUE(values_correct) << "Output values don't match expected for identity Q=K"; - - printf("Identity Q=K test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output); -} - -// Test that different scales produce different outputs -TEST_F(AOTITorchSDPATest, ScaleEffectTest) { - const int64_t batch = 1; - const int64_t num_heads = 1; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - // Make K different at different positions so attention scores vary - std::vector key_host(batch * num_heads * seq_len * head_dim); - for (int64_t pos = 0; pos < seq_len; ++pos) { - for (int64_t d = 0; d < head_dim; ++d) { - // Different values per position: pos 0=0.1, pos 1=0.3, pos 2=0.5, pos 3=0.7 - key_host[pos * head_dim + d] = 0.1f + 0.2f * pos; - } - } - cudaMemcpy( - key->data_ptr(), - key_host.data(), - key_host.size() * sizeof(float), - cudaMemcpyHostToDevice); - - // Make V also different at different positions to amplify differences - std::vector value_host(batch * num_heads * seq_len * head_dim); - for (int64_t pos = 0; pos < seq_len; ++pos) { - for (int64_t d = 0; d < head_dim; ++d) { - // V values: pos 0=1.0, pos 1=2.0, pos 2=3.0, pos 3=4.0 - value_host[pos * head_dim + d] = static_cast(pos + 1); - } - } - cudaMemcpy( - value->data_ptr(), - value_host.data(), - value_host.size() * sizeof(float), - cudaMemcpyHostToDevice); - - printf("Testing SDPA scale effect\n"); - - // Test with default scale - Tensor* output1 = nullptr; - AOTITorchError error1 = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, nullptr, 0, &output1); - ASSERT_EQ(error1, Error::Ok); - ASSERT_NE(output1, nullptr); - - // Test with custom scale (much smaller, should make attention more uniform) - double small_scale = 0.01; - Tensor* output2 = nullptr; - AOTITorchError error2 = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 0, &small_scale, 0, &output2); - ASSERT_EQ(error2, Error::Ok); - ASSERT_NE(output2, nullptr); - - // Copy outputs back to CPU - std::vector output1_data = copy_tensor_to_host(output1); - std::vector output2_data = copy_tensor_to_host(output2); - - // Outputs should be different (scale affects softmax sharpness) - // With varied V values, even small changes in attention weights will produce - // noticeably different outputs - bool outputs_differ = false; - float max_diff = 0.0f; - for (size_t i = 0; i < output1_data.size(); ++i) { - float diff = std::abs(output1_data[i] - output2_data[i]); - max_diff = std::max(max_diff, diff); - if (diff > 0.05f) { // More lenient threshold due to varied V values - outputs_differ = true; - break; - } - } - - printf("Max difference between outputs: %f\n", max_diff); - EXPECT_TRUE(outputs_differ) << "Different scales should produce different outputs (max_diff=" << max_diff << ")"; - - printf("Scale effect test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output1); - aoti_torch_delete_tensor_object(output2); -} - -// Test causal masking correctness -TEST_F(AOTITorchSDPATest, CausalMaskingCorrectness) { - const int64_t batch = 1; - const int64_t num_heads = 1; - const int64_t seq_len = 4; - const int64_t head_dim = 8; - - // Create distinct values at different positions in V - // This allows us to verify that causal masking works correctly - Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); - Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); - - ASSERT_NE(query, nullptr); - ASSERT_NE(key, nullptr); - ASSERT_NE(value, nullptr); - - // Manually set different values for each position in V - // V[position i] = i+1 (so we can track which positions contribute) - std::vector value_host(batch * num_heads * seq_len * head_dim); - for (int64_t pos = 0; pos < seq_len; ++pos) { - for (int64_t d = 0; d < head_dim; ++d) { - value_host[pos * head_dim + d] = static_cast(pos + 1); - } - } - cudaMemcpy( - value->data_ptr(), - value_host.data(), - value_host.size() * sizeof(float), - cudaMemcpyHostToDevice); - - printf("Testing SDPA causal masking correctness\n"); - - // Run with causal masking - Tensor* output_causal = nullptr; - AOTITorchError error = aoti_torch_cuda_scaled_dot_product_attention( - query, key, value, nullptr, 0.0, 1, nullptr, 0, &output_causal); - ASSERT_EQ(error, Error::Ok); - ASSERT_NE(output_causal, nullptr); - - // Copy output back to CPU - std::vector output_data = copy_tensor_to_host(output_causal); - - // With causal masking: - // - Position 0 can only see position 0, so output[0] should be ~1.0 - // - Position 1 can see positions 0,1, so output[1] should be ~1.5 (average of 1 and 2) - // - Position 2 can see positions 0,1,2, so output[2] should be ~2.0 (average of 1,2,3) - // - Position 3 can see all, so output[3] should be ~2.5 (average of 1,2,3,4) - - std::vector expected_values = {1.0f, 1.5f, 2.0f, 2.5f}; - - bool causal_correct = true; - for (int64_t pos = 0; pos < seq_len; ++pos) { - float avg_output = 0.0f; - for (int64_t d = 0; d < head_dim; ++d) { - avg_output += output_data[pos * head_dim + d]; - } - avg_output /= head_dim; - - printf("Position %ld: output avg = %f, expected ~%f\n", - pos, avg_output, expected_values[pos]); - - if (!approx_equal(avg_output, expected_values[pos], 0.2f)) { - causal_correct = false; - } - } - - EXPECT_TRUE(causal_correct) << "Causal masking did not produce expected values"; - - printf("Causal masking correctness test passed!\n"); - - // Cleanup - aoti_torch_delete_tensor_object(query); - aoti_torch_delete_tensor_object(key); - aoti_torch_delete_tensor_object(value); - aoti_torch_delete_tensor_object(output_causal); -} - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} From 06d11de8d997a211b9d238b4041e931c8253b9d6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 10 Dec 2025 11:10:54 -0800 Subject: [PATCH 6/8] use math decomposition as fallback to support gemma3 --- .../ci_commit_pins/optimum-executorch.txt | 2 +- backends/cuda/cuda_backend.py | 34 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index df87f35a69d..156ff2f3c82 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -d03e90c2cd9048e6d9a75285c0355f033cd016fc +0123293118efb08ac4ffc4fefe9d330201465c93 diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index eb1226ebf8a..f0d3a000ec0 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -134,20 +134,20 @@ def get_aoti_compile_options( return options - # @classmethod - # def get_extra_aoti_compile_context_manager(cls): - # """ - # Return SDPA MATH backend context manager for CUDA compilation. - - # This context manager plays as a fallback solution for any remaining PyTorch SDPA - # operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. - - # Note: - # - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, - # this context manager will have no effect on those ops (they are no longer - # PyTorch SDPA ops). - # - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this - # context manager will force them to use the MATH backend, causing them to - # be automatically decomposed during compilation. - # """ - # return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) + @classmethod + def get_extra_aoti_compile_context_manager(cls): + """ + Return SDPA MATH backend context manager for CUDA compilation. + + This context manager plays as a fallback solution for any remaining PyTorch SDPA + operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + Note: + - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + this context manager will have no effect on those ops (they are no longer + PyTorch SDPA ops). + - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + context manager will force them to use the MATH backend, causing them to + be automatically decomposed during compilation. + """ + return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) From e450cfea6906cf924fcfe9af1159b0b1c8b20b05 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 10 Dec 2025 22:03:52 -0800 Subject: [PATCH 7/8] gemma3 support --- backends/cuda/cuda_backend.py | 37 +- backends/cuda/triton/kernels/sdpa.py | 305 +- output.wav | Bin 0 -> 960044 bytes special_tokens_map.json | 33 + tokenizer_config.json | 51347 +++++++++++++++++++++++++ 5 files changed, 51658 insertions(+), 64 deletions(-) create mode 100644 output.wav create mode 100644 special_tokens_map.json create mode 100644 tokenizer_config.json diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f0d3a000ec0..9204ecaecda 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -68,7 +68,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] ) triton_kernel_mode = mode - return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + # return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + return [ReplaceEdgeOpWithTritonOpPass()] @classmethod def get_aoti_compile_options( @@ -134,20 +135,20 @@ def get_aoti_compile_options( return options - @classmethod - def get_extra_aoti_compile_context_manager(cls): - """ - Return SDPA MATH backend context manager for CUDA compilation. - - This context manager plays as a fallback solution for any remaining PyTorch SDPA - operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. - - Note: - - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, - this context manager will have no effect on those ops (they are no longer - PyTorch SDPA ops). - - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this - context manager will force them to use the MATH backend, causing them to - be automatically decomposed during compilation. - """ - return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) + # @classmethod + # def get_extra_aoti_compile_context_manager(cls): + # """ + # Return SDPA MATH backend context manager for CUDA compilation. + + # This context manager plays as a fallback solution for any remaining PyTorch SDPA + # operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + # Note: + # - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + # this context manager will have no effect on those ops (they are no longer + # PyTorch SDPA ops). + # - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + # context manager will force them to use the MATH backend, causing them to + # be automatically decomposed during compilation. + # """ + # return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 601c09ad983..e3c444f093e 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -22,6 +22,24 @@ from torch.library import triton_op, wrap_triton +def _is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def _next_power_of_2(x: int) -> int: + """Get the next power of 2 >= x, clamped to [16, 256].""" + if x <= 16: + return 16 + if x <= 32: + return 32 + if x <= 64: + return 64 + if x <= 128: + return 128 + return 256 + + def _validate_qkv_shapes( query: torch.Tensor, key: torch.Tensor, @@ -64,6 +82,131 @@ def _validate_qkv_shapes( return B_q, H_q, L_q, L_kv_k, D_q, D_k +# ============================================================================== +# Non-power-of-2 HEAD_DIM kernel +# ============================================================================== +@triton.jit +def _sdpa_fwd_kernel_non_pow2( + q_ptr, + k_ptr, + v_ptr, + o_ptr, + mask_ptr, + B, + H, + LQ, + LK, + HEAD_DIM, + stride_qb, + stride_qh, + stride_ql, + stride_qd, + stride_kb, + stride_kh, + stride_kl, + stride_kd, + stride_vb, + stride_vh, + stride_vl, + stride_vd, + stride_ob, + stride_oh, + stride_ol, + stride_od, + stride_mb, + stride_mh, + stride_mlq, + stride_mlk, + scale, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + """ + SDPA forward kernel for non-power-of-2 HEAD_DIM. + Uses dynamic masking to handle arbitrary head dimensions. + """ + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + b = pid_bh // H + h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + d_mask = offs_d < HEAD_DIM + q_row_mask = offs_m < LQ + + q_base = q_ptr + b * stride_qb + h * stride_qh + k_base = k_ptr + b * stride_kb + h * stride_kh + v_base = v_ptr + b * stride_vb + h * stride_vh + o_base = o_ptr + b * stride_ob + h * stride_oh + + q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd) + q = tl.load(q_ptrs, mask=q_row_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32) + m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) + l_i = tl.full((BLOCK_M,), 1.0, dtype=tl.float32) + + qk_scale_log2 = scale * 1.4426950408889634 + + if HAS_MASK: + mask_b_base = mask_ptr + b * stride_mb + + for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): + kn = start_n + offs_n + kv_col_mask = kn < LK + + k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd) + k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale_log2 + + if IS_CAUSAL: + row_abs = offs_m[:, None] + col_abs = kn[None, :] + causal_mask = col_abs > row_abs + qk = tl.where(causal_mask, -float("inf"), qk) + + if HAS_MASK: + mask_ptrs = ( + mask_b_base + offs_m[:, None] * stride_mlq + kn[None, :] * stride_mlk + ) + tile_valid = q_row_mask[:, None] & kv_col_mask[None, :] + keep = tl.load(mask_ptrs, mask=tile_valid, other=True) + qk = tl.where(keep, qk, -float("inf")) + + qk = tl.where(kv_col_mask[None, :], qk, -float("inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + + acc = acc * alpha[:, None] + + v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd) + v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.dot(p.to(v.dtype), v, acc) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + out = acc / l_i[:, None] + o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d[None, :] * stride_od) + tl.store(o_ptrs, out.to(tl.bfloat16), mask=q_row_mask[:, None] & d_mask[None, :]) + + +# ============================================================================== +# Power-of-2 HEAD_DIM kernels +# ============================================================================== @triton.jit def _sdpa_fwd_kernel_body( Q_ptr, @@ -463,57 +606,122 @@ def sdpa( def grid(meta): return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) - # Dynamic kernel selection based on workload - total_ctas_m64 = ((L_q + 63) // 64) * (B * H) - threshold = 4 * 84 # Heuristic threshold for kernel selection - use_small_block = total_ctas_m64 < threshold - - if use_small_block: - wrap_triton(_sdpa_fwd_kernel_m32)[grid]( - query, - key, - value, - out, - Mask_ptr if HAS_MASK else 0, - B, - H, - L_q, - L_kv, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale, - HAS_MASK=HAS_MASK, - IS_CAUSAL=is_causal, - HEAD_DIM=D, - ) + # Select kernel based on whether HEAD_DIM is power of 2 + if _is_power_of_2(D): + # Use power-of-2 optimized kernels with autotune + # Dynamic kernel selection based on workload + total_ctas_m64 = ((L_q + 63) // 64) * (B * H) + threshold = 4 * 84 # Heuristic threshold for kernel selection + use_small_block = total_ctas_m64 < threshold + + if use_small_block: + wrap_triton(_sdpa_fwd_kernel_m32)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) + else: + wrap_triton(_sdpa_fwd_kernel_m64)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) else: - wrap_triton(_sdpa_fwd_kernel_m64)[grid]( + # Use non-power-of-2 kernel with dynamic HEAD_DIM masking + BLOCK_D = _next_power_of_2(D) + + if BLOCK_D >= 256: + BLOCK_N = 64 + else: + BLOCK_N = 128 + + BLOCK_M = 32 + num_warps = 4 + num_stages = 2 + + # Handle mask for non-pow2 kernel (different stride layout) + if HAS_MASK: + mask_ptr = attn_mask + stride_mb_np2 = attn_mask.stride(0) + stride_mh_np2 = attn_mask.stride(1) + stride_mlq_np2 = attn_mask.stride(2) + stride_mlk_np2 = attn_mask.stride(3) + else: + mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool) + stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0 + + def grid_non_pow2(meta): + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + wrap_triton(_sdpa_fwd_kernel_non_pow2)[grid_non_pow2]( query, key, value, out, - Mask_ptr if HAS_MASK else 0, + mask_ptr, B, H, L_q, L_kv, + D, stride_qb, stride_qh, stride_qm, @@ -530,13 +738,18 @@ def grid(meta): stride_oh, stride_om, stride_od, - stride_mb, - stride_mq, - stride_mk, + stride_mb_np2, + stride_mh_np2, + stride_mlq_np2, + stride_mlk_np2, sm_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, - HEAD_DIM=D, + num_warps=num_warps, + num_stages=num_stages, ) return out diff --git a/output.wav b/output.wav new file mode 100644 index 0000000000000000000000000000000000000000..dba3cb22f842b883bf54215b4dac08ca8107e029 GIT binary patch literal 960044 zcmYIR1-KQ(-=5jsleaHixOA7KAS$USC~bg(3aFqUA_xX3AYy}}Vo_3xk}6;zDuRGW zqk?odxBJA_eD8bq&-1a*J?HG~?8L8s@60au_qyw@hRajMpl%QL7&-QZy6Hj)iDPfuPk{iZ5F zNdf9z!|bo9Gx*+%W0(3(ZB>U=4(8MleQ6jSM}HDy%283YUI}G|>XJI4j^geDwCTsJ zQ$#UZ$wT>h(Bvw{tAi3%aBqn^q|Sk!Qd9@u(r|aFir~I{#eaqPtx&fPXj%h(c^s(( z`ouxkd{E3*DTzC?Fq$M$6E&-$-75G@LR*?}FtR-57dGaRENX+=?L;e#$G{Ole|6An zCiqtwJgW{$3Doq1qN$Ln4k)dLzBmdUS1O@qxzbQGBx34wa0jp41sW>Ix(f;i&^~+0#K^KCTQ_1vsVF@Nn)%RnOjQPNDnTEUP##av zqAJEwU)&4ZLs^yLKh%w5J_dBJ31-S)InX8xG>NMuwA&EAXtF;s-ykHz0eudl zRyqDOaR$6AP(ifW4)c$ItF<9r4#rpk&7&3;K%4SWGYKUM*V7UuIpF10jH?<()e!Tb zoSuTjlz{SnNQ!`l9>!4s+1d{(>(GXBw6src2G9Qm=Ssk(Bhc$S@SEd4hq6-%t;@qb z|3d0L`*aBsyct(3&|3sY1@77g*}sPRXCQ|qxK;_&G{DCqv~wQ)Na))+j5-3oQk&x_ z-G}=um4*3{!zHkr6{t~&S#mVF_+1LB>7acndWwPuKFpb#R2@95gz*=HPH744;7DwE zh^&Ma@N85B=T*UrTTwp^64C~`S`+%lJ}Qe|Vkq>Kb|8+HNwp-<*+M@OcV$D)O!*^I9?s~f}p@PXy!FY zdKi{43nk^avjFtY1I_%HK`z>-w-V?%6Kz$7WY$2vOw6|mq(7FB%?p^5592CN%$D>& z58568pZ~--@^NU8#Dm~hDQI^RRLQ~TS(HlwmHtTMV_M8(Z`oU-fxHfNdZqn;00;m!X`C_*7g9s z^U!Mr+RFh|n&N1Kzjlx##toF0DzM5oqSdCDcNKhhhc)U5Nw^XAwH;c%5hDtKE-8@H z2VobQfg<#mbupHnup-&GS`A$H4?j{J+G-_aycQ&=4fq#G$Zk2tOBoGf zZe7uOeY8*n3dJzKn&2>fQ)NhQGe{Xd7_G?-C{IaPAVx${@T+dZ8pP2S>FmcHCHS_X z<$2&_8KmGWc#sEfSA~5mfSmjU9_65B2=`xtRoQ{#0=RJwJfdtILVR!@zmF#Fy9!IQ z8KuWy0dhe5i)b?p*)T!d0{lfG2c;;zf~&OL^cY)UGfv`;gJ|V=!fu^IPrK0XK9uaj zl|0ngiWaZ@Zv;ot!+G?68T{Ri@}p?61V;qiKMG5+8~4N#v&sO!$|1{PXb5Xu#cb(g z=sSy4nVezevM zT<-$j1z<<=pcgGs^Jd(ig;8o~vkJy9Fso#6kDj&?%BUgKz{)tU4xYQPYjsiX#Uj;E zM#(jy*Ky3McEZ|Zq31O87(ix$C8>6)lN*Vdo!hI$v zPI*d!wAIE{Mw4Mk%+-XRXP|aN^iF!yhS8FQaL%!(q0KntJ13#Rv_6bnX-OC#nV3fu zUbRw!M|oAXM!SqM7++C049vk!w5QO!4>~}f zo{mvE;5U7GV^}THi!mJOWFu}#$H6GsK(rQzrd7r`bZFT*$oxUHFG26xiLy$#mX2TK zkp}Lz1IOB8%qHe_7PCl0o0Suh7G)@mdbHE@+L`E^`9lrRy9Kn2y2v=RW};N!z6Ka+ zH5@{m#yoRCjkd5WZ6FtwFb{9m61`VJNn_ksA9sW>uQ(`IfxazJmRzA6n3$~w*`(*H zg|ZuAVcO%)Ab4{Q8c+e=)`n~~LOVg|$W=(Tg_+mH=$c`)$>>o}P`w;Bg6(EQGHamS z%it_^_HsfKENGPms-1!s<)Hj1IB^bVW$67H@{l9Yt-oO@BH-mooL|9x^wn1&M>eFA zqo$WX4ckJUWaLnU66QIlFwRT(6-BKmuCNDMnE%iwGa;Ux!yoxygc^A`yNY>R=w0C# zvjoNmsqhc9ZAJK#;EtK-hy7{^c_~CqKV}?H^m-bf25Q;p@l3)O(QnhA(NdJ*+eZDf z81Xgq$!sYMNoK53AHOQlJL9~%pllL!D;qScioXm{iSk$rM-$YgF413Qfrrg-G=fFA z8UCORxYQKimBGWxXoWtm4r=mffjhIIN7bQY4RCKO+|>;CwZ&%71XuS%gkC{vwW=yX^uW8{t0IHQ?uBo{0 zLZ^I?4FSn9P)_X&qK1Y)#wQNw?t*3xe!0+iN`oI~3iPf3?eb8cJSxBl3m^{>(2)^h z9;izBxB#w4a8`o5b0H^j$e4w?3MJQI|0~dQPQnh5b2jAZG$iFe$Yd^hB%NLKK$$#@ z)~=3rJuken;1%{YazB9AOkuLSeT1vL~X;J{|ZF%HITv=$Pd zRl%`#;BOV&N7-P`Ru_HIX17mxhnnCqt;Q|L5AJ~sG=MBo7J5JqTH)^&NJK~6$J~&2 zwuD@@z$dLh1B{y%Dj9srLhTN)Fb%=|G|Z9`)&MP2ic&G(RFu;4*TDGbD@;&70_sw# zD`O1Vkc2YuhH)O_iR{GTOK_MrA`diAM;m_J!8())#vBn`&@n3V$U?15^qhe{JSpTj$$JyX38lbA zzgH7-Noy9v94pW>QK523f{r^=P&)?MD?(e9K=G;=c{0j<=tBb?nVnH5YbG=?3nQYP zVLnY;UmM!VTswe%nI%>MFPXDuqZLwv^fNJMf6r&$wMv)ATR!Wql;xiLh$lY}O zs)4&0Hz%QIN+2_aa(vTcC8OOewCEtyYJ#3?pgwh;wj~oax+zD`>c+xr4EJiOmIA;ud z3Z=9{ClGD!#<4dMjnX<4VpOO9CnHx9yv)J9=MwWcjamGSaphyqWr;a4mOY35X`N1? zPg*nP{U-Xnh>;Xwbk}h0GS07oRvgO-v_QK?9dj|#ODOZK3vDm8E`;dgIL1%}uB3s> zw5~*>g80tGn8FFp(zY|Zq83myY+Rw_+u#lJqiP9WibO11j8fv9&GA<=A%6Vg~N&3QJNMvOsIkIEEQx0Do!t zECfA^L7Q|)V0J=w3Na_@;Bj2Jf-AIo%snsSoEdR8`b-8@DsW#M+$TaQQG;=OE+}#W zH7h_R>I-c{G3c#eNvX+|6Ou|B&4_`%nRsv*M;x=FUdBN=`Y>{s-n;^(C5e7A5;ob( zItALLPjW&1AmlFselt5aU_Y}lf68P&cvgX#6oZ2Fg?7RUlz`(zvlu;7#s%s!a>$2F z#ZX=u=LL|y|4W(H1=d zBk`v2H#ed+A9}8WyP7BNXI9z_9;F4^qn%)$LrgdfiV?-8lr+LUXE5V1M#n6LxSWo0 z)5n!TbC|IbdnPtboP{WHDask^R6$t`yvfJi)p4~1Tq3&aKsHDr+9-NkQkq^g4lXeM zWt2+GPiY7z=FAcYT%f)&z9cdw(a&WZM4GAR%shz|Dzrl@ei5ZagCfvmB2@WknK@TJ ze&?V^6W3|WE}(xu=9Yr75+mhYh|w_P&@gJ^-KS79ALC)fRfyJUqdC%R30uVMGXqp% z^g&4_5^19)6Vgt6hf#}(a^j*yYAP@;>VFnEQVG7Wej@5)4&?*=GQc&)dNuKjQ5hwG zF?(%Llo1SV1idz+)Oz4FGr{KII(w~x)>@!mI{Ig9S0^DK2KYnY&B&HL2GM3UT&s_I zjHNVOA<`idb0$vCID*z9gfeD6v=odBC`;ARD>I#D_{_#xWk>}h;40{W*%>)Qy(NXY zY9he_Mg@#LxiS&KSa=6@&qgf^JR%}bxhe$xsQp}5aKNQ2s7*|=7_%t`2L#$@{FnN_ zK9bN5X~cYq8cd5v%?qFoM<2pf;%1CvDJ3koVWW>A!u&VE$1+F*aciz!*yw{9)mhZ$ zOj9v(j$NP!9eos`mis^XIt@uYiP1CKi9&8@AxbfNe}Yme=$A+~k>gU#gML(pmeWgK zM2}n(Vb(}A#l}HeGqNWK7?n}anTt?*NPo&pGQO$(j0=)b+v9yB@PL?M3*>YS!O`0I zO*E-`V*bp8YoM$#u8`uDad$(=N=@*eR-9M^*TI-Gu|Ik%{snQQVE!Bhy;2t1tAT@2 zN(yRJLND|V94XN$j)?X-0PD<^msH&0U1N+&@4~Tru_xt%$P6VSAGE0e9|YQ9L`hDu zo({?oWo28m6&4Qq48|a|3gi`WDB28WlZ-VO(@_ozL0{$)#CB+bc^9Mi7`RKT!rY(& z2eS$Cf%zcJ*7#J$WN zDg9oJnz%15(LQ;}$d=JN*F3A>ej;T2N4cP_rj$|#*qWEgX5c&<=kyq~3!Z%p;fx1$ z)WE&zXouJm?J)I?nTdyP@tnR%fP*|DprnT*Fhgf#Pwq2HCqfy+FUIH0AIJ-?EnEa2 z^KeFf(h9^O5A;O*qcqU(({>c%V64Eq7++(X~QJ1K`eXU;*M6X&8_bJc=#CSpzu zm^e2xbxOMi&0&T^yFtB80S%b}a$pht8-Gv?n!W2k$i4t*~(H!m(Qag`VuV+i^f zdQnO}<$+l(?F6x2+5{q({6%_FD`U_kj*oc`F#*b+53)wP%vgfjM+?pzgjfKhY_3-` z*5tv~yxf@)C=ooiMf*ZM=By}Dag2)^M~siM7e~uPRH@?vWjarp+FmbOWbRy! z`d;ixZ^`@UH#kP}#^VV2LwnA(i4@R~pVV2}bVi5F9_XjDP=^{$4?rw{m{ueFt_S~6 z7d5KmtU)5S^r8($q*W6##H^VfzbR@q#dkAA9PNQ{F-9eJ)dFY4ESlgq5BAX*cQCeT zg5w6nrHoM+wba5GJX=#cVY6sCD#5A{FJ&A;3ld038)-%xN58>OdJ+0xQjD3l10FJV z^wyTBvGga@U9R~MBVWPp%KrKCu>C zWHiFql(S$=>5YWZE^BdC1&C2jA=hBUOUdM&L{P7yWm1FQhSE#P3P4(j@iUsFmJ%DF zzx8yH`%IYo(Q{MGWEBseotnfxYq8S~K&)WR7#N8ik-n9(J%1J-6FP6;A!h=4E;^5Sz!Qw`{U zU2r)AQpEN9>L??2%Gi*|0{KfDk{36Ot z9pPGo=M%YNLJQ28jMxJyM%z#PkJ7{3n;wc@hBfJRS=;mK)H>=B>Bz{9UY=_i)N&#l z^zxn-Q}4VrOxC}M-^@FCH`heY;2?%Z?WUGeJFg`E82=M7IfsL21g!vV1Z@&|$AgxG z(F##H#;wdvui&8N;U94q@|xHLV_0G&5!9l5F*>A`PU<^^^!EV+YSWWZ*hi9LY7x4|SCKM%%zr@|t)KvpR33 z>xP6EX_~MFT+ObQD6ff6X1M&Tfl{u#&_>ip2`xx%oD=6@l*->6gBO3&T6jHkM;YS? z9?V)j`DS**JdrU0J*-C^S`Eqr^^F?Lh=kb%Jvwb8EdygP#sI`2sL`Y@<1nre(9Y6B zQ-bJMsI8_#4G3EA-KXuDWcq-M+vXe7twCe>rmz>dmIsW zdvR{;*Fzy0LDIf4%BTL*7F32-b7hyAKdm%1mNtOdBR`2zGbW|4pnlUTdvOYF5OGlA zEaVl}w&-mbf70LZ)5}CMQO<0JD_-oKJ0D3K+F37x@}fRkaq^9M7RxD#^a`}8q&F=l zBWhBZR3+~@GomAmN%`x#_E0)#jVT5E$C#IS1KXmW@v6rwmN72py|glnIw>KnL+PQk zF%G2UdAI~4Wug;|ZoPF&j-Q@}UVt$zWy9;4_p?uW6t=)w)1R?38Eeu9~(7u7O5rsecjqDc=*_5XVhdO99NhI!Br zr-3)L?c8U}T#x%H>L&EGal+RV(eh#gdJkBBg~G zDx(tKLqEaZxRy$5L3_cRl-|XQRQSm_g0e`Ab+UI)C$@PYBV{{`$>1^WsDX1$^3xP9AiRim>1dQK`W`NL?Iaq zQu922Lj7a>##V@ZFm|O4_w15~LVIQb;&K?c*m0DwlzV2lzIFm<^iRYscn|NVZz68t;a-$O z+C|DN<%I{ef!2-^?jgh1V+|rr+#$&<$J0>SSH@S=0>=EjN{dcTG6Tr~cRhr_i?N7g zQ3{wdRK*=U_gbYV@UjM4iIl`|MpHx~81?ZW|F7#8M@$rf83*}Hi$F~y(!$u8aUwHT zo>Tjnb1^UFRerK1uIn&&X7oxM$9RVJjQB|+Lc@rR({gw{Mf^tLGs+ zee-BYTR;qv`(GKkF%D&Z?MWotVSedF{FGfU+GFg``>y}v?qFI(+91x-vnAeKDe0w& zgYwE8ia8TkSGYn+kI7kbM}aqYQje{$MOt6)H!}=ce6BTd*9`4<)r1aH-x<+)hy>#R z?s+1Q=)0-^^cJKUDa1CY?Tir_fig>A9#TJ%O>$>^?Zi2ebMEhCq`)ka2s$&_W>_8P zdPZC1qOB6YZb0tItK1*Kbunh7%u8vV*ft|m?(861#dw#Nh%pA^Ra!B6A@Tyaq9Kxl~PY#r*xCH-cAYPljObED{YB4dPWA^mwP>0 zW3WE%PofKc}UnttA&pT`vdsFdZ+dq_-v?Xf2pKRs!v) ziQKtV7y3_ban06y;)A$1BSu;f;?C9oSI*3VHESjQXdj3|bC<5C&x|Cg>9iv?P|nDq z23l_d**32?U8@nXF*vYCc$`((SsL3Fzeu+E?NdgsMQi9qL-qrWR%9Zi4g?R zSH`-YNAa?5)}pM_BhZR49$^jIY>#66MV+BtryU_z8JUr)%t{%<(H?k_J>xuTCQ%@2 zDR;OL@nY0Ro8n=~j3(&e7&|gABsa-P&pvoDKT%qGZa!<^t(15&Lq5}*Ft%XZTtD%o zi}?f38Dn_InZ!6~Gq@Z16uv1-M8Ft@@Hgu*lV!|FNv7m^Qpr^`&ySIQ^mN{e1w9Kj zl)i$wfR{~pcoEk}$t}uECGgY37@48crW48Fu1`j$T=Sw2^3ZDf8J2MOH4l0SS_Za2 z-@~1mrOtcn{A!pGL%+TqMuwEg2bmQ7sWIj+e-am-%@1im?LY7tVkY0#OvM zqIwZ4Z3jIgBV?|&636$vq!-(SAt}sPc(6alFO0LijKzy}89RFUKPgG8z=Jm5TRovS z@p8}WQ9k`SC4;trI!n1@luOC;&?d&eM9w{)c)1Si(Lzyjh`uuBq?KXb;l-P@Bg8wr z=-2ado;=Y`P=c>}KH@E0TlWxH#>-@!~q-iPU6T1&O}7 zpF1%SjFpy^^rrRWJh-yJx$qzJ4bqstml>j`?UZy6Q)GU6-9xfH>MJ8Q;$4iT8JW_8 zdgtU65ed)F(&{t%r%m>#N~$wX_oz=#Krc-MfIFDz7x+o6oQJpfiW=^%Xp*o0ADs~Y z=9!m)P~&NRh&s|gQIA>XAzW;oD1(PHaAoy8Vuj;~B0Ow@`!+m}M9br4JCrk0#j|dV zMm;Y=f5TNY&W-kuQ7)-VN%XK0B3aZ5Mnd%2o{#Y;#Mq2~}HIpYRK8@x_0Ld=+X zC~dnJb&+O_@4eiX66r+=9>uRqD5a56swa`8FJ+GM=w-kBO-h zj5o+v^4NXB5mx(US{ipv;^ZJ(KIik39cFk3iIulEM3lzIizd<%8CTnFDP( zHG}#?9#TG-W00HVB&`-{?#To{J^Y;YxCY5pd}h)dD{Yf!zu2;u<#^Z#TcID~zEaB8 zb^g(=P)4ZXj2y4;xbScbdIb-8=Q8({gbBDIF5mwrR5^WIJ5uvf26mEn_kB& zxwC|x(8H^)#{uLqIZRBHJ1=Mx$W`)!y!CQK>Lk5`7sY#s6umfi`Z7}E$|9pU?qOoS zK~B&QP;2SE$sL|?m4s*)wV$<^Z_@M8|B!S1_WB~1iG-2s#E*E@!>So4@M$9tqv9Bt zSrp*g^IvS&i$CdU*e1U@TF+-urih+-F$`lFV!zx)<7H}G&ErlQY6>kS?F;3fc8q1@ zCT$39Ds_mO#3+j^RoC+pu4^$SWA4B;ab}jakX?8y=N@Xn4AVmwxO&E#+>cMBqAtGq z)Dq)NdK3?f@DK<3BSwXkk>tc5$3R(Rl)%h`YZ56aaYsBo8{+5)Y#iDJZ zjIm$ZC@(JZw2an@*m@b{lk)6&XK!VSaSJ7!dpIb+UIfRS!^5Mz%#w0MY4LXG(Drb3 zg&bpa$f(D|J1Nclqf8biteLl_!>672WB~P7{xe>@O%xQ9wVJy z>^ot6$$f2n)`w3=z3CjbR?JiLH?}Uc% zdEbkeeF>0V@0r}7xPYe_FRNOK`7+n#(+L-`tNAcm;`<^(;9{D%0_x^~c6_etDE9as z!qxp~^LU~i#+H9WTX%soCqUOz&`$20J%%xuK$RkDBTBwS6u2Jm4_K!@$L9h(J^UG- z_xuV^2`s>qf}g1`)Ef1U`bAB`mFLu2HC-)J^VGZQP4xtR_f?;$0cwTntzK0V)U&Aj zv|6s7Qs1kW(cVkA^HVhyC6m$qG?c!rzEca)>hEeM-qP?1kgcyUlD&8u@Dh4Ej=zJ6 zXWRe8H{Vsjrw0nKr#ne(#&e)Q<1ZJ|(g<{I4JoMx8rOn^q(c_CkG&XlIF3*5JFY+< zez8q81ErdPTii*1H~4-xzB}Nl(QbH#xHsNj(i>0v^b*hF*A(%n_*lFsmWWTp7h;m;Ls5(-8q^987TWS@)Pl5gg z;LY!NqIfR&|0AfvIFLJqAHzuB6ZebNkoNVE`_*C+X0;8oDaI@^R;GtLI5z>)4f?rZLRSE$SGtDxN$ z$aQ70SbQYf%KPO{vavQyE6{4{Z|FbkmfqEvZTx3+_f7L%@Y%jj{t5m!{YU)!{I);h zZyV?uxGgX$a7W;oz-NKUf#rcEfsX?p21W)t2SNc~;Ie<7e~!PgKgqwzH__M5H`nN5 z4AV=sG;NSv0Nu0PA?{~RnSHOl*t*Gj%j|8gk57$Tu{p6GvE50JF$30f=34V#od0PqHMg02 z&8_&AY_+qp?WgP!PSBm>z6-6YBzwsZw3hl_eVuW|SmLYc|II%<@KvC5uy=5Na6|Cx z;ECYG;P7C*;EBLDf#HFkfog$G{%QW&{zJYszB_!k`CKF4IB6^Zb%q)*8uN`!#u4L! zan?xkoiO~qBgSQ8nXwJmpEJ4{x%v`)qF!D9KpUiOlF!Ti;thDJBaq4@wZWa}o^aMX zW1JMHzVnN1I5X|v?T77G?fdb$!hXd*WzV+v+B5B8pi{El#9n8uww|!sSjpBrbD>!S z$1CwZ@$Io`v2)RQG#qUeeJIi>@=eA0^3@en%HOWA%a)X%E9+Ojt9*O;4;5`AcSh&M zUXC9(Ph0u6?bKJd3cvQ7Hq>{EZ@vFo;8M^Itqq?~I+L^_X;sq2@JHdM;klu_(Bk0a z;HtogK*Yb?Khs~szr|PT`@}cQzt#T-s5UlOFZ4#JdN?Oj%z{P#H>dpL3?!kJ?kLxM|04iN77QqaR1} zqSK-;$G(rXk6(+8h}Vesh@Fphh^>fDjolk<8mk$7KiVgfQ}JPBT=}kwys{o;q4H5> z*=5g_eo$JY)m$u9W5r=nCe~>&y{fN;|AK!(U_@|2 zC@b_|@EGJ|Ww2YQW$5S7Z=rp`!GRn6YoIBmzOQ`aeINJ&zD>qG#t37W;Tk{s-t*n> zFZVwZ7z=GV>hB!r71$6s6KD`L178Qy0^R%_eH)F7`YZZCeW_MMJ0ymw@$M;mvAy3~ zYGv7fnHS8@V~3+VB41W4tk_(3ds)rW)5QmhI~H#!no(>Q3@==f_eFkb-qQu63S_~L zdG+$O+-13&a?D&acYfZiyqog>Em%>KQ&tk46rE@8v<^EvMGrXs7lW%qUxjuc$l4c- zhSrDP41OIb^!;q~(<^C{<=x^1HPCJ8eCMojSHLT@5nIFXEbXn=+WqnGQlyxuBOO}+}S~RzGbC^!e!972iV-=ElEu%SA)~OMx$g%|re`zu=hQv~YUT zuF#s`FM$XBmHZL?jAn~p+`4MD^Q`-{Y7H+t!080L^`2O+yGH-Og20sE3!#i~&6LS$ zIT@ES+hkVBoS3mRV@T$E8CTOENHbG9CqER<4rhg01xNY6Gg=xC>310+BdT}NC+X{T zzY#HVeIo*00yhUn`Yr!&f%(Bsp}T@R1F!fi`v>?+^m~k|TB+=)es|LB@_0e~m&lsP zobsZwr1GNT4W)Kb*OFk#p`xj!LyEpEU08IuIH%Yw99`O=&?-7xSX3xV2j=f9id=gq z|8VZ?oGk^zudOY9J^%Kyg0e45ABnXo?`-ckCyM^YX2Tcy#Md!YeyVAyI z{E@ahqagjxN+&AaS;@()l%A6Da>~5au_+sq;^Do4+~8$@5_Dy<|6TuRqmBQx@8v+P zV7uUt!J2{GKxVL6Xmh9{bTV`__(Cumm=pNhU*Owq^fsDkhU}?2IJend%&qa?BYH(z z`JZKNBTXv?luszxQM|c0gkU(; zM*>BGKYe?A)AXu(RGT4Rmi5&=YNA_TO?9WLqwa9=oxC7VYL97seT#x-QdV+GO7oPX zDYv9$r?&$X^JHcqqgBSmw63WO(!NZ-A*CYpXLx-0r_lM(gP3u1UlAzLFt{P`SfF~K zrSFPwiobav9rp4C|8Za3*Vq51?_J*l{SjT$_R4*Tw)Tn->}2zJY)0(f__6rb*u#;j z(J^H;%9@q6tq8>ymHR8&7ws+isBB32-10`nwTjvmj4xbN8YwI*t6JdXuFb#w>ayY) z1$~P?%o~;4w?O9hDcM*ws^YGaD@Ez$1EZzZOKvsc2htf=-Ss2Ek-^TvhQ>T!6JNu? z%uqkyJHaAxT0bP`=rw%DRf;y+wh&o&lb>pZ?j7P4H&we+@9oP8ya%nC8SENQK70;`_p)_)R5kqjifGlopqs z$lq5ssd#9`o%uf%UMyTzJf|#@mxhNi>y>^|TB)dUWMyQq{doCL@dM@f6t+NsdK)H>m&$<;$mQ$9={;eRRU z1hvpJp*up!p|5>@<52j;aCY!*|KDn9ax}0g&?+!2@S$p;RfI#K*{Xs5n3^l@*1vVz z8ZG=ABIn%a)eTxpxjS-Do{UPjOy6R^Dvy?Qu#ZK446G4D+>?dFB2T#u?Y&}4bVbEG zrEi!%Sj36tZL`Jx;Oz|}cerinenOZB~XCYnf4Mgk_!OI~PUG7f_ec+bp z74A*Gi1Cg&8oqp;hSPEkG}4E$vWfii#=k!6_{N9TD-5eM~!hS z#UD`F{vKs5ooszwWOP#9;;7pd>9yQ>zOwQ!f(KN; z_$&U#_5e3E{Bc=^|FC@A4D0WjxA`xb=_=RAv>fDyt=$yu3HLE~p6-Vcc% zYIiZlSQ*+e)hfQvciy(GQQB=z7kOGe?=DpXk!kgH7pwOASN0F8fn4gYa}4y`0y6Qf zOm#k!@2Ih=z9?|3;A%bfjH_!mIfvycbrF(2KxEXj;(br;^7IIV68s8kJmD=um!~Rs()yCufd-rJ<#U+U+&$X@MgnU;1A?}k0 z^|@-l_}Q56>hf#-jQUIbt2cGC^*gk#ZasahY_10DOVt)}NFGo}MNN5BOh!KYnvA+{ zYFR)N)@k+K<+4Z=x!KxWpa#QbmOSJx6=M*&PZuMAXe`9@(>v4waYFSL{X`>GBS zUn%RT;o=X`LCu%*fjzz}A3=s?$Xk#-{;l4^`#WokxA4Zn31S1X-CgPzWRbJgP~arH zfm)@C@9>P>YOxRNP0Qq`$U-lRo$3p`ud70>01ZO0|KF%N$TTb54#F>|sAqw6gL(D?XjknCHk>VY7lX^x>S07-UqmYm9 zLToS?h)tnt3=HL8U>Pfbnau&pV5`T}FfkctCy-;05siU-#ejo6D@Fl_i31aNX^1K2`2-^ zx*y#8SG|JwLM{Pnb*p#?^X-j!E(3kO#`{U0z|ls$17xGK_!7DI{lK{T0WazW)T%xZ zse3@%#^P1%uzeiJR+;J!sy>GA%HVR28Ym6`JyEK?_zPoB2fmb!w+iuDalW_VFy5aK z#Cw?=KqN??6y*CvRWBj)ItW~=DKL)nz-#sZhgpO- z$L&Nt2k!;r8;b&v?Xzf|&+hVVcYG`3Jn)F?kb8hH39$i))E3mgjL+X7jl_%>0L!TW zW^xj5PD=shxR2s6=pP4ebt6!*zkvspVtjlHUu9roSAo0fpj%k%#rrs~0exBljl2lX zTvfC2HqM{X?nh{29NN2#c5~GmXmc6H_%ptD0OMK(l;$b& zV?@N`s-uVckn3i^AKQZq+~>x(o9t1yfIc=T`-|!nz5cwGm49}dj2CnSXLN1To}i;yM$f_$$)b%6AgtNB2^T8O#e)yrtL7m%uSNL&Rt z(-xmeK*sjq7vKLe4s`|tMfX9%%E8w=@ar**wv8Btk+j5>XCSwwY7%g-TOegsfF5xd z#%*{b+hdpo-=NbS=X`5mGsrF9Fnbv2*GW*fnfL_MX$4Es6TIY31-^f>0Mt1Fz0C(Q z#&?u8hom$_oe(tnERfP2kg3aR6D-d$pp$=q{~D~!M)ff?bvuraAZxo|bJyYCA93wZ z+`~Py$02=uI;*W%47n~-{lVSWfkHl_-UDJb5!hK@AfO{)4|W4d`vW-IMPQYy(8g=v z{Vd>_Kcg1knfDp2z(~}38@T5>_=dHRpGP5Ap96{e1M+eJw(V=6rJPd)v(SJt4iG=# z-i|+&Y{~^tYeM2ps;@v}6z**% z*72RK!OGzetal>6$UXW4YKME1`ba*cJ?9*@PH9{8$K4^W1q|j7Sr4{)IneblT}NDo zJiwB{V_i7DQ*j)#i z@1UNL-C)@cs=J&4ZUHRn>&||R`){LnV4r)Ra|*qTQvZl?uFyUdBh9TcPqcLh=o8)B zoMy(OPG?!)sAX+*1I7V!h#2bs)_KY8sb#}AI@&SJ;6M3+<#%GnY!&>uQtffoLR;>v%6d1e=~YWYvli{ypO)oUt0QK@GZ?( z`ca^d_G|Qg<8M(M?-^)hcUJ2|Q^o%DDPm!yzvxi+=Ot}}3i#TB>)GZT@s`eOeyI(1PF8lU+v9Hd2KkuLLUgqs@^`QhYihx= z32w94tdhI5cVavAp8l2Q9cgp4&Z>R*g7bj?jzEF64mtTY^-*Yq^_+D`J6Q6J_I&uO z*tO%SeC{i};pjc??v_yyyo*dBeRe;^h_8#(Izs< zSP`uo>z&lp=@Q=R-(OKr-%vIoo}Bt#v^;!B){Q0w_r~(971}iQsDG#ZfV;vi&_){5 z#dzNUxifmR{#vX^EC{}4{^Xw>`ABqhzKYEfe+GLZ4qj&TgQk{RRRTNocJgLX5V<5L za|J29*6I_wFaCby>%b>U)8+|7n<83Uf5*zSUUqZ+FGLEznggAy@;;gB!#+3j5%*>L zdo3=yxGe+Ei}}_TNMx({v+noWFn6F@8;Mwd#;e+wwIrcqkUEt?JwD)TY>Ds+KiZ>1vV0dt%IGR_C}+-e~=4fT z>_*BGA83zg{k6m5bJ0(}qZRlL``yrO;qODy&}sj$zM&GS9(*|qjz_64lDe2Z0= zd#yRvcU>7qRZt z!@1je(!K27j6CRe=4^7Qd_$Y9H2?zqoi2^e#;?X&Bggm*#{^Jrn$ghs!suf>g_VPo zdWGIZzf~bF}ZZE5OySVjbr{?N9BD)(}(~ zsMpe6?UeSsHbZ+{dr7-hYouM2E9FAjQ{F9G$=b4-EEfBab^apeBTKp;@vV*sa~C4D zSy*$PtX@&Wkkj>2ebh6Elg6uuR8O3>RUu`lAl7c`si^y}`w8MG(=kD_8O}(jj?==q zY;Ur^x8Jc}v+uGS*#TR}x>Y^ctfQ8&uUUujx!cOM_E|Yr5>_$O>}^&aN{>OJ3amNy zbYRO@unzW?vkZ9KWA2Y`D#lWQi1Y(-Qlw%1Y@RHX)wPkJ!xgOq*42jVllAHPE`2%H z^M1ka$$DSCwjKhd&T6~0dD;5@w(|h-wdlCZ8ftepS2%oMI?4p<6I>uKhoNyr~m%8SXr+g`6%KPp+Rt{xAy};|^p4 z+mHcXK*>2|BuDWU?e$o(XS{w1nE?OafHlY*=Hi%%V;1&3yb5IC31k>Is3sVDBh^WD zSDm4S1JxL;#NUm4;ck@Ni_a%f@&a;@0g%93aNb^3!hGwffU2cpu7-G`9P8Pc2^tKA zt$PW(FGc}v8iv1d$QR~f=fvk2$*bViSTzLaZzA{LJ_ir!c~)$V&$S7a5M8r=a8$AV%}i>LRrICPp*|2e|u{TbQnI%J8fkt=RNj(rBQ#kVjYNA`CX zWyRnh|96FANFbl=3U?Hb~>l6V| zs|u;T6*KMvt9~0WxUoQGVwl$tn8#7|2s1 z?48(gGFfdDQCZLTv3A;5D4#VNi74!va7^nMR#GAvUuqg*cU(Jk3YtE1g6Phk$X zsucICGsMYv_PWj7)}kHeG6;CpLHBL!su}6NsOGr0qs}?;o4Xcx(xaGTJ2hI45F@3n z`pGOY!`%wIepsz?3xKG0g7Iw{zHo=Cj+hCu zIdvL#VXNGu#=(-l4*o8Z)6_|`u;ddx!O``pB%RGJVaHz6rljd1&ROS!Qa_^!~vkqHFwWMM<%f*SC0l@q&t1 zV!f4NOw&Kbs#3mxl3q{$*X|sDM*SE0q^OqrO+{nt)3W}d{`T?U*igHqebs9v-(S5` z(nD!CnfE!{DhA~rDZMXmYivsCt@2~X2o4P$4m_2-MjsGfNmT zeg{5`GY`r&<{oWo`~z)Te6V=b`9f7OeoA?R9;~8=w3qCXSp7)7XzhwRk=x^UnYTLy*zXlq!`)uaZubc%YIbsNvi`9v zS;ATA{@`p@-Q-|5?4DIYXBXo4cjRO_TmQybESnpR+yP>#-Q2liH&FAOJ?aBdSz>p$ zy9Ba+%pPISb{}v`oUP7%Za#cP4Sg|oa=mTT)<2WI)OPEHJ>7iKIpofECyPxgS$=_) zj1=)RH0d9pUG=neqD1tPr`>Ed3MN~asy@1H# zPBGbeL3MYs)EW1(y+`fRXUe;c@Aa?rq4Gai;{NVf#4WEo_h=8g)%Dkn1HM(r?Jk%t ztVf*Q?l$MLJ5Vgr>l$43Drn!W>Z`6msHTe?ZGm{r_kemv zb`&={O|a?Ca&g>e}ODo9*H1F}X=wgw>J3;-c0~e@EV=&2o-8 z21e6V9C8-JdZxK=XkX|h!R6W!zX9FZ4b*O&s%GBc+ID}-FQz-c_%_G}!PU+I-|NoT z`WJ|Q&d8^&KJH5AoY<}Pv)|GmaNlzK%O1X3zCrq}dZxZc^s;Y`A9w$-@tlVHyVKWq zv+<6phFQT#B14Q|I z)e2elWM`AFR1VgDRX@le(1(W3dU4I|t;gh9*~h8jCTX?BS@DR`N88{GwBE%C?>9bk z+89;TL}Q)m<2Dfe?5^=bd$Lni4bggvrP{B~XyE@>?EBQ)#*MZVZ@Y)|Ia-$d75T*( z_ffe`|Hv)V9@Jk${F>wJR_|lY@U-~Vec1lWdC$5VSn6Q8M1NlYTb>hRoW917PEEIm zouS^wO5zp$Z*9MQRWDabdKa0dkCHz*zl$j@R zBg;Fud(3iulVkZ7+6#>@?BCS$BHLZ1Uv>}bPl>jAXLqz!)1KgVvnLDJX(2w=!s>)v zBi}|eQ`?^5&bGHX4`XlPkMJY6`*LMt<7x4iUhMv%DNQ>4&HK#3@=JTY@rcYZ_S>Jh zwZ#b6(3;8C+Ox92?dQ~TCPv%)KecaDw`eT30^mD(cv9r27C5Q8*eCdm2jM3y+w8K!S_SBp#H zM{Tr#)394x^F;%vTcC}b?_VbB8}sa8ZZGu!(6^tQuI4x5A@sad3^E7F&(&D{HvM;Eg?6D|f9^0p$EAE#Hk><<@2=TN^Kb!PFnb zZlw*t0xvj&-5J^eWbdQ(-A)T@l3fgKd`wf$3j1?8+5f8lPxVb8*V!Y6*oU2VVivgj zk%-u<-&=QhqMpXMa}PakJfTlnTO@Q<~TV8TD((iM65bc|Hi%9j##hS|C+;vY2WVaYUsey z`xp<)C$&4>X5s=v;LTouOE~B>_{0=pjZM$3;-RjodrvwhV)!@J0}lQ-4__DUA>EWUuxq6$@Y^TsTGoAPy8Eo& z6Ea^%U*tZg{?@XbhhfR?!8+L_=L>C-s4G_@cT840v42wP^PGUZ*%>Ey+1Rb`PDL)# z%KbvGrKK9F?qN07I-m^KwV$x-< ztvYJmjhEn&cDi$jP`j_#|KPE?f$A?gNWa?%_%>(*WG8t-+~+QJ2b-VUA-Ae^+8HSq z8-0AW0!#d9MpM;3e#Ne9J!JQCGG!al5cBH@ACoKNPIFip&Hdi)XuSY=81Bqa%cU!? zXwUlZ1cvpMJ{`Nwwma{;yX;3)mfKUm6`Av5Ia+kr@kUW?qfAkEszt3B{o6SZ!>On)Q)Wc{)}Qf^l#?AFe`?jh%t zYAGhlL&hGtM~=n%{Aa4n`P`Z3c7qOF(GF-a?LAoo`^cxue_dUTaDc-*W9(zr9_NH} z%_PrP&5@TFGev#NsF9G(XF(T0kFR zlp8PlUiCHc4+vBZwMmL3_e%Rbt$pUO%zBkBWDclwEptug1ndNzkkK{cSo-LUJ(=Ay z8)e>;@ltw=^mxjNlrhPdlbRS9 zk9-mB7`Z>TKQbU*0P!vHE*lAs(Q^D zE3^NPqq6{y;%eLQ_{_S;0|A1&yK8YN!3vb(End7>i?w)hEpEjMrN!Od3WPx1)_r$; z=6}Ec%0-P7*xlLLbKd8E?$aspzT>WyvUWG!HT;;q{-*lK%nua@r;bRtZQh}qPq%~B z$RpS@Ed|2-F)(qpumxCUgjU4-z}fmns$hVlxX0i$(u7TpD*lLu(vR^$Xi^cG*nhU z*j2czOl0nuho@PqG^(0f!&_}h=1&zKCCT*RH0ZGFhvK4E(bo5rSmN-0hDr}wJtt=J^9X4N-U?`3|-=w0Dz z;#$)cdKXer;e=kiSwN-w_*G(oVV>S@Z)z@f^fB(Wtug*%YGRya7;C;{RLyq`sG%t{ zfQZSzu^s)-T-n8S3YO)U=f2L{omZasOVQhcOT}06#e&tjYjSUW$^R^V-u$In{@naW zC0&Y!ln06jIWL#3aAv#qhE7PU)r*?NrsCwG86B$RRZg$!uDCSiuH%boxPBN_3!AO{ zDJ6)<_;Kux@S;Fle=m2v@*m2^mK|^<_&P)?N@4s*ZO}X?pG}%k>Bmy!7-cm*i|y(yC}wVXMOYg7rlY3)_{nDm`B^ zt7K>4hC)|?z3}(E#6nF$`@-*wcb2> zMPF@cHFM=vsV5S?v;3nUPG3bOg_TxFxA`J=btE2s7QE+e>+a@m?>XcB6!6DdC^3T2 zX|0cws-)kn_*2Hxiq;B8QgVLZ0k~TlfFwBXK!Pr%vR$T-M`wqOe)aFZK+fuibP~D-!lFo^xo@r zRw}zzA{XZr`-)OaUX^Su3zX-SmzFmw|GKP4NsF?_#pU_e3Mb~RE-1(7s?5*Xl+9TslDweR?^1bG7ydJ{I9mN)G<0uwR z32pW@b6+X%?fh8&p*+zO30B}1$~TD|!w1KiG_m68%4%lE%(6<$DpX6CQs*Z8VN0=X zwzjs6GrrK>(#)o5q8id#z99Yv3DZExo31dswVY|V<#K8?iBEl-_(O8X#M!nC;~;Ia zejQZ>K8JSTXR%|jMo&TKh&h}t{ycERm+5Zn%qyK$e!TQ-*<0r|_cz{KKGr?V7j&Cl zd&-ka8Wy!E_%}Z%=h~M(IRidseQNM&@aG+QbZ&=|+R+NU2WzdJH0V>7o_+Lh5Hy;9O3$9?M=%Pzw|x{yYu=TrBv&p?CiQF^F?+yH5b zms8EOdknc2#gUZMKPjA)0ekZY*7~Ny`exb;phb7#4ty!(A`a{uv;wk5F-hn7UGWOx zFaGzwUfvJxAKZD)kn4`e3U%I6*H+JCPdDd}&XVF4g<}eCd>Q@aV9xsNq-^wKcJ_Cl zTj##W|E@6P{YChsWKw^dhNo_d*?n^XV*nnMo9%CtcW&$XU6n z)L!Dm53rlNj6Nc6(s{a0rlywPtXpl_knc6H8mLgM+SnQwYD8Dmlc;duYjWVXSF8vyWz3jk~T17YNa-n`(#?G zS*z5|NUyLVO-gB)$XY$7BZkwO-^s?VmqNJok1yxgMr*l=Ut19x9ncKMJp zrgUamC70fJEZl*UN$?H?TP47@4qe7nwm(s4M{gJjKwQ5qCeY~x; zd4_&8vw*mb!g&GmVQC1gO7AG!kRX;uv;oe)op!Z$ns$q}i*~z)(Hvx|G3m?)CdSNX z9zri-Bo!h?;y&1QEs{Ju&Y4*y)*hh_xx15f>f{2zUP zc-wfx?hmdfhT|+V2 zEBe*O9FyMSu(ffVNx+i=NflDwrrw5C*{h_P39W6UIbZi3os6$Rx~o}AE#)_58ceV^ z;d#V8ayn&$w#RhJO4SCIB%Rzw>?WsDuc=4S4*84zj(SaQ#yg`yxuQ6N>%_)lcj56m z;yCMIw{n-b61E34qFTiZqq&jJk;|c1!DGP_q02!k$Og+ zmz3}LBhep&D7@QWdpdZJ`Dz8_$E%RJ;h=g*=zfOXK98FhN)-mnqMX8(3wSP1=Ba_y%{Of5S7h z1W5-TA{+DJr^zbhza$BLrOJ>BH2^lqFZPnW(i`!Z)I`=mmvjnnj_ZWxg2->+lGt(a z)vP`MYchIa9XfV;IVI(x0C0oyPDhY{>fJ-D1_I=@ACJho1jG533N#v z-~^r^^U=E4b^HvGLOrG40av{s&}*#Ac&1GUo|n}|)t zyg=tS#TEgBFd832R3bg(7HSXmk;rgl@_VXO=NvGX!)=2T*IsH$X_-#z)~2JR`Z74(48Gv6+|# zJAv*%JEEnKfBu7*A=Q>)2KEDb5zT^sS^-(uW!Sf$P~!4Hxfs}tfH**$A{c;!sm}kw zePS1}o!C$D-1x-!?)aN{b+#KjmtDxNV_&nk*zIh8AY-iX-=;t#`5|lNrgL@q?))YG zYvHvpP24IDg|>A>dLl1ThC>FI23IqUZ^M1CZ(c-fB#sc%i3S9RKf>?e7olI>4IhsW z!`tDHv9<7gh0q(&zCH?bpR=eAxQ_p@%J>pIf%ug;M)-*4BtsIUfvin-A?uLi;h8BV zZ;_kHuH-XfA+!p=;632iIfl)_CSi>*5{sfNur#-Tm6#5!LkqM%ypng3Wso99fZnbI znb;v{fe!(`qM7n9%o|n%?ebAtBF%=mKyB%J=|3P}yi%r|BG-~LL)b> z9z+yvQVl5!evfU^0jVx9BR%1@H~=|C5r}RwfpHuRdEI<$i2np%zW_@%6rY1n048Ju{xfuGYrrJ41iOhH z1_G-ZwiBHO?281~1qWG)5v_Z#KY> zLLji3Lc(eRa;6S^_Yc5I|L-jCy^!Hn2QFa;&__04MJ&K7{0ijAL!j%AKz>{U2>Mx& zWn18#zW@n97Tg}V@lyzj!hRf_1N*_O@d1niz1G-V1OUy#rqiz+4wo3y@QQM$b@K@?Z_+1L*jgbESjh}$WJOn1IS|B7& zq6Oef$deZ;3xWS_1^IhtRhDlnKSI8AQr)7Qg*3FW>XgIEMo6HWsTY(KWT=cI8A=dI zSKcE%m261>+s9S!s0h47qQgU8$k*yXU=4?4!*phZeg=m!l%+Cd-a z2W*zI5S^<&LyxJyBa_h?>SKgcx?@k2AJJL=PmZ!`FKANi#SF?+6hQ{UEbb6o$H%~t zFoUugCF3CVA~6ui^c4c7t;m!qI^rTag({6FGRM^_R0qLBiK+o+k22Ae(kYW6htDDl z%=cmb{t$AoN-T{HWG`;W~?l7~Xory%{9-;EP zfWCRF9+Fd}%kox*U{0yekk#A}#UKqwd&I}0+het95!_&Cxii>*;lf9D? z8FvZ;5HEQb-^eGaH3VouiU$pUAf<_W#1_~}IHvd|x{Uq}OOR)ZOGGXFe^Gt_w@h~? zx?dioEq8^r5rN20yFXKD&OXP3m{=yo$iX2q*Y7)9b)~TFW!kd+E(LAIE zI+o}NW{ZpPT6g1HNv-(_a19^EAXS$t3k#J7>`;MMrXly`3uq&`1Sr%WrK3t6DGQmT zeuw{yE2s@wpzIf?D_bE|s)MDYd-1jEd)XkZR3AuMak7-Ejzo*_rZB&+BfEqVqF=0y z*vN-OBXW(LDF3Nc62ov8>Wx{@kMv~hYgjR)aqIbp@x#&~tS(ZDc2EYX|MD}%d0d@X zC26?uTsjF;>)OOP`G7Qn6{6!~ttFE(68RI`1QbHH_%XgIUYBjp59NHy7T_wPG9$m^ zp2s`K99(sNp|TOk#$5b~`V!KDA^Z-$y>wskfFoozOpVVgjnzt$C{>h3sSLIpuf!xX zGO4E;;pZX4bdh~E5A>%@BaPqd*Xyn^SsF<@TL0Ndn(muxn?&7Q-Ac_k8l!ts{WRs8 zY5MWn2b#^$8~;RJC9i>5;3sAu(~BNNwji4jO@X$YjZ{GgAS0nkk_n8-E#;xao6R?k&!J~MURzHaWQNf*s6zxzv?p-t zU+7J1=wW0U-X3xV8fgyv^lPjbyMSy}Nx6kwNty*#lzqZtQDnEW7vsxg?_=|$A7kH# z&xKos!hy|!xG&dl_Kgd4^X~E@zG3d58!7Kp-n8U%N%is{oJy!)JV{Je@}&#PU2(kJ z1F1##Hm$SPN!XQe&)U+oU3WoKQRmVHj4Okit_`+3j7AXu8<;@2rY%0 z!cJir-y72JCGsZq36@OjG$C!waKrSwd8u`YJ(9>I*^{$j+SM{?fTN9LknNDAyLqAU z8(2k;(k^7iP%OEQ_<_tOE|Jyg0F?$iq!zG;s7h2P>p|w;5UBnS$YW?rR+s-1YKo6J zQuxd^=6>QFY)1Tg*r(?KP7wJNjAA8?&?%n6Ky)OiOJ?jguZg zZdCUOwc<0PO<_;+BzBGM!p-MWxQSdBSUm;#>Qb`$8~T`7OmAg6Y1`@Q8#PH|6($IBswJesopD6FMF08+;v{7FZlG`n&l0 zdQW5Mg+@0NzotMg1mGmvDQutlb&&4%eEdoEp4~R3>sxWVE4Yab6_?3wpD%;y8 zpHI~!cXC`X{h<8}){Z*#U3xzAlHQB|CzEXN@KgUBUsfO?5{=zv&vFQVm;E*Vf^8v8 zfNZQQd4qYPZKxk(m~K2{3R@T3zfGKxl$cyM`JQ7jh|8{6R+;}bPBhuUo-vIx~s2A|4htLeOnW@7Jpl4Gr;XYRaH24#whoY1J1xen?`1P13wkbL? zIyWMR&%+wOqA$r?$1~jB!gJgGucwMza%!EUN}H6dE9_YKp&+@ifAOECqrH1V9k~zk zZ|Hks84$=i@?W~6?vA-f!k?+v(mJL-PFQ0eq}@Sm!kS^}xQ?tuj79H@-DB(hiZi+V zoKy1F4i(2N{AGSD_bC1#uH|Nny^x;N4?52H(R|m!S`OLn+bSlQl72~Eo5Ck1B(|^* zvaB@KG(0n`)Xmd2VS14r@TSOdb)dQmF<~{pjByjcfnUOUV5c!VT;&_Uyt0M*b?LkC&_{d0zagWO2#6Vxg!{ zAztVybQe<|KD3LkiS?lF=w1Pp(?Ro;*{6SMwj|A}FgoK*#)q_fjylGDR1wAiBm56O zfbd~E)mr@9@D2~FeMH>ewvAvfBJ zIi^SKXCSR?kiheQlWSz zypoLzrsXdvs0?Y!*5bFWJ^?XqRevC_X+zrix{KQS+CI7^rpb<}sdY2TE2U>NO!>|l z(450_pt1$cXIjW@-^-rL1UZ}dP^BnM^P_4;#si@5+b+gI#?bk6VyW+z>#ewb4o0VLjE>M#`MR^%YC z?1<z!7dc*4>{_k!$#K`#o7Plb@@^XhzYdF)2;xfk(@-d3Rnu`stynj()-f?(LHrf9+J za*U4aPgpbwt&_@9LMhu)swW>z$ah?_KCmR2ou-+lPX>d5)8x}%{vIe{pE^IG$73S|r{LI*!@S)I}@Q&ywHkZ%a0H zKY9bZ2YW9&HUx{7NC7v;M0Ih>XKJ^DCw-S6}k zcrJV9dYk#C`8xY1c=x&Iy8F0!-{$Z(_M*}WA4(6V?}G262p>eeBYACm6K>n*NOO?Z zQTmnCd9SSz&KYwQ3tm*?M6)vhNadL8z6z2DOD5O3l+ulQVZZemtt#h z5C@?fkUxP?httNp1Z#1 ztmmwEoOi6(?(6428m8j)U=_cK?80#LXL=F!g1k;m)I2d9wQR5#JL=eLf%D^U;vlkC z-V9t{3F1PZAy1X>gwF9<;qAeip;-8zc!ogARy9*?pzc!2l%)uY&tc;F3zj#IDTxJ1 zv4l^K4YuZ%XQn5HGX_@IQFlb^OFevyswX5uQfCf?E- zkC2VBPNGB)&vUc6KVW+E3^;;;vDFa+Sa8PpF1s7Kt~&cUYqB_0XPk@xtLQ+s;a8{^x~Ap<_7e%-gzbsj91rZ9tw!r0%RR8*u=)?Wos3A{ z#7(FHsjlS7<6!#I1m-V8!98|eo&b%JD$o_FsN9nOlLkR=W)JLNdWh}C3xZX+&oAYA zaShlZ@g1@H(b=p8IgXL6$e{`!@V8Hdff6Y(UeA@9{1KNjxJ4P%oI~ zhIgi~tTk-ot*1;T-8b|g;uIPIe_;&fM9Yw->Ur@wS1UFn!%2s8?36MvKEsLZ@K; z{k?avr>%Rb`*%0tMg8+af5ZlIth5H&X2-D=(EWQ)Cc&NlE8}AGPnK2I{@@mBrkld( zsRDc>@D{VM1?V01lFSRG?8TTZ+BP~T%EbnUe#RQEk6aOGfkc- ztri;zb@-|r&+cR8`0e=8c-Q#w*!F00^nUnW$PrZhH+_dZ&)m~ogIx1mzqpTjs|Alo zMRud~T)m7v#+wm|)L2?!e$ovzC0Wnf8r!`#rzPEZN8635P1#^tI3C}ODad~1TPc-) z7_S?<7A=lCW52QKTsA*X2#6cxjmQb29g}KMOsV!Oc5A}1gzgD%6I93jgsF}r_HNd< zre+4S_9sS1H6v-F8~#1k7aasS!~vxS5V_NUYrmlERE7W{R9C#mr}Htk8TS`>1dqgP z#LvarNB2a=gntQffmDCcTjrkc`nCM8vV~=_vb)Zs?gxQ4kzd$EsiOJ|bkZa|llqTl zwA1t%mijiEy|Mj*b(z_#tEuTrlVm+`|EiQJET zS7DXt1Q*do;x=P9=9_YSe{rbVM5|M2j$phyXDOz54sEMYhN{jDS2uCe%xymmcvoHP+Hv3qDQ^bPV3Ixy{21lCT7a#;Rd zt_c&`t#Y1ZkPZkpxI?TV{tO;TyV%oc>*(Fc`_S3oKz}lry&k!KcMdJzQg*p)P5CZY zTW_b(8<^z>l}Es&_QB^7hiQv;lCibrenPv%`3aTneasCFJDII?KOp3X&=I&FHGo~) z!Fbh3-SDdL@TehfW54Adv9IH=;~MT)eyDU8IZtiVbvMfTdRi9+;Gm^#ly`vI%#Ra0TNJOJ3ohf;Or7xgI2L!&^UO%m?2 zk78jk`FW$KqoL@O$hYB@fq_1qXO8Q%GhWuQ{A#IK5-4t1a=LVVx$IdF>xDb82HgR@ z*l~D&>bOQSw6uJ7xRZ7zUUHnb%rwr>)?%JQX4a8jLP}V7r9D4Ab}Q5~=nFOs--@8{ zyF7{JM;=A$#yEDIU{tDLEcvVUtznDhXGfcat4VC){G@(KQTTK{khmgglH;OHHn!7M zV=&?){xf*)PGNW84%bV4CYOr_m@QRRdq5NM9CWk~s%>Bb66R5U2G^AHuy4UjcQCRc zG&(TK$9Z9B3f?{V;m%^mNU4?IpK9_xRWrrVex25CgUaSxK7E7ciY5ST5$M9rBg}jul zi9f+{*|c^IlcHU(VYK_HHrO07Gny3Y>Ho|3XJBZ!f7BaEh38HN zF_rGF!FAv32blZVE+-gLdL$oD>5+6l@r3=B-EGUYyfmNCkI<`BL-H{85!zTG`KfXl zrf#URL~RS5vg$B3bfFcoK44-Ma66bj51=O$jr_H6f%`WOD{{DFe+oYd^!1&03FQT) z)>5ogD0x(j6*}|F@;?<=N+joouw9rWhtNUT9_Tvx$fvqTmQ6`pQ=X>3Og1G5=100L zZ4-^G?W$?VtR=>)hq(=*@qqz8%(vRF59Nf9z+*ZdS`^9-wT?byPfL5zmDE*czqV9o z(4WySFzvJNOstZcozft+a?)Pg0&53TEpx;$&Ujc~O*4nShBZSE!s_CxdtqPn>Vd z50>pK*-~6o@Hzi%{<`7;rDNRJf<}RWiTx)04N;kyLJu+4v85-!NbQ{VEVV-dXFG3v zr$4PV=}>JmssOtsEBvW&b|C0?`?mU91xTO{PKP>$oS{eIm$9w$e(q8_$}5bR;K_Om3OHIH9d=oas+PEkk9)Z^r6|ak{^#i}*q0fQ-tS@+>gM^n?a# zXW*Y7BUP|<*gRq>wT`i9=W9o6S25SAN#sS4p=aWwuvx%C9pYcdyF`hwCp0NoD|p|# z%Cp(!E;E-OF8NR@740vsSTwJ2Y3c6L8lFa8GWu3rq$I=ZuhX1kTIkA6KHG}qwaJrH zhb2+=m6qLxbe+uX(uAli{Dbm~(3qVWY8yD|@8{PCJ_b@l$3mAvQ^WaTB6gH}Djh~A z633Y5+6_9rey)Cp@i)_Pi_h9Kph!(t#-oH)ZE0=Ng1|O%T-gBp`N~<_Lin6m{1zaPk3E?LwIsf7kn737CIB^ zAM%IFVFxlJx`Z7oPF4rwvq_S=sga?dSP+A9@s4hF-1%(&3)!1GlO|Y+n57% zBPN^viL8ak)Y@`);S@KX^~DFXmtw1<+d_%pubk%gcsIIIy>a)HvO}(##mCEs6mmsZ z3;4V&*W}{Lp}asyYzemOZO~4ysTo1{vF2INCOu64ABo|F!!FXyCi7jxSU6DN5!D(W zC$-|YN6W(1gYAQX;L!jZcch{YDBnWr40j8Wcusz%t=ePS*2a0p^_Fz= zdh;90G|L6cd1D{rLOo&Fq~D=Cp!!1o)!@h^M;5le{{;zEYaGmp|gK6S2efM_tU+gfUl62aB1;N;k-~hz`z_y_OJ2p^Y09d2z?Jl zy9}|Wyaa8J-yvque=`sDLkwZl4)gDpQl+u;qhgv+0IqhT*8;f+iP? zxwpxOR2_U7o(tLJ*U*V|p`*}=*klZa#2=%c(sSq=%qz_T-D_Pl{Z#!qol|#K>(K6E zzQAg|BkFZ&Jt+MXNkt)2w;7I(gLLz%~kmW=XjDw*z9 zil&rRE-5Yk;L30h_NRtx@U@iQ$|CF)a+G#4kBq8yux)+f(uBVfU)b8%ds+ILvW@li zt#t31ZB#Kj8+M}8xvK2sXq}iL5)TpKc*qml9cmwa6Ya{L=eJ7r6cYBzL98EHMknZa zT}{J9<7%L!9o7P?-`Xda^&6LFdvI;JD_)b&oK$n4jqT;Lm0W8}S={5B;oS@<*YXkj`Ca+r}ry zevG7s{{wq%lMoiT1HW zHqsFiErRU<)A{a119Gx%jH$@5*uKQt+NL*Guw4%OX?4wZqP6<@geq-b96E!dKF9O`RowB;Vlu@bRc0j0&s36mgU%IFg$Hd0HyWu+M}}1n&nc z!8;G2v^1)otI7`BbXPttNaA5;}j`fc*MEl zm0}BHt711|cj7;=rTlA{g7%c1(h6u>cP9QJt1@?WQ?;XwZA~s?1nLH4OmqA!O>ACV2wN$cATT67$@>ZzAe8fzM5?ro5_D0c^cgm*&28l*&E7= zoc33a^bQX8tc!juZy0RjoFDwt^_6$2yMcRy>wVb+SGVFjt|z61!PVYIvD}a;c7uNz z-y!!G?;xYm9@HnzMNN@DT`THF>SyW?8af!38@n2&8p66m+E1E3RA*`pArmKYhQRQ9 z;Ii8(w^8oGeYTI3F85X@%5T(T@;R_iWU8;x=J;UpN9r{57xV!h>pvM*n=&kCEHA9{ z>^^JI*4Bzzt(N8~S?Mi`r9QKCeMP1n>R?da=d?F7z?w!Reup)e`=C>>YvK&_g0x(n#STUe#x|(8lv>dkpT+*hrf|XV zMWJmtovr2H$Js+K0+C>P@Mc)>pA0PUeGYB#ybC_?EDxvp;@;~1P2T09$5)p&jqDdz;Wn^8%FpCm;JH44<`TK&Qu-79l&MYUGt;&IX!pW;yqFoIp*2IH8?h5! zt;f(^`iSI$pW^}a(=*im(pS7UT)bVXvQR-_JU2k6FV zju}FlXL{bSOQ-6S4Q^d?{TrP|KZm)Z+poPyyEWtBT|(27sFhetV_jmf`r5jLo2V_p zrr{>KD+etbZoe!$7P&4tv;m{BQ2sC0E^)GJr=}&p65E1b3}%^3fey5T?W`;#s{6K3 zYV-#*FJ_IuMlyXHkWM@wX$#W*mHu_oKyI1X!5t$vL{Ee=r8%4}8em@w+rqw=n8rhc74NYkuiDLch#aI@r%NSUYIB&UOmFPDQ6*hUsd)jklqp-TM3Mg* z7E8;qD~2^}N0Qa97WH^gJ4%|*uT0gKys-_zM+n{UpEO)Zr=QM$QKo9M!rN3tZwNC~ zWu}9_lNAs3CB7A(p}(smqb#EnuPJBZ$4OC|7XFL=AWe5MX*Avml%0V^}k8Hzf!nRSEe14C4=6h^mQ6q^y)+Svs6=sPn5E8!lXe4rH>`d}^`5$U6D7t3(sPqHU_BorH z`c3w`dP#X*93=a%u8H5D6!RvckMwDt!S*k~Lj8NT4|S0<87-mq#BFMCv>WnB|0VPk zE!ArMb(jz}#q&b9Qxzg#X*=+J=qNOu921@ZyOvn!C39K`MP}-2hgYjUG@l)< zX%uWj{VjKr#&AvabG`Sd1bIH16YHJ0F*sZvravE9Di5J9p~+rz(lRekxrlngaBR8q zJvtgWgV|ye&5MFAB7pahP4T{tMP(j)Z@xR)o*!UZ8_Xg$>!x_tT6EkvWCq?246vic z`E(MeGfeUCu?}QU$RC-CQZxB6@_|SW6(^>+>RUSr9p&NFa`i`fGVw~DC#__*%P#2y z^Hgr8G@#RDy?jC&jz_5_(h_299Dfa6?1p?1J(!z~*}$8#j2INYYIF)7ZZKU(F$=fxVaNx5AKW)|@pqan z;t(}gieR76E9_593vneFUW@gvnMJl>n=p%{ zW$?*^pp42Pf(+(lg5z?C=c*kV?AK9N+=P!pRtW{v5D7#71dqmj;M3yjRdl*!AV#SF zidC3B;%~~Y#4mCdDu{>aFY&)L8@R>>K)#+}4#juSgQdQ}oNbjBFl6+wwjTEp-vTzPa;cF1hMi2l6t9ve6-KD0 zn;&UNFBdQ4qQs!<(8@{-Z2%qfU&O!2d14CI7dYMn$OWkjc~30DzE;4wn0u}i^c!&zmz}F6tJ~4#YLeVHjB6w4d|z{@A3W0A+X0(0n2YLI$R!!4O9kW zAA!LeEw+Ta%mp&W5t@1NjnqwP2iisYjaVb+suEUH>VoAUvxVklAE_xg+nussDIqNK zKCC_RN%{+m$eD;d?#Qv8b4dO;zqFdfFI!rya%a1V;^n zO_Yz}bC7k)XZ2^Sy>K1fPE_CyliiUS$~2&%I*0{iI`VgHt?`f8Rs1lygmKQ&#NmtJE&FF)Qf!{`()Jx(a<_-H0FO>el3dAI`qxb|pi+6_aFstbT zPd$(mxIc{dBX^;b&7(bIc5A=jBr=o8;@08&kd{g_c)nAB_*#sWv#qpuV>WHg_|KXe z+#UKM-;yeq`eJw3rG~FU-|H*$?@=A;iZ0OUV;9L=_{#XdS_gj*(%IUw5AP_^NG-ep zoQbrIxgP1R{acuaC-EU>bv%;@V~qqoJ{kR0tVOQjml3J*ERsd>Rjb4MZZTns$-4THZ?v=cZYWAj5w+-C@>Tq}CY|4n zZcw_a#V9Fdk>kJ}Qyu1%Rl%Oz8al#z#P5h-!DoW3q#EQL;v>J9_HjqaSK>N2 zHL1G#5PK_B#7Ch=cv?4B7$aByUqeG|z}$;}$J`O7pkuHH;u>P3&;}nWe}@&R&*c_G zES{zG`dvw%Q*fq3QjJfb+AY$@p?iM*V;|VkcR^bHI6mIECNNCDi))TY<*W2ezCSvX zx)ObF!1;7#58e;V4RNhAc8Mw%kJGogUp1+4w-WhZ!2KYRhs7o2a=8*rFNZ<bnU~>E+Hbi{71H3GMLv2y2iIt!)|B0QZ*{jT9Yw3D{)AWtbFHYg=Y1Z4u#eiNXv7ty!mn2!kXp6?Ux2n0taK|Sz~0qqxaMd>%q0~g z6_gt4R5U8?Mz^D13FDZ@@s@NI?xNNeT}AfAe~R15&lT46b0|#EN{v{8kPKYnC{b0%{ zP-viX+Nx`%zGS965*bO%<1sM1Tt{+IqcRFP2#NM>^sX{lA&>;HZJ3c7$~|7zitySU`$4Tlk0*5Pe3ZEN0eN!vP}Sg{ZX_w*hEeeKE5X$*Vqlx z?;-LXRS_Qp^X!t|z%Dddtcjgbddi&m0QHGsB@34XH*#3CC})sTX_H!C)~l11W6h0I^j2*#`;SqEZ5c7@DCD>rJ9ihSQ0^bqGMyaRC_F%?- z1!faHcnYQwmEj#c3H=Mw<9T2`xU7DS|BNh0U!mZU!&vkk{wGpP{e^fgOe5z(N>W1> zung&t+!q_hKb04#+r%sKC$*YbrmW>h!wF|6Ilc0;m@AA|_Hp;%)UO9ZJz+7lvZu1+ zIf326ZREZcLhP@clZ~^D#HRc(c{YDT{YTubyakfG1p5i?N6jL%)In+}{gzH*7Bk&7 zgJ2fXp1!~YsG9UbihPWM?-9q9Jlls^dnNxQ{y>>o-)cDK+&ERJ0lUc`s+e)dkR z24{->5wF9o3R~DakyvCgaII^2TU>;Rkb$o+{wmg2^5yZ!N|nO90A2>!QG)#WR3-C$k9Bc%79a5p|U@rR^xc(dJ|IV{qi`t+~@+T&v3_c0h zlUK+hsy-8>_G)P6g=UxLp>~m`Ov`GHYCdWDGL@OV)Oh+hnL+gcW0Mnnas<*J%=qKv zH_}DvnlMCKDee`|3x56?PYHd+GXAzuOW43Q6pwOud9Bce4R9;jY;JSh6Q3XDB3ZF> z!P)FP@6>SA{lItBKUg@DNr2w*P!7IB$+On=gx(s;FP^q6j$CR2A&Gh1I`xUDzp zpJ~>^OuM3H6gZyGQCTE|KZ9!m@=%~iTY|&yYxP%{oEVgoN(k&AH-TI^Cl3X`vjrW4 z9ssw<8hjmi^Df}CV2}J4{*L$!RiQz3KzWHzR=%JwAUhuv!J}(Rb0HVbX7PPGdT-Gjmm2VD@)bn7t9A`A6_g=H~`0* zv*z`t->tPR3#`0uqq)euOyAgkz<(= z#ah!s*}t(a0*bk__7l^VUcro`e*v-v&YO`>O8H_{Ia`T{y#*uxnVrZt7aK!f{1S{O zbEW^aG~DV8@ctem#*(?%cJTjJB`Q$6;JmD>V03(`RFv1C4fLqgnQ?rT03tnL+J4dJ%4aJfIhrEMB zOxdb{*0rtlfXDddN%`>XBc)m1NAl3(md-APe+AF*ZMj8shS-Pv0xsX-CZD!e(r)7p zi$5XDa>Lff_|6uxHcpJ$Z(CY0-_YxcM#S&hJ1~=~FI9;B9Ge+_%j#qL=d5_k#1g zUtpy-H=O5M8(H8(U4Q#;t-L-RLzR|JHLU$mjn|M2dRlkfhYT|eqlO`AH zXh;MOr+j66koXT|1f2P}iAoq?eP=suI>m5!3Hm~*jc>)uknwUTTqCr?S1as|o($`v zgMu>x$9-pgOCm);Y2-tzCt2())|ZPBE!k8zTH71A-7eZj`e~+zHm~hZTQ($5q_vaz zp7nvTzje7m_&=7;0!(VF;lgnlTi*p1cehg9-QC@-K(V65p}4y{6nA%*BE>0g3#;QZ z@#H_l_doN@%1kDc+}tEr&N;7ZCEU54*c|jq?klyNDF@wKNyRB=DOcsiYG1j!L`x+@ zF70Nprtw4mVXia`;{ga94dG2W4sKy?+fDX>tt?l=enZ@3Q=GTC-nRZsJJ!qWBE9TH zvOPG&QusgPOyIfnyR_H$i#{u@et1ioEo)?OW7{ko zwPu-@q?^I3Iw3E&^7yy7i)9tcJRa!pNevGPHV^g=&GA%G_9}&tN%t@PiY>}yun*Wp zVjTw+l^#7Z=C7#kNJA*)3dK!wWyET(I?j{!-NJ3#8Md0O6?2pwLn&s!Y^7~B&ZsSn zt$IIrHk*dCq+#kuWE{ntK7G3$GI;Y4eS$j7CNs17M(jxOA#^d-?OjEi^P6p>7_j9< z0#GJ1+7`>lGn?pEa8TAXVvy@y4Jp*%cC!tMsuJb0mvC*s80=T!9jl04xnk@g!bW3#8fdvT$$t zukbr*Sn#Ww)Av@%<}N7r_qalpz2!ZR-OJO+jGNz1qXDE5))xmqA>;4o!W-0T&b_*8IHz>O@Brrx=pHtz5K z(mt2G$u}cd+wJo{&1&vdeV>BUJukiQQ|G47Ng1asqN{Ox8Boxy{VEdp_?s~+6CcD) zh_gA%*%mo&+3#=!dqhx}s_Y1ICplD~8mR1fp5+Tn52xyTjJ`^ovOm09a;mw>t$c4= zXUB0z2bbjP?HuR$;h13m?&u8%Qz8_YzcY29%>PsVL;0ee)&!$~T0&*pXZce?YW5E`mCJO}Ijj0yRB>JWd$ICF?#i z=~OjZTdT=PIq>M4ko|C%EJfd7)A|2|f5c1THCuaI70@v^iqpW?EQOin+H3>*2iX?- zf`aB0;}B-_S|fSCDty(gG1_pH3_^9+mz~RXfve~PZz5wPnRjtH**3^Aq$vq`?-!B7 zGX*_>aq3*9B($WId|v7$b(Ho=4mrCV4H|hJ$tCR%UkQ(u{*uN>3&Q8ZE0D4vhUNwP z1s4V?`agS{`|bxy$%W95-$I6|X3R1=7c-K5$WeR`{wx2OXT@63g8V14VhP@0CL>ez zIh>J2_5Ipjt-ZEgE2Mwak0Hx(m{|@C0H-kcY63s86Plcc~uqVtOB(wcY7x-1FK{CixJ( z8$VPV_u#UxPxl7-U0^CRbC|QBoQIhVCJ$SbozHG!MuKm>mn6E(cWk)wd$an?^pj&>#APmi84@Om2GkbIZ2kKB62CYH^}fW~rU*Fj zvoZJog*rzqrDhH}(IWrsHEey?iQL~g; z$Lxmm^4aEc@H#f(=QwkK*%H*;dYA($Wfn0Lv6e5!3n=w38z-S}-)^imHXBEvHau#a zHSQWuj8ua&bDIs#_SlB`<^l6QR7csN^dAjI;9=_-XbE|cnlKGY`A^6gE=qPI=aA=- z70jSa>rw;2@7YM5!s9aZvPY=1)Cp=kuKCkIZfQ%^ph{3NR2mX6Pm$X}Fq;l%d|fa~ zvO_%U^X%<%`^ zW@KlknsMM-i~?ogCB{1oBjbM(7&OlboyZLbPEW9;x09F2ZzM9po%>7|OUNXdZdUY-ALi92vxCT#;X+EpA&5fbPh$gyaJ42wMA0$AV$ctLrLa39|x)8P|%AptnV9iRsk#WJ)<$ldOqIuILZmMmz92nu0vm z5DNBRt#Vc=Fw;w*wVV_6*amWp4Zov7E8+2o@KKViWRO!zgX&S!ss*}C6VN@Hj*zxok9+ zj=Y}>{E4FAW0b?AET|M^QNHCsrYMKBjH2K!6h>+1!SP52ohTkTAxZcvJ3bkQHOJ%m zzgMHMeml0pj{mdcJ>e%kBf>F(nUv757Y3+`EdIvKAfC3LTbdW!8TownpU1l>zAGQT zs|3Ec1jbRyf(%gzsUmr?g@y4ffNhQNHgjVOvf+{Va~pEtm0VaR!t99pxpyw?Uxe}* z;k`zv7!1CbL`rr5sSlYbvqZ zE7^XQL{5}ZzMrM}zj}Zyer%r$`@~}%k3^6vlkGYOxFtSiFr z$b&MElz$%lmlJiRI36WXKT7@N(Nrbs;Y_Vf)BxF}BAyMPX{d~MBFC{9YEUtJGEysY z;2GK09C$=(QRH|<%JYBqf&RH&XyD^D2g=il(&ezz?Z~oGKTA1M{~}ujllspx zcVM~5F@_i$xi$pW_rJZ0e3rm9hQxcepLKwNo5es0hj1)>r~w|7d^+mEH{`SYz%j`{ zjmW^&Ck?)CIC4xV zga?8Nem=p6M*UNNuG_-j(Lc|>A~=?X@W_v289Ac)eqI;y;TRXfvjoUGMSmXcY`A(f z!SSyJ0!%YpuOnCQ&ZrwTaow&0dQ)k<(-c38;}Zq(%#ULkgCiC>P7x~bYizj-$13*! zYh$DiCZpaJ!?NYEbTXciD^sL)7yWsyD35P1j5Yp>@2HBU%j26OSFOmarSMmz9!IVi z`LTus@Mw5^Ixm(ihU-fKyqX*DN6zm^*?vO0^jDO~dmO7YlxI4Q2?0F`gYAo486tH? z0-KmZ2}bI20c=G%d`k(`#{#I;(MTtb)Rlns7RTT**7p*$sR*6+9PV$2aSynL_rCmO z_?-Spq&oZmcY}xE`P@Ky(`jtM3)~gHWBK=}J0Gn(`1u9eja9h241_;^I5M;5;rBXF zb2nkk{xj-AJj$~&$~iy&id@071dW9_xp*>Uxd)y4|md>CfjRHrhD^i>q z;RqGOI}*0^5ptmRA_-|L?o5&U(=%*`8|%)FX(j13RVHQx`CS?jId8?*QQFmH0)D5&R0zWsO{Hy6kXpQ+|(QzcNFi54J~`7lx% z($^S$ts_PeGnME_bu_A=6_per))Pm}p_upGOAaDCfmKnATua%_(_|(XD&;_Zf=gUG zY~H|k_B0+Lt8yq=l0e1~aO4J!*l%tyoYqkKPm4rZOJoW3O1_fI;kqwQw=~L8Sw>5m zr-m7?$uU+FY7=xu9mtzd@ts5)>IAVH`SYF7YKo^$nqg}KIgzTYHzI3O$E_jMDsv9{ zA0x~MOn2feNbze(8%UB*$r(fn$PL-yVgJ`S0BXo!D}Z~=5NO8k(`CsCR(7;n_u$;@ zNf^jb`;}Ti90Ae&A(2S6UL%A7aC)DpCa z%YgNkPHsZ3*$m4>IdRnA^cJf;(;95q;pQxIhVcM3q$zn#KWbc{cEjhrlk{0l>HXFN zVug9gTA}X-8Mgtk!1!v8w1&_Waf0$$gXzlpA7&Zr8~FfQsj})&oE1D){4{U#UlOwB4)_z zAdl@UwZQs?>P=Q7WF(lj(L+Qvb2Rmz_0|}tabyWH3QU8(+6wKixq!{aEFceqWMf{UE1BJ_;i!vc;j`+jmoZ*U-mpb zlPyi=5Sm!ynX|@PDp#WDhs`YVXJd zr zoF4V9cF31FuKiBT&{rA7K%UP7hRhE-7a?#tsOiEra5eK=m!Z=*V(bUItd4vUy^)K` z2rH9(Lw2Q~k{#($^naKkUWz2Ja@az_*rv&t{eMKM))E};!}?JqEqu4OLML6@`p$Gh zPM=CFM{09nW+SrmHmU)PlWJNvYcM#nSM>R8AuE;}Nmix$kTtFCR8D$|m8lHCRc)vl zOC2GmST&ejWCwaP*g1#5O&lz{G1GWOZ($6uav1-ipQsYYXs>bsy6bCtI@8X4ru}Mu zWb?wwZim8eH<~98srr$@*r93-4Kz+)%d;;J1Bc1U@uv)13} zOb)ZQU~Hu}vEMjP&M?mrEV2R%Tlq02z|d346Yx9sh8D9ib&sq}t&{zFaVxL+yLF9u zhH;IudQs#yjj@X2lRNaON>{1{=+5nwb=D&45Os-`={MvHeLNYBWQCUC=xW+oP1A;% z9?NS$CjcFg4X(!AAdvJ#J@Ju+i5p}qXtRGaiV^L#7(-JF8C#8XjG8|+b+bL)19!_& z)K&D9HxYfwTbT1oGk(WiX}q?`h(^Y$tQRmB8*zFUvp)SV*~GX-M$w1SOHCsen-7g8 z+6A-=n!~F%#4Lz5L=3f-SgPMLSx~JuSr@eF#_#GJa}ykZD?sgOKpeo;Y9*1ET5hpg z9*o$71!^?c4(aoav(_$LuWDIK^*_zaMhCJx=*Sn1V@T55L3AXK!1K6=$VJw-av2Mt z<~eFS!X2up^%&ZKVPGQ3#5~f%SY>u=u$hl|X^GHq?1lH!Mcl#>`hb4-0KG3W2Nb$o z#5P=6gBX7phdWv?_>#X9?U2RR)%=E3$7!ZweI$1yeQ^Nw7D>MA$(>M7{?O`~D%D8; zL%V@-f~Pp*Z^(0?gqo;Z&8%bQeQT2$vhu+Jln>>)#d47qz)P)cR3nR0&&*766jJWm zo0pLAH<~V|!w^ahKtIGH7SbMKm{ArAn%TxuP{c&RNfLvs{agBXV~gMpwI#S?>&GE~``iZW|VyfANC3#?n% z`$g1!GH6X9CXtzFC6)xa#cM1g{vxYdS>{(Vo=PTukSnOW#sJeEN#``W8mG;2)+w_S zMoG&PG*uEbyd!AK{ew1v9kk#U7>OHgOopms0`9XTtcuV;ZGeiW5|Z4%fwWo1nuO7# zr38`@K{dKcEXN(_B1UdL5*?Amxq)a4`qfa7tUsH(z)b9f_VhFB4jDtwBJ*Gbd@)+r zPq204kds-MEKViTFR_NxAPknG%29VP0^i?SPF94b=@2v~rNKZ;L9SpmGMzYv5xA*f zbB-e}5PtX&S$ZkPDnCM|8y~%R2ndMij!h-b?E)Qg#X;H)1Bn5aWzyMoEm}w6fC8bLIqV2kz;gtt=@eWYc14Ie7>6FouHmw{R&Y}85Cr%rHF3WfEv5?rVS#Xz*LWNMimS@| zxWn8S{t@4X-^XX;AMj4T6cXa{vqPBT^h)Xrfw3Q)<)8Ew+Cz|k3TtKcG<_112n!nR zK>In11l;i;AKh1VH9=XVj8ewS0r{I^S6V4Gq#e+|{0Mgo*9v6@t_Mya(Po6Vx_e*N zsLVimqYSUx=DQLqsO}(s5!N}6#(qyMo^4;YXNd+;#P8X`2PGU{u1^8Sw?6_bithSed;K3439G_xQ0SW zTX)Q$@5Ve_QAZF-AtSi2Ojo)*J(4O;uVKbCam$uW6^oT`c7Z~de6 zlBXlbc(t@aKA_mu9;&9E(wth7whKJ@KjnthFxboGO}S8JzfDxswy*W@2*a`;uKY4Encm7n)3aN|?{p}q>f{=SO7*Pa+p zLHDf8t(n)d3VE*jR)q#ASFFKIRd8C?!a-az;RDo5eVpUCFu7S9hj}r@9pirJ-s7(1 z?-%MK*EY^TA%Bn^%jFWL3S;>W!dK`7xARZ99b6f15ONwH(Ye7<4PpD&z*E+SFC#32 z`)-4n*Y*y%)*XfB+;8l8WC3PvW{d<&)AlBG|iE5 zC9|04Q=qr9#9Ys=vDJv475@eaE6&7$u~gJJ(M=99ZiVyvr+czxZp*CbuIAYtm@lQN zd!eZPm*tT`A=+oy%h=QHP3@cDdaWYP5PI>kY$6-SDNH}6BYTN`&wb@9fjIgAXXsrd zA>_l+9>DS3BQ6&|ncvBO6Q%%i#X*Qi3%l78$2rang ze1NYgbQY@#HN^2)%V&R=PWP@HefC`JKMO6%p>(s@u-m$2&y( zAC86gnf813WLsT(VaGySHv2qrp-^7j#IF`U^BNl^wB_^iqs0TlBtBU%`P;&uf(yEt ztF{I9ey$y^Vo^Jy2fCU&DSKPtEjtcr$Sp9jz5XSkbK$|# zdg*icOOOrK^-gt@SzFWoN@G%&ri@6bmr^1%m#1hT*<3?EvUiF{MK6PAL-OL)^ae>Cl6i`k-urv5R0%ey7d;fd z-FnfFor@d}M_*)Qrn6m`qO^@3%4}k{@z~JvDn$t)z_uDesgtn zCOIbC3yX!h4OC_8s9I7g7JT7-;jZPb=|1Q#=@0ljgqDZdP}T5;Am>l?(QekgFe5#^ zP{x=vHoZ&w=+xNM*!0@L5=tv;xVYA}I5B(9{W;_FxUw&fLCOPhQOgs`pUJ1M|G}rJ znQPp?1)C{@FqJTre#TndIxbdBwokDabyRY`b*+j!6;CE`aV=u2#?Oe#j(mvTjNf{0 z3z)*zL9-~Sk5-v{@muV-FKMwWcz8MpIAyP&F^Kh5%biuQky^l?`U^g zX8TMotBmKBZ$V&4V0mC%Ah*x$E$!Wv^*O6w*2Q#3=DxJ!>D|(*e?RoC`X@ep59Y9k z@U5NK;&SCp%-gHL`vOf8R>hF4S1zL!^Yl&6pI#|#sJn;FoYG`6#|4S6=ubEEMVptAwNMM~tzwv6|`M zl!o%<@Qv`yP>WFe;P*g*Kp}7Atf}DBjm>%i22F4GagXka^FQ(|^Lf16JiaU=vsLE5 z-si!M>Q|~XGuS=^D#ywRuEe2HPwgYQ-PALytNO2UPZC3y!tdm5;nz}=@Ez&6+)$aK zKQw!>U!&3zZ25u(ixoazI8VOmxh^JFbdg*~LeS^yISn^@iidHRUCia>y7Kw?#q2G1 z1N#-!l#awp^R}^6%R+8mqA}9kZB-+^$Tpp6)i*0@i=~UfoW2e|#a+%_4A=a-9+xMV zXE-tlJ7bcD4b(+y(u8oOa1-dw4g}xh9`j3ho77gBrEdh~(#7qyzj9EnOy^1mX`3Ld zVDCW}a@)Fr{_SPj#SP?}+s3%cMQ@M&Ep|g}(b$ZbFR}4aNzP4TF8(6&Kua^b(F5GW z{Lby*x8t7CjX%XDg38&P8g7k6i>0!90!h^e)gD?aZHO`7e1cv>d(2zkgpy#N9#S4E zq?9XsE7&PiI+W^P5`69-7Mdfyl;%N?v0dIQ!55>pRz>BIGE23qCG?iMsBKc~Kxt4L zZTj|BEwdXa4FnYXdLSC2e{gt%y#CyY=0q}_(+%uMb|=G(AL;?+;$T!%EF*8 z{Kn*FRwHN535VuhN~J?=Z>Bi&J6(nG(`D!l#Awn@p0Hk_Me8-{f$^AWtkBCD9O2h< zS`GAO+E@#ocpd8ya)U1+Yrhrp z2l5f3X6hG>X8JjOiy37V*SnKdw4&raEgkOItEVb1Y<3=2402z@HHN{2u34j8S5}=-38i^ zb8I5H5~Ihz8wu2JN?pCA(a`$NEwoNE>DDJ|fKrgWVtiB@G1cTxaN%rYmTRM}i(D18 zHdI-ws3Ard>YY}{Oe4<_bF{ySV)}FW2>M~@dr@l$zd1!MKp%#_ArZXbred3r2-YWw z@#Es=Rco1=j%1*0$dAd#v?fMdJY4N>)pkZ9BOcVsW5f=1fqK)7V(L=Oh0?Be)Bx!^ zxtMEj*09f23KOI;5V>PN%n?>@vV~TVyFiu_TgWndF#I>M+eoA}fe9ZZRdP7yZ@nn_M7zOv%$O-Wvfz;63WUGM-14>*u4#ps z0Jz3&;7!dKJvwcTfKdJcn0e zGyh|vzqp6E%=eexFz2bpL=PoPKk1xm%tAt2UuJT2D76D~Of06m^;+RGF&)%ZW^sCi zlE`#ocF9H6cVcU*HS}<^qL-2@gZ24_+DS)sbDDlb;kdJov8m6ff>8~FWyMpT4l!++ zki6VDM|fGp0C-G~TlJ;YNE)?g%HETj*?betnp}jGHCI2BIv^^-wPu;`r6vKI@aznVQH= zB$|ejsc7D(d}hZJeYKiwKr2UBkX9Hs)fJ)upN&3E`^{NqvHB-^HTRwpV_-Kx^w=4 zVlJVa?RMs>K>tG5)uJB1(x2ZDRW(@2JxhEY_d%+kIZQB}quKM{)65PDcgaeDC!umt zd+GM_{LpZ|l&y?eHFGapBw?)D)t!sF7gJo{5xPb5Od`EaJ?HNgHOST1!(~p0>tY`k zh*o5xnq!@IN`Hh=@^#1&<^!|1SRs}4g-3@Ax;}D^q&Hcm9HpI;jKCk`E0$2DxUgky!`|_M_C~c*1O=Ec)$H@%A3&covuyP=FnSjO4p1h{HL>$Vy?xxy*EqR!DwJK?FjMCPPuB5|HSD)mfse2#5dP175R*Wz(! z`!+*BP?g!_R(ana2@gy{Rej;BxuL%XV}(=EBGAOM?K{}X8R;Kv~qzg z+fMGAdNAQM}%eH&Aj$NbXl0oG?E0P2um{4 ziKfAt_90|HX$#j0g!abf#c*TNW#6dZ2)(0ji?1~)P)Tg-YLS&rycK5~mGsXNa&SyH;o;CcZ)}zO*sh@U(1_lu z)MN*6yTae8NoE6L57RSnoGEXAAL`D2_b(S5&dcuqm}3MlyjFVz>%}cFy9L`4$C!Df zfU$beF|`q*SGupg5T3Ag1JQ_MsZ4j4bL*+pX0{Hq!|ygmC)AaF?r+4|#OtQ{Jxnih zus7kPPxUE!93J`>kivpiBYxaa=vp+4(6Klje2d2u;z=0{adNc z+%|0s(VaR+mIjIZ3-`(r%oj`{eVzK$s9|dy9u5b`ckQ`xh&yH7Q$Ewj=oYd>wBr9m zBW|O)gPpF;G}n@MEY|9(_he>iSBVO6bsgsOh7PjD=swyDy$`pOZ5Z~N`D{ydU42M) zBib{KJ*Dj#$kUDH=J|$;zY|NP+`d)) zp}4pfN-f8Uz;Lo~Y}1VI;wO?;6NEp5wUEkS*S_%0rT?r#jK{2j)ZtOgerq4*QmP3P zELon(o-rO%uaq2gIg&z>#TS?mqs|o>DVKSNFD{j( zUx96vrtY+A+SbTDnN9jpW*=EY>cxeLDaLyx2j7OQr=25unha*XJ{mtPg3V)GG@W!F zO401(K&G8`lI%c;)D(Syb%cIk9)Ul37*oYqZx!dJX{)I_nw=SH*tBZ41L`ceFOSo$ zF-td-d#R7deNfYz(JbjQhA|T{zxa@@Wj-NRg3UP%=j8*qz>5*Z%wNdj^aN{zS&Xl( zb)tGui;ed5D6?B25FL%hR1flyk!GDB zOJWA=5p*s4F=yKc`r2dodj=RHzCtS$KlgOL;K=K-W3|iZ^R26avypG3-ob*0!ouf9+&TnjU zYn@SCx~euYW~q<#cLYk2B#bg@im|~;){jEdU)W4BzLUMl%f>UKgOwMG_4CveWY5k; zGW~w4A6dy1tUBaKxMhwYk);$d4=LCEFr#%1>Viz`rg_5Zj%0+hm=nK1Z4<9B9k@(I zg?pr#>V~>DKwTz}(%bAKnJ?U*+G}e!Il!vO4bqEge=2U}5Aq^r`)`p&$PdgsDu+0o z`2l_HOe-IJ*&4B*yav5&B6Pc(tilX1 z44ah&kI;K?d%qByNS?V(j-o!2_srs;3?75dda7B+m<1ioucib_nP~PW20?8QYp%BX z8b2@_`-Xt37fCLg$vIYGVhDPr6g15xpx>H>8Rak7*RkeYbEaXKpFnVMg4tKxss{!5 zTJZa#tm&Ae8VD73E~1!anqRGLWC7xD=!iyG*`beJLe@238>5Vkn439ZPBD8}F{fG3Tuf}W6yzH$u^i?LD4&`VGojKcOjLkYZZot_<)OSi1`X}MQ114$ zx*83jR(gt&sw3uZV?EU3*{r9=9Acv-5RZwsmX17oyY(CBoDYyvuou4Ev(`r_e?}mq z>!taZ*#U_&H;rGd8|Fl-pZOI!=nt4{inu|}LP7izitN9P|B&N<51O&YpfJp_9*{k; zKYv>r;oKN)J+jP*6VM$kBXUCvm1HeNz3v3f-4?4PQJg%E{YW>Xt^ULg zqB{8*pPG;CmQ*4SybxY;GI)7Mp>a6|W%(6TAYLO?=qM5%zLE{B3xq;#B@?L{L^GTT zG0?3%hC-#ObrUqqBJgLl0KpDfu5><V!>Dv8~4OXh)QUW?2V4?>#7cy{!di9jMK= zLszJw{!PJb@eXS;;l+NH0Zr{Sv}5gw%}@(I#H%4J`PHlnf?s9K^Swp(Xj9C%w}Kky zidEaH1m)OjD9%R{(bPLYq`!b}QIFD)ZYq)Mh-YqQSr_OqFr|Vm&jvQzHR}@UVtvf1y)?#wWY7=u zXBVJ`*Kq%61&_fCV>{9+CL6Vk{E@6amt#FXr&b-8S2cRIVX06BUBMaq61E}mG%KIPKi(>RWv<4dI>Wbd3wrz`U_{J@b7nY*lNFc}%nu|= z$I@@8)l^aHG0A|q_7`yqdiT6$apNei91$kwAa#jaLhY-rQC-m6-B2rmm01?t(aY)+ zb+4k!Gv)QtVkAzs1cZSmyI~>hR*&X7)jyh2lV%o%taaChg zq$sYmIp9efNY*slYKN4i@-68txTUS6h4Nvgnp#8qO>bt*1|8!Nc#G5EX*o!x(4*k* zt;{)*_?`=V<<D#8TVT6_cHop4Tci|xg8 z!aDvjw~CEr2Gjkig~%&_UW051C0Hyqo{B;@7 zf`1`{^pVTSuH-2wIeWqNyVTgEk4OJtoAO+)BX5+tN#DX~xre>sL*Y>PS$KGOcX)Sb zNvKM2MWC^Nk5}}zcK6Q8o!K`tTUN5?wy$__ntV!6q=xa$>`SAb$G%87mb4{lSmLF) z&N1Cwyv@OO0(q^xQY(BuP{=>VH{I9YKPPZ8xF~!`*42SJWiCMy(I?C`_ow?Z{n<*~ zS8g_cP8cI5*jm`~*$N>c`Gc^BH{f<6=)cHcpdt8(JpSvL@mI{ZRx9X6Ix=(ERPG$+ zq$dd{geI8vyUdAPO;(``(%q>yWKQw|aRW-kY~GxvvTkHP%N*xE=Q-n#4R^rISbZj0Z0>v#Q$3+k zQlo6!lWHWcjlUQ(&^b$3N)NUEQwK^{0$0Eai}FRRt(G!m6JcK#f+`iCUPOj zMmf0h{Cd8*@LZt9!Q!9dM{$T)PHZQB1IIYX73A(Q!|3;9UxKzYjA%C23xoZ!4yv#< z<}0fJ`8QRHDaPjFD)4{u5BUXrE&gBbJRDiYmLTT%+)ciY z1eDUCicFJOskbx=&wr$)(oLy}^eDV1{2+vcGt7=w@ip}nb&t$!oRKr*uZ%$EE_Y+! z?jWVaf&1BnUt_1DPsSBXoRL&1DLH9SLbkZ1r~$SU+*2~zoUJh7PJtKRMc#X!lHM}D zT7e&-7t%$woN?bOhMdd2Y=|o$%okn@oLFA0CT0`kFweGF*e>)Dl7zW%8csz5@O9=d zx*OG#Y=E<2EqEzMaUIM7F7FiUDsh)ggy!Kr8^`bE`wO21OSlH|{T|Fn9A=tB2ILLNm9J3A8pVFuVdJ@RDBj~H3;hv|D!sT}#9NY~0COws|OP|6yc>?^~ z6w6}@<`5@OXJ!58Uhm!KPYfrj!_8zmhOcYy7j-lCQNq88jgkr^B_{5T zFN$>a*WxKwpq?6ksnaAa7!0WXP5$ISgJ2+3QYxzq(FPkstPeyl>OQS9xsZRJfIY7y z>_mNggfr`lpa_G6AN+it;p=g=*bz)Sx+>aQBBA<&|x=2xuHb{UP=^{&A%9HKr}wAb%tm093ykRNC4ebE$)*!~f>;k^A~r=FyO|xzn#^bB z2*#4LF>YMPzQQZIhh9!kp$pL;;pl6Lp6W!>231p}v^HXa?$5?@oZMC}fw9z2 z(uwe?@b!=pY>EtxDF1x#LeDmL@vM0FaraZt9$!pgdDu{#W?$+k`&b<6+!j40Zbtm2 z_=oWw;uGT%qvtr5itpHoRCa5U_Pcxql*CqnXn!gHDt~G~3x-42rQu2+4e6ccZOcxs zMuKfYP-I`S_qehAVE!gwNH{6j#5F=UA(wEQ|CKMx}wsG!gE zyY&l^MI45sb|c+^X#q-BNA4f)2^Yif=f-o5xs&W}wmN%?8G(C1N#+ne7L=cAsB`D3 zxkz(KC+Ct?$wO#gdC~qI52su~`0K{ODX;{7h|2HP%ZDo zHMc*?vavcrsis_(FUqszZSq{yzjtzdMU|U@g;`uVDfg2PNYBF2;n%@D!TDgGe)c4J zI(d70xA|%Yo&|$)Wqk?}&kW@%hfa~({-$y9l2F9_JT(Tna#1ko$lXsty1w=UFZNl=wlLSJY$ zT0(YEqg%jFu>h^_5pXu8Lq(Y%EkoJ(7ik_F;A@x)8_r%(wyrC~Kz7fqJjS)}p?pyO zgi>rRACV2IyWB?_E#(Mr4ebvSfg3*3Th=`WbL}?I3wJ%=Z2yvQM^!UFP))g}$dz~( zdnCS8;@gBKiKi1+$4`%$=e%muxr6jgVxC?@DG2^-D5!u;vots&^f4Te_Q~1QXuXb6 z2ea<0hJTC4PdVrK8p%hQT+ua?W<=M0Y z`oJ2vt0W-xG&^zAng;h~f^{9ewkM{A*7F#+V`TJju0SX1A$~!>$d5Ys8%7V7!j<_C z9JF8Ig$?M6-Wpu^71}4&P>Z7FTSOV8?2(7a$#Q3DhO{tzDZDy#FqDN6x`I%}nBME| zVeU4WwcKB_c6$o?#swBh_tf^paJss1*Um-~q&d5ErI1tKdtx zgwS&iOtf}Ul^x3E zgC=FKa9q46J`!Juv%p|Bg%^T=mh?OB2OES+Z!3inC}hZu18cdOk)of_k3(PB4Gxwu zXjh-KW~0m|!}BXar}+(cpX}s*^w#R5|9I27kFlOz@M56itwDOWjt%t+zT$sfC=W9$^k`q?Xc_ zP`4l#UMm$)MFdjU(ASu`p!yFn4Y)euCEHU+QRhIY@{YUifT?lZc2p?K-9{Pav`*`v z)UR?k>2UadxU|M(BnfW9G9;z@4&!Izm@v{$?)1U%v`!x$SgY zx;K>G)ya88u>!YtZN{Ho| z(Zoz6uXa!tB|X$L)HoC$UK|po7U7*zb7`V{Q;t#2s2j8gMpr9}I?F8Q2ilrCQeEew zE=JdjZjBt4Wwzd82=sWbUQY-8T=oWcg^q;ENeiXRN?mQI{u2E3QW%}v1- z1=_}JW8R@P7~+=+eqpen@bP?Ia1-v3y7e9;IuSk1Q5vt;HLjUwiGoyRW(k{(m-yvi ztgyIONn9#xV<~1c?W4L;xv4nnD8@5%)Wk}p44>H~;vXaqY&Tn4Gt4G9n?2A+H$=X{ zZ7N7rr!&x|y$B}ne2g4)AyQDk>saGW$tYmfhu-HZ?wHNBvHBh;7v3ouYBP1EHc}m| z8R}`hzMey$uJ?iRqOm+h$qGLTP7M6+o8#~6Q+&e%G2y?|W=PR&$W6BWqDPH!O)R#ExEiBQcLMuj3{^v9}|1Y zD5@=efZom2#hpGIUy%RCALFYb+Yy`%-o}$^xw!QFKQWJ$#2k}s4dl|YAuXv zT+#m1OPUM8LQBFp%U-$*#3HE3-@_T-0{y`i zCS%M%>vM(C%dnZT);n_tRPt`J0@6HH!J!gkCl-HQ5)Wxd+G~qrBWT)L&?fA=?c`( zyF-=3&x6-c<2mUORKSglBIIQHSD~UU%{e;iPBb6WGdkp4=U6P*`N~u?qPKopt)kSG z_JvD_9pN$IDdDBkFezOgguI}+`cy-=l+opX@oC zk3G)Z0J-EKS%r9I7BN0*Bh)Fll8lj8%M5bnEVZeA+{kTB#c1y?sxR%Kb27g$3jGZI zERp_;DnN}Vv*7StYCS=sMuJgBpQ(LS_kb0>P<^0I)kf-5jI9_w@DVvF0==_J%wpy* z=69wxvSC`|4jcti)D2`X?6EqaSLie^7^WUTP6`|+x(dQ*31n7i)@33Kd9Bmv0=RR( zXTBg~T^PFl4ah7hMf5?sP7W*6sEIU+OGX3JZ-$M2$ZJ?_okE7m1+c>wq7PIF zq|8_7N$18`=1w?(PGhWcG1^Wv929FX#!jHOvlTg&elYkQrye&ult+>XBdQrR)eO zzK=o2y2!5Eo3yUMZ1qMV@gQio}?^#?|6YYp)d>9ot3 zJ(x4C%O65A=Ow-v-GPj`dgMUM{lAPQ4c7TB7)NLRryO@uRX>6s(P2ItGY&?RSu@64m z7kNvU_a;$s|V+V2Kgt- zGvtNAHGVamqO@WMTkZI#+<~ZYj3cI8_A)W0;|6h;9X7r&KULX{yeofLkZJ_d1Fgb= z@Y&FJb-(h9c|~$)KBGFAW}At3>`?K%bDV3HD=wOfY2n)BxGwtn)zoa`hmr%lVNbvn z_~w5adMCw7osLoPJiC$I$hr8rz-NEp?0gx1Fkgk+#uTJ& zR7vusA?P0H%zl;DNxh{)IF}Aq8mN2qp6G5hg0t^7oKJ=`|3dj_Ez^X(N*4eUF#^@j z5Os?#427VH=2xOLnARS>k$xW4YC`X(CF&;hnQEKMiN#obFF^Z!92X)`1(w?aQYLzi7Tbb2~7se`Z+CO6beFTLQ2`T|u+OVwo z*%*$l%_u#;eq1x4$9+aAEJexN!Xw3Z@Z=Q4Vj?Te6zhmZIg zD>3Q-%>~BUGjca`ly7V;>(Cr!T^C%7oGA;jtSggah#TY}vIg{Yo>A|qQgj{KM~;N}d7d zex$NbbV^7##0B9+*gs!~b4XRBGD<7;H@Z(IX)r(dR@MWy@s2%?qmJG7#nxJuy=*~R zGJEQc)e7=J>25fP-06rIDUFnKDp&;d`Nk~b96Fs(nWs=st7OS#t&GwBZK-H+2zR*$ zOd`B~yOQ~c8)kO&(h8xQH=Ni^>SQ}QiFwGj=I`)Vgh`gCmVuTYmT|&lehE~v*U;Um z?Bscz4H_BS^h?-%69F?V-WEeS-Vc#)6Kvk-nA?wL{ zN3iQpXvLI}(nP*5UBf&1CL?Wu;k5*p@o7IX*jX!0FFw-2!(_nr=$|F;?lFwOi^l zbsu(_#o7m*M1^6US&+Dd`a^rvzdBNXz{Gk%Q_K!%^7V%^Tr)a1^&8!~lgN=TXyvrO zp(A-%J*{>Din#~apCvGQ_u;iV798JV=n@B^w~(f1=w;C-uL4Day@<=xsPDLdUw)-- zROYIY_P>IW@rYYl=E;Al7qt55MUI2-BD8N&y&okU zXBtw}z{TJRTCNY+BDBM;Wg3&uz&lpWDSdQVg$>Y#TrMIWh; zfu7lZ{R4jbpI+bS44+$c^^N-Iew%Rf#@OpY?KQONrf8?MoY4G!rIv!%as|z%scMWC zp%v6FKskG_o*g-SQ*>S9(X)1wMeyc>(iYSN(eb;9bJGSWcLdC#=mw_{6giZPN1u5* zqM;sY&5dv(F9Scw{CL$J&_AsUCH_u8rHmq5qEkEu*R>tjeG`@Z-Q;pyVL$vVp6rU> zRw5(tTn+yBe2ehY=@^H9$VBL)O~-o;sH9~QvFOJV;MjfuCp8%#p8&}{6S`@`(EaWY zMrmW{lf**%ATPSr9iawuis*qVPIqDz^jbPXv&01rlLRn7%OL{37%z=!P(mAG^g>O& zF_c5{7zcn}7>^NK1wFY#h|o_)SA3p^%FuQ&TQ{MXI|91Q`JmMu3k=P_`d0lY#{Vbu z8_q*b?xOw$F|43#7J*0C#VMsMIopk3f6#-RNY2D)cEKq3z-e|6em5Pi1k1_O zpXRs^2?C~~ih z;PD>@XZIMo>fM2k=}5Fj7rqCfLlYxpb|)U<^HnngxQYC5?295UKw+#eqG1;pvjtG? z=5R*N0OR5tbWqM1JLS@bIytgc^hj=v9<>}{O>P8fHL0;vBSH6{aVGZ)>(IXG>% z#3-G{S1xq8SoF#FCU|(>W8AM}W zjAcJ4)$W8E+fTfnKbV;}fnm7}WJf35`O(M?C*y52u^JJ58Z6`gh#cs{sTl1?I8QAB zOIgD?bPsU?4DOW}eVu4SJ|^E_)Mt~oiC0jai>2~o?7st-uoKnS0^}Ru91@9~aJW55 zNYF`K3r_e&#G470`4z%K{xJXGKf1}W7|A|REs8_^!i~u61`hgK{N@5);Z|^sTcdOT z4^-Z+B1iWkJN<=R2H@Bb?{v30D2lABnnE73Sa$1gh?;?Cv8ESip=00cy z{)>9ha$IjU;3Pg8Z&BUogc{LPoZ7#F+k6=+O?z?IuAHKEk!G0$QpK7xAmS0d5O zN7N^Ls7Q_?w?WPMUudR0H~XSe!4o;*J#*6xn1|3C$T}-NHHP4;*3d&~iTx)Insy&i zQMe05fQ~@$Oa@|PgIO58mjzhiIpDWO1EsPZ5wXH7L@YfD$_pP9WNY)x?mi$)UDo{ z4bU0MMlj?m^C;06DhefVcfKRK)?<~Xf%oZ&xQoQ-W#F04V|HIMBT?&Kh&x>pue~Hi zzz6#)>anC*g}jb9dJ8tZ06nL-<~PD-_8|W?b`kT8YQ!P44pa!@pk1}i=tE31vlHFO z8@N05jLC3m$_r&IFVP3u4f${dO%T6Ji5yrDSp4TqLPw64fYHr)BEWCWt?t#N|*YHWqac@1d0 z9x=LMjQv<|p5ksrfWOnzEQ2*ZJCKMM@t7m{3kf<`al`{-Al9-L#3#hsT#Um6%!*Cs z0n`mf5+k5}VM592z4;T0Rwh)j0(iaKh+IICwgwLI0erz$nF;8hj6mf5!N;aV*!V`! zWL+~%?7`k&-pq})qcY~jOvFMwR3s;3Uo;VAJE0cW5_N{dn41%c>3IBE-0vM&H&-FX z1!#-K<`dj?qvNF3lZFdxFmJuBgE zRz`Jq1g`KF^k1%+Q}F-OG5-eQ8p>iFB2Xndjh_vM)`%DSU)4}a%rt`J1tSh~Ybx$l zmM-5^L~+*In@Ti-npA5k+mwGlPJ z_t<6bnukcQ>Ctb}Rm~SbgU*DeR$sCJQ46}Vx_*NcX*ZlR6M-DLLmoFX5ZQTv5I#eF z)uV_F)Me;|R3+yFb-M_-y1Ha`)LyrkEpRSq2-TXVz{?FXFJLd2PLzTwjtG~JQJAGm ziREN9Tzd%mW*iWL>ruTe0*@a%)c`74#SsGs$U&&9x57+xVU6!bq?--&nZWU7Xbw`VR0e89ZS;=j3uv6*G&X5v%x*>(bosI&I$oMX z$U`{8$T+M01$~{HK(VUGE%#uz8HZEIPO>?4Enb^&bE0ygidvo64R6g=K(>}57ZYuP z<^D~GR1@Ut_vzh4)=9b&QI7nBxR^q=#2oDeG~phyf_4DfeKukM6=_5wDtckhooV=R zms&&pY6#`URW{S=!AZ=5mC-OiV7*@gJk3iyx(K;|h&NKOiqI@rJQ@x}Xgl%^TJb4yf`-s4WbFxD9{x&`;l0gak>$hptx zYxRE6L0^Ro{u{8Djd1=PkDPrSPU>}yop3ZAgR}T0eIu0lsv2$36Yro8(LZT@^&&tK zM<5z_s7cq?C+Ks~`^k%Y$Kq0ssXn-l-6SW(@k zA4D}s)UE+1SPwthZ&cOm>hINb?I|iE57gG$YGj(#;F54&OVt}|J=EN4HEo7APn(Np z`&-+n1+`>#1(2R9utBMEf}AXE0>b4!Y8F>fe=G{bW(#%of}Hpfu?`rJA2|6%VtqbM zbV0Z96V;Gu!)E8a++@Cz;6%4S7yPzs*&acozqY-#J!BheD`XpJ^#PYwUYNrl1v=$0 zJBqymP4r8&K|O{#`7G)za?PylavV?`&MtFVz=J8lnmfV1*bf&QU|rf{n`20T?j zPz1Myv-kzdk|C@Y&cTsJ|ZXW zfSl?K)SJpe7je8<(|iVH_9{jiW?f$3VcTn_dRLvP7FFNFFJ+@LOR1s=N;}kyML9!$ zD}MmC?B)MH-@z!q!)Sk!@5tZrmI%DwTm0^zJOfzKzH+=gMxFu`+8}wHJY1fJ?`z6M zfeUH~T-v{KoHAKCuH*(w`6ZAN$G|Q>rx!rhUmK^0{=lnFCT(z zx?I;8lPSnH#68#m1k`@^IC}*?a@W{zY&!hsp0Ur_MD{4V8?jiKwZpw>Gn9`6^!R(w zrRa;)7)nOAUyl5Q6F_0?qkWKFpF}mnih5OoUJF&h1g#tDHDBQqF;p$6hLt<$xph?< zDwPyB^o`!ir{rDoQh5bbot9vf=Ew{1woG1xh&v-+k@w;4pu7=pOE6*!fdJ|X)u)Pb zEF3;!Wj8dSTzJbZSC;$8YcTS;QE|JgKEXy{1iBu&p3yCBZh+aH=TNi z4&hlk1|#&Isf?J|h7Mx{SB$I4HAZZ8;;L}5ToU_^eTm;L39>RtpCg zW^iy086n8Dr zgcC;Ba5eZMZ4KQB<%T21hv4{N)nL(JIFK2jp;S~oI4_tlR4tS%JU46suJ0uLZE7hQ z%2Bm9^uX-KLv$aT!pY(_xq-@#u1fx_wTbbeR~y4-gJRSlbZNFRQ<$cV8(8Yc^m^c4 z+tM9@aqR@G{u25G%`i3a2#Q_C#&9ROru;>|CfxrD!)+sAsc&6qePMOlVxVT^u*ue+ z)~onQRqG$iWy@Gg8A}S70M&%==->2%{>Mvh9XA;FOJ%=8QDY=qmi>U3FUP!~2LSUv z3QVHvGk9)Z{Tq}MXe}YfO>a~XN z#5?&N|67};!vo_0(;PZ*=irA~l>SZmfqzY;uE75znW}^_d_^~gdqs8hJzH}Rxp@8? zA1^!<%33B^QYxw+{jKuhPoW8+ z3ZbOnw&18>_h7$Z)8K*N=-@XfkQ4{%CL#1CB!qj1KZPHMH;WTQILyLX$}Jh9D#l7l z=q&vbKO?fPh=SBnN|I{GlrmTusXkQGv{5+0bjLZfDESvy9WUq^K)C*a=EpnE$v5Xa z^WC6LmdP#W25~QN9dW?CEei=|ER`3NTa9z0)oDZ&M z2iX}|Vf$f6D#S*!@0cU#Mn*y%tqtx_SLz-1`=iihA8zi$39g<#Py2#ZZYdC9#pT1m z$nFqpil@SL!j|x=&_Tq{>`=?lJbc^}dKY5CWuX-IIqViIi@(IlVrwZz8ia9aF141b z18uuc$}7#mnWTZV3H_RmSSdFu1yn1}Mo`K$+M>22pdQ^43NyXvM$Ae?#z9vL*8}gaCeC_UHEe+z z32spx`X0O-4>50!`pb0f5xrmc5>ahW}KV*h3I|; zy{9wmM68<4*~VZJ1ejaQKiKtUdNCbGZ=^J2dmo|D-4}J&NFyG2vNYALu2fiMmi!Q% z?*GJm=p#=GUkdHTY}0~uG0Pq!!j58wEe$pgIYKo zi!SnmcsDRA2ca_a322k=*c&InP3IHVu)?^{JHcX8ANP#v0v;T!9+(Zt$~m zA_|UDbuov0^ejfic=>P-R&yNIson4s;)NK@LbvdZKf>>YuTeLQU{(GH_kp{HXl~8* zQeTb8WenTz;rB=EJy+X2-B5vkxo50qmi(fzc{LwnY7GgRxy_p%Ln8!Uo|h;%yVIbsd|NosXPhDL87UFjU8`^nciFqXK+;z883Ob zLJG?)J*>^FeJ%ZkxBMAyA+v?hy z+x_;fjyI0kj*5ncq;IbA0l52J5zt8kYUfvf1=6;)hhOdJ!(Noy}#RqOpa3eS!ZM?g^i@deH|9F07l+JvV z@hUT%o+sUwvg`MT--G_FNvAW9dv|$Dfw|RL)U@N~9!BAGUba1OERT5Q4oBRN*y?=e zSm3!FzUr7Etyok}^^K4)jlNu#a*P4e2LrQ=zhLxxJtBx%06r zBch)Bl)It3syo@e)BR6GZI{#j#d-jGnOnF7oGB>Q2Cav@$d7L0blVuY=}qp2z`@>T z9rXT8Ydu>>`#<(w_JVe&ZIrc#C0>}x4`;_Q2GtmzS?Tb|x?|=sJX(3RpE^u=52ivL zF%<3^Dj(bt`0ZWc>ke(xO2NT_mch+|=7Dd4XZ~)1_WtVrMuF=7S;5V|3V|}1XZbVp zcu+Vs0gJfu?j#>NGmxUet>j^Mk%hEvCjPNy=CyN@*-#mTrmn z!v#ZYgU5Y0e92zPSIzgr_uJdRp8ynclJ~5)t2f4b$XCic#_#p`e5br^JYzlQ(^E6P zrfo>Sj`tnjsy^LcF1RM>6F*7K3?BD7!0hJIh1*uYwSaw_?I4sx=UP5l3Jd3W1Ky+P za~dP`Eoxo$WtNw++)h3Q=C2IgDQ~nRhEo~L2L6wbV(n>v=g^%$TyG*yM?8sW>i+I- z9+m3;2_Mo=_J+1YsL#9;{QO9)Gj8Y}PiKEIDnl~+nKSGnz81e-sKnplXfSIUK%=y( z5MZluCOewT%WYs<(A&vc=xMJ86TO)^Ngt)BXpOZp+6CpCD#?eG(^4~~ozzzB7|!$` z3kG~!f{OxogR4EW!W9Ftp3C7bo+je9)I#BR=>>y%{F8%ueYZSA{7*8|eM#ODK5G`%v|t$tKGWZ){>Kr{6%R7M~1K6Ww~I-S95C8&#NiyS)G#w?4cb+gi4O zX}SStW;=fw`08))*L%Xm0j=AdxkV47x6^m%xj^o@*e=Kpb!3bRy%I5W8LLkXYB93i zi_m4rN4_&Y8b8gMsL~DvcFCt@=nM4g>TPJ`?=css>-FO5zbdQGR%;p=@;>#Nd_|2{ z6lJy2Qr;;$rPgwDRJD3ZOQ9`TR=g)N(nl#4zJyzmdu)>OOXH;B(i^d_^c*{PX>ov@ zUCyWe1lGAH+#()mbM$@SS62nvrZ@VL58!@D!b5f`oYoiMwU%e9;hGLH1`}esfN}Q+ zXQ9#T1f0|MvT5vERHvNC6cV`>P)+{8c4pVJaa=sR4mk4#&=hXU7Gzto9Z;7U$KGYi zu$fF0n~$}z7nw!OD#p#$!0+POCqSScL@t&NU&0#917yyF=|Nzydw^Sx#>z5^ZiCLLK`P&R4@t4Qdp*P)#ihH%d7@Xk; z@*4G&$lf`&mAj%nhH=_<^UH`n76B8(u9A!EgRUy{VZp zT$v;G0+X@4Xct$6{sg(uq;RgVFI+`>p-zEXM*vDAy@~BqEmRpEP!s7A$Px0glet9j zVk>fO*^2ZfY8Uz%i{ZA|o;Z&>{toDT+(K{gp>bN?}cu5nCwd}glltt z&d=I854!_r+IrkT4)qOK!zBu*g=C>6)SM^rL%25FF;oYhBPIqycfAJcD;<&Vr9szn z9{8PSp>!~t$|O}}B|pHmD?*l|VyW8D)6Y(%peyhi9Tp$G0bV&N+99Y}41*)Z9k>vz zG4DWk^FFZinsx@`^cy+m1o#>MEqxDfk`9K~gQL#{?*``t9tFPpjzN{%;alRr;42Ww z^xpNo^I3dXf*V5Jm3_(-bO<|B7rAY4i+wEIv$V5B@|}g&d{=%fF#IZ$K^>rWqldW^ zXU!?{M7W^;3C#*u6JJY&vQuI7l0cK^fotg$%$z!j5dP!~rv~c?}lv7j*<2?8k%x=4N(csZk0l3rn>sT6blN(n#tkMuzhT4+nSnG6Ow* z8UBO*LjHdOdHlPAKJVy&;uSr=JyD*Zce8&cv~m+o2i4Hh!It8h9d#!1Qnc#&6v5ep zmhs$rswg=Zm8Q02S}E)Z)(qJL4}zsam&0YC2~ma|ML*;{)@)8|ME!^w z5t?(3vyW|ur2@B;j-dLP%}uX%PphUa(e@c`LouG1NyHW^30^zTxK(^1)E2YbzS{n^ z683+r-z}Sk@_Y)n8l8w3dL#CpWFQ;nV+3|U)#xwyc)T#@Bu%`ZRLR2ApB2oUtqHDnExD{z*qS{VV552 zzZMGnwZK;YC*LdIu?$OQuFU4X;^FObbz?WR0UD{TqhG{^vaN)g{in#4t|qph%nHh< zZ&s(t#iX5L|FBbR42}HKVP7aYv@>)JRmR*>fqt z4qGk!H0BXe%^0V)MBVO<=)-R4R@SQ{bY->vGB`U%6`Og&AH6=+;z+q<^0Q0 z%=WkS75|032;Z0L@W^PR7YEC+B)F0xtuxSW?}$U>Gio#KVse2saSn637VAfCXB?3D z%jxb=V%Y)g>PjdsodjBOrv5|Asn$}zqN;CTPd){$#`57?!TO=+!L6b3q1~a;p}N5$ zp|YrjZ42i0mxty*@2lW>l2JarVn!368mcTWQf*W{ZnGm8u{vsI^kJM*H@a6j+gh72 zh3U=4I&FkJK-vKI*ApPULJ@ z>5jmLA=V4w_=pv4bZaQ0xhbcs9b&^Q>a7q3Z@cPp7Cr~sMNeb zhvg-7Y2xVpsEWJ*{@@&Vece%Cy@yOLLAwf1i@M4-RPH-LZN({F2dm19>S$3ZO)Mgh zkUolS!DyTsDiq2eT;-eT8;?ADRocvy(SItZ8Q$0aY~hz$9%`6y){!1HJoaGB{pbg= zr(^QChuT&Gng4_+Y$U1|Wm=jT$_R}NRS9x|)&3oUHU7i?ox$H>N={TynunO&mZSC! z5%(i)F)d=7$KH>QaW`|;vu@yg)I6hwT2E>eo)F3vx*04G8WtAhRcc}50Wq2Oa6>JZ zY>l0>T{$BzM(l}b5wYBv&*8Eq38lEfbUuUacM2AG^$W?m@voYUg8yy{M6zOq&xf_~2u z)S}ml1I2k_CU|QaRETm&Wu&sw0q79y3BL;d3Iu%%y*E4qGs0;tQ~pbKBu`1b<~2Oy zLQ!gVY8KzcaWk5YamU0(r)4h|ozvahy2J8{Z3jHnDz%7~AWjTd5toGC25b3m1lIds z`+oYS2UEh&#fNGRGn1(*DE1-|hWldl_UM~2j;QhxFYPle%ejB3HRgP+u2Nn)CvFG( zqnfx*ItnlBK1Lxjfu6@r;ip>j+J`!-Ig7a5t~AG1`+eJ4YeC@}djaZAI`YO5dUyDW zSI{~>e_{!#jnqz}wX5-EdzZ%zJ@+R96LXNH<7@W0{4*jq+zik=qLIC@XE-O;bzKkZ%YqlKBQK^-HH zLg%O(xES~3+oBRW9AbhUeI8q1-NcY~VFjN;nDL1wLWh%FCHPw!j0TC(vn zTbfCx4wKJ;E}I9;Ss`FXoa9f`$s{0(s{$h|pbvQ%NXBUV^bvW4`a`#4O2Ub~KYI<{~T_^xGE%M{bArPts8A<{{PU34=DhSV!gB;;c zBKxwSN2|i4>;&DPsm0W#C*XwCiRlDQkpJMW6$uQ+LaGbqSy}W0M$uL2p{U;ejkClg z`Y|+qdQx?8rqs!5U{`g6PE=AT?Pq5zC3aWt%z72k_ZzAeL$AS=`4NaAwDW>^Y&q1C80}mcL z=$&2xJ7y5K+hXN^^H7(uG_+(#ujUSSm5!#0kmbp1;9P75_i+o*GDCrXD{U?^UZGdt z0{ie~gF>e-ACO;#sQjqk++k$qHlv|$TZDOpPX89LwYT8RTmYT0Ufgq>v?p@W>|mUz z-#~pX%Yn85u!cE-+-M5i(MK>GC^8c9wHoVsF`(uSz$>&Xm`sPzrF{Swxa&Yze+0Jm z9Z+SO34_D~a)a>`FBL#ZugN(o{vX|yykTt?0z z9hE~eq4K&z+lE|jjJd}sj7obY^mML}6Pbo|PpIn*Kt559ZNz_I9W0W5GQFx@v3=|G~JKnDH*4zUl;rO&{9Xa&`_IAERggUwPL7}y*@ zU?myL!GX~AXk@^N=vDmClk_+E7y{4fhQ1A~%eHzCJ$DvP4p^y~$iR;3kHC2Nh@9s? z@F33Wntl`fz_MUG>;$vsneone1op%OeB1-H1vJpWqF4+R)-f;uRsrQNXT=uW(C3(& zku4+vXI&GOlQ}^4yaR5$FdVwNgQL)u?2mnGI5~ivM2^RGcgIggpgY(c+I%grI(5e9 zhWNf29#ZtkuAbx`ldIxWNz-U;42>b_#-#TD6w7}RD0ot}a5V7s>HU%Tq z9|&To_yao|3rty-)A4N}&9(tWJ_VUw2{RU_uK-v-myDmNsYD{PDGh{SBjkKz%{Ca( z5oUXQ9E0yCfqgay7=zxxY*YX~)rUuaG|qwha0R*1Zm3>dfL7*OV-3DO1>D+2JnLs8 z85*E@fH|&%znTv;-4wX_-U8D96W9z^V1e`F)sBX<{Y+fPb6nR=-1E1@EsWG%FkoJx z8}tU4`hS7v-GW-n0x&qX6KjC`T@Lj90DAs!?nGAgCJ#o77GKi(vupI_iRpv-~V&VhZQ9NGMohb)plVW%l9jbrhQR-9FMc^crXUW1A*NOnDhleMsG!opF$M$#~r?a7}$xw zT@5buJgfojftG84cpr#YSR89>FU)}ocqLg5)z$ILeXwGT#cOMY{~ifUem%SVdsK@bR97FXU7i1nt3z?FE;Iruuu<+nFtqXLl+&z9v({Q{`+M_~4Tz5Z%{sC!PVp{Rw;n0#~2q;2r@k1Otv1iN{hH7Z!QrSM17< z@VtA$m04!)2io!l@aOlzd-;P`Rti_)#we!X`z*)vx4>hb1fyd$;_?pA*Z1*167iqD z;{W`>8k~#>%W}c~gggHf*Y+E)<0&F99bW-EkKcIkdOTofMB=I-UXN%mg;8pZyO|f9 z6b%{e3q1ZV)C0g^L5zIBt5*>}Hr%~JLDFz%^{>HmKfdl7SY@j8<5njRp+ z0*J6IH}af#ERAb*Vn-D47Yu3`3ivfad<_}mZcSU!wd zJ|OT5<1P-zU0jH}zZ7?|AEGe^%p3*t=sHF%i=*-#|0zF^_Z7f5$%Vgpk5_UW<8dBi z=*9hthE|jfqxb|f{Iwr< z=LVkXB*yeS?rfHNZWh0&FB%9fIUVf z@B;UesY~A1TNsRTSzZaHg%04D{h%J9+8oQpp%b4#)B>M%n;Zl}afEyr9kC8v66ncI z?N-}twk0*(?5f5|uizULmYb7y`Z2S``o>BL7WM{J)1;vsuw5#r=2v?I^F59K1H1ys zui_TaqPbPeRJ*7iITHWxE7#fb!a54{Jeh827E_Lf#{@t5n}{Cyp1GUa!{xIi3lsU> z+)u=gO-YvOqwW{P$z}?Xi>u@OVo#2!;X26QqH~yDv5=V8+c`MSKV0pZRj0Fj=XmR4 zJ{R*93hmqFG0LHE3{ByVM!QRU2Z<+Bk>!u5b@$8}?ta=p;e zsYngbvZ)oreU)!wf^-CU>tX0soJR)#1W1NLMteO~8KMl9J^Egt4OsSoKnq>bk?zZS zamM+o9gtIiP0Xda+E6Lx+ zFny#lT>c@AQZm%$Y7QeF4yte2wbUD`2AQRmTugr=)m46mE~0aiP3>z`mNTfb=x@%T z2(VYu^gnuv-V|ubYlxA@uPX+zC@@I2ni3 zzHO3z5X$h`*@jdkc>pYxdfGLxXa45O!n=Kgd`InSlqB;ByU4ukK7BHU?g;#>eN(K zuFZ#c`(rZ?TZp5K(Of#T*Ei}rnPcQfeU&ajyV$4RfbMQBXfi^jo~f%QvDNhk)W7mB zdQYeUbxB#kR+7@G3*=k)c@G2fWgk_N^Oyv@pl?u})#7v&tr?-How&{BR(6N{m8u8F z)0fO{y%3opeE@&U0pII#R32>@x(H|JOrtaXQcs2J)Exe^T$#(zBDwk6cc*^DSoWv>%UsCU{*;~p@w1F%|L zV5UR)djc^--%gw|DiN>fgW6jns5GNuftVF>a@|IhXJW{s%t~}FyMsBINSsjj5?94I z`mTA$-dJwt_$I%$lv4I^#WV@J#dDN{+9{R9Nq!pNAKpV>peA^Rn*>&npwCy5w88R! zW)G2OJ7`O-3#g_J&1lMhq1?hDB}nEITdVQ%DpY>b%v@B6yUdj4-=XW$hb&IV6DM`2 z{7ki)M-`5sjW?E)L{s*uagbuoi8N`1%(<8k*VGK~dWy1>^)k#!auE}wbzm-`&QcR9 z%=_gbdN;7#DljXMQ&uPMFsHDyEY=IsWs#*FlCqHr#BusMxsOYS8+u2x4|CnPPrg=0 zl1){M5tfG_6Y>xT>5uwts+eAZdqfun-N3r*D@ zfNwR9uu<9JY@cBs=B^mmkvF%bGxf<}QWOWba-2C_UxqW$CNQaQlEcg~YD?lPGnHyh zTqZJ6n<;=EXgv5QW5JsY!XJ7wT%Fs1t!g!@5Q$`Kx(oPZk8ze+M%H3P>X?a$#NJ(( zqUa*f(fVyD=1{$tnVni-_AqS50<*1t%qUH_G=?MRxe3hE1M?%1q3t3XksY9tF%X#b z4|*m}Q}eVI=$d{85_JL25{<#gJPW6271i$~bE%%itx@$l*t16%zo7%M7HpK~I2(BN zPx?}L6XY?j8tsVz#&LC>HVpm%B(kM=fI^e9F{;iQMQSKACNwtv;` z+5>&I(G=PT+rjACf%EZhat*;jZ#f^+5(=C5g+~IEhM=K72Gyh&mSzIMi)lCwB<>>=x9V8|kyO9a>n) z37_W_C@2iW$7E=`liE2oK`pL6*NQ=T<&PeXXpT@vsmqnh@>c9hxr5t)h&~i}?0@6m z49C>E{wL^g^!B;EGcw6cE@NccEO;%y2_BZt5IS#jW{Z54txwL+ITLgEqY6hfww7jY z6IHaGvKG7_I2XDou29~qb&#R{O$1OMHQA-mpjgBfLj5ab*=74--D`OZ2i_1rk?lma zF?Z;X)RyXPodgn$pqp?Qk!B5%jmjaw9#ncWumAl4DR*1gvrX-(jd(QVPeUx`}V zTk<7Tk8ZO&dAp^Rb+mPg)n*-tI{tZXJ)2BkGdn}~zJc07?kzWx3n^BmxcW*9sp-%V zR&){=`>s;iaBP?j)bh9SggtdL$EKIaoSObAqgeWn^s?#8Qq9!a$^HJ+PYa~o^>AW> z`iFQb9CuBLxtk-H^G`N1TXH0C39+a2^U{598(&$^3r}?5ihmna@;*XMl_U$%cJdq4 z!isQl+(conWs~)owWu|_^#`BCq{0otqhuG0gs%rHgxK(O^e6g(sdt0!$vG?w;UC-3 zS;@6OBHn#Aa#!@@sI5_V-4T&{9UGwObA@h14urqqD_zv<6NSj*+$`an<%_i&YR*?I zlTlk-#4$`wD5@>NKJTQnQ*Ve>#%pb?Cd;LynquQ{kI;$WrqIMtcOY@shyNCb2Wvqi z@TjkwzqU8zIh0;Ht$6CG-y{Cy_?_$Lte@3?#w5M_T|50zdTO|a4%a?gqo{##$#M5! zojue&&vBdmqV|>Y`Rb-$NeD8@N+FT>n*x3-%+^79_4e{s#-*g!#3W=SZfG{ zEJ^GgdL@&HN}vqo#?h$GRfm>Mc{m^Em4P;q=HbM4Bs@GA_K);m%Ut5w;`@pY=X(DR z&z&?&=Dw6AzaA!^`93d6`!*oiNS~a+O8=1km|D(9QHe2s$Hm92j%?_9%FU-1Yih89 zzhY*QjL+$VG7tED{!<~Vc0lhyc45DBan?N6YVgPXX=?(d#GC9k$~50<+2t}~%W!mP zTIihQRI3^s>0&7O9dEYewidH=65_e8Oazp$#*o>V5}b>Rw?#QlIMX7=xmLN#It~fn zxx)NYRH0YdM>r~4-Z=f%g8Uf19aSB9;14Pv|469r=ooR(+1oY1J_gH^tyH06^)D&f-oFx7wSCevtgW}rYb-&NIHM4bQYFbjtxU|GS`BQ#o#HPB_ zZe_%0ZcJI3@>lA*?@QCJr?*TyFMl9L@~dOjY?typ$bUTMO$_HaW0X*5da9>1NiFx= zl~%#)_gkUYG(+!2{-(A9$-Y?#S&Q0#x#~ugbCq}g5{9#9%vQ1#{^Y;uEf+iy>Mvh1 zAJhBz($=T8S@x3lQr6MJPHsDN>+i$QV>Ys=(?(0?Dz}3RxjwqaM&8YS7yTM8s-pFa z;Ql1}J;DfC|{)6R$0-g6#Y_4!; zZe(Dc=e6kbt%xWO+=7;OhF!DWaTVlpGiS-S+*r1&@X+1BvEA{>1{F<4)wi3`WEuS} zczR#Cp7wgcdUZ9QgjumZI5S~?i@K`Q5Rovicy=OZ3DN$YC+WT;JDgMRC2>k%7mm#hQ$hFlVTmW&5$dFMuh;huZ>ue_S#{sewzrJsIQDzNK|ZKa%mt^C7Lf zZ@tHrddjmid0_JNA2~l){>~(S^myqqj?>wG6uz8)V*aCf3%Gv@QF?yQs`MSn-e0HF zs-zXn+~zCfTN*knzEj>1$xIC!8`(2kuDITD6=S2FpO{Vh?vRb8=d9*zZXL?~B+;J@E%9ykR1D06CssfEjo8h(s^-0zJD9yn zL@NJ|3@GWMN4cowH4C$mmNbFKZnYa7xsKqq4dBaLG7vFstdX{rY!PxU+`U$TeV&5? zVxOn1C+rzQ0>6%Vhk9onDC{)S2g6~cTG$uV0$sh6d`~h~@g5uQDo6>Iy&*lW3pWFMFNPV99@0m~P;RH%`EXzI$$ z{^=u9%Vciw9P)1ySHu0v!jje@uB}l^V>`#5j2h&u%@?8?sg1%yu#V?|=cDhJf0$HS zdu2?crZEv1ktX)PotB7y9LMa-tnutu@{Zm`dKkJXuGM;wT`VJ_X2!mVt&=@9D$(BB zdY_qN6gCWEBV21rS}WM*3U#f;h5Kv+HVWslwoCtY$-39tn|lrZQ5=1i>cUhP zM)MJt>UP%h1H6wO)DdI9o@5Nu7OHiPC{+=w%jH5Heb)jn1NHr70wsbi{J%3@-e~{l z)Mc4^=A@s^y-DdmlT$oJehx~~GFqf;7Gs=>&B(T<(1@IU@?6S2-%(mfP`_s$%k2F- zG3Br~Uuukp2@Lfd4`0w8>)F|@)@1wS=vUdOY^`G}x-M~Zh!|;H;DG-N{!4?vEq{6W z6mWqbsO?N6W{vQdEr-JuVLGPUD_Z}d|1-a+_rt6_5?z-1meuZd&cQL$qXcVf`)#%u zPT~KNQT#cfspX+#wUC=T$gGAZ-frOejx%R0HEaP}1xFQYCt(Ix0PMErK&58q`-91G zob!?%vJ^4Z2%E#8>{*4nMK(6CE32g!(u!dD(DdLy-zVR%ttBT zerI@_rS?cGof-J^Ic4LoL0%%*(;I0yU>)ihn4inJEZ60n|3Q>(kOqa^P>$4{8FRe$ z^vm9V0!knlZm;w=K9gtpKCZaPWzqLz=S97DJ`^Z=w6;H7H)Q%I`YVUGhB^5V5RGTx zQdWhDCH}ot_E$sB{$Krwr7rh)<3QdmfG9|;TqYBe#khv57vsd z{mx;oH^Nrb%&wa+iT6|oGm(11)#6tQHLTmM(eSUS#b*~P^P4P(tRL7l+#vQ7QJ?a` z>meOJT@T5LsPexdOBgLRhf-SX9JXSI8SUTg9UQvkTb}6(j7aU6cF+6c&#Ux&zD=ow zXLQEIl;^*FDNnp_J;%&1wg>#gxUt#S#NLRz;7W9KB!9~DmD`#1yiWrWnRPr{yrTDj zzfrh7@PJS0!PY^JL9WW~822!HLH;d0+o-P0fV%MNup}-E_ma-aqZI;LZOe%9aGWm= zwehP~-ddX9!VF}$@)}#yTEt$5$sydM+nC9e+aQ55n$Deb1RRSZ3Pn=3m%>^u2D$=8 zOi^7(L{SW#6RffI{4Op^7=p9xFvo3sKAU9GxNr1IA`#k}wZZz2*TOn$9!2$sFq=Tj zx{wqH-=QOtHEa{_2B)PT&1{uYC1ZN}(bT@ajhRY@FVme~=hy7-wohMw&PlE4nGHkX zn^8lGR*AWqvtf=ZR;%r$+Cgoh9Px~nP6q?&J$$7zhNcbkvW&b6m+$n`0J5 z-Eud!J!9KZC$-(`cr8Zm1fA(;;&Qp69GZytC0bI33EpEL1)tm3ZMe9J6TvMmU^ zl&xTfQSPJiHk9qNV1tw}bze$!VvOI95wlisn4Zh(E%fkw_f?8G7akryJ!G()TdV>{ z#5C&wW;u~&qFO~8VYJtlndP7y(hO+#Y3{kekoAppMHa|#K0{6SI=KkDhjz7Vn~%v< zezSNf>|x06(E4E+ef7{f>WTeHBd&v3TAU?b75j?|fLoo(U*$fFIoP{Z z4>7xNi+#haWA@nPt>5&Ok=3}ZEmUPF;CECvs9*K-${fQVJe+!4>5|eV%}mXcP&27O z+{(DC-&e$C`ML31->=I10SQNwHfpakKlioGmMdRB-|g(}LwAFL*DY9=%F529<|dnZ zEtw|!60WB7@%t>5c`Jm6y2B1eN*UgT90>V_UCaQR+1^j0>`4%S+N^qXK`)^Ppht0t zk%UzD@Tlla#j}UzDjUIN{1wtmXykc9UV#I*gLx_sLmMeB3vfUDDX)Y-eTI;+AxB-~ zU97Z_%T4aF1%W7hES7Z7kPWQrO&}hIVczMJ8t^i=#IB~F*IEST`j0CA1O_VEz~Lrf ze6~>cq3)bS zv{&R?SMJc3;yoJ6MC#+<=pv=HNlT6|^?O6I>|bo|aJKWm#4YYBzJs3X-h1L(=y}{_ z-|#>Ac&D13+jyrm)fT9;w9CdhXQ;H=ntdBDFkNOZc%=^hx-1S5%0Nr|*5K1_) zwfv6%CiIr4$ZLIvLnnsL@r)G{xv9`6DF^SB+Rij+T}JU!g`(nO{tC3gmvSe$b2u$* zOj)&+Mgmh(&A`pn`TknL6DfJpR;0$Jj!LYMwgR}m3-o)SqJKczqm(QutA0#~AN0N7 zkC-^|SBKO)dYJr|s|CL=vn@1>%;B9c{uMHa%i&zKC#czs*2bUTx6}=P2CFU8Hd@uq z_53n+1m8(&DfDs;39bEg{<^>s|5i5nPyX(m^dU(|>KyQ-4GwNi{giYr zt;Db4e<~$?`}rvOQLvW(iJH^N&E@ABy6lkBo+{pBzUlG|p@XZPP>)`P%JCL$xw>5s zGghkWjc(>{2NWb zV;(finp<^QCmIWcX=&93t`k*@E5Cwa7^MdJe+JH_qy_33p=Lw6K5d)-ztnk&?*lOj zX@5LPeN);c-AH1BiK&g!)HEqjTIHN~+$s4vw5SV)To&EDDL8QXJZ$fRrdw7fiydX2 zHZH5L(v>sJ^2|d1hwDH_Pu3xs6S9rXP$=@R$X#JlXrgbSXPNi{E-u~qKBO@zN|uor z>8$+H)6&z-TiJ6-F6OEt*O01;uf%NN^)|;|@@1wx&V=?d>AUxD*-z|ITch1@vU62? zrnFPHsO7cZMi2E)THe6V)Z(eT0yOYY`4`%@d-e9}Bc;3gE$~Ho;7>GWrS?nB7pRo@ zF1bU}*5vldw}3%(E7#1U_9C_{KT)dd%HnGSkC*MPcH%j%F1ZWr+GXsgL$!-U@Bv%lP_)-@uN}Zqpz_} zsj5c$!-I9yElPq?CU8uV;p%rKP|H6+S*5Ji8VCO9$${HSKOjHXtAm0ww5I;<=!XlL z_08PI7ptWa&+N2)K-850Z>|fMD(2*SOVcp-Z7uHL^UK|Y(&7~1sC>|!-L>0u$Mw!# z&Aq}^(tT6z?inLTxmENPPh9W|mPKI+bXl5-<*{Gs5le$FG8zb)DZs$wf=b&XxRypS zi>zaIKIkExcXB{i^9A&jX48v6mKHQu7>b@1?2VD|wNc-o^=v9w3^kZ$P#O5lyZN%xeQC0ATh_$_Pz`7fRfqA?I9Fk* zkUU+eEHxD3xkl0;;_!KedBQ@W12-M26q(spq&72w9pa2+4>Lp9OwI)B0?F?U@UZQLivEr}msJ!87C3HE$;uW^wrW}h$(QquWN zJ~GgiB1iR|@-+?Wm(aV5baZl_i_ud^f?kBag+E~hoTCu-rGD1w4QWbw!k~WMxY>Liu z;?U_#gxQhbV(8Xorn+4fKB*JA&g?eY#<~gD=;l!6Vo625qMDIA#?*0U8s$knpmtpB zK=|Jd6mr=a?Sk&Z);Fe-`Q25{h_wsyhoEs7Z|p+f-@saBuE$Boc{-f5vukjx?W|;i zRgCY=vh*N)CCA#CjH&WOx|2O;H51R1%!(<#wug{*bf%a|)QqyCZ0F(6*oEQw+z9Tk zv$1RcgwN02)v|F!inB{F0eiPN*E&a^3#Xy-*n}0a%38_yHqwL?cAwNSSWS>!Q<4XU zJ+uy+hs3uwPlrYA*DyCWPyUHt1hu~jQf2>f>xTfMxmM)NO%ta#N zyz((lWZgzl??L*0cwPMkH)u{({?02Kxo)sa4Re(-8gk!QL+v0|WvO+PjiH5Q&In_+*@>dZx=I6Br+0#eQZ1_zJ4KdKN_z;8 zQ*-*CT-d(itgxC%4y{N!nUADO%4zp9m-A#W!InJ5zlnc4kW2<=NZT3a%!} zuL)8oc%yr*b!J}IDa@|+2yMtEdW}`#p?!skvEQ;SY=a-7u90pr-xbSK&M!sQOKvUp z64t7Iel5)nH<_764tEZt8DE=x1rPKM-$9<9c-(c9J7sicLad^m(7;G%sOt_;@-hOSK(=vG4@eNX{5)!&exo`HxIba>ygbY_U zM0%47a{>8YQr8R{{;;9%O~tN5ZW^+ym&Wqf^uyvZYlKmi+?TaG=1%v?q``B>%y{CCfQvs3+`m$hL z$VWqFZS5Tw@WvhZvz268>WV z7pX1Od{-L>Ng=UTaE%mgonfGZgxX z(Na70sEv8CvQ64;%&``^UIqJcH|cH99{*i-h^wCQh$eBPtbOcEV=O7=s21Vl!JBL2 z+nF>wWMi;7H&oo0n&vy9jFrQ*zD$(7UMVH)ppC^(_Gxgof?P*2Ib=BFVA@h}e!nIAay)Ef$Yae0F!jjL5Ai91)oWOk7Be?lsdI;ui!h@$V z!l=vEVDoXy4aaJOJ?izw2x`l1=}LR0^Bk(4^PGLwHKC`uosZRZ`(Nq2_P}f-v|!HL zUg&Orl27?_xvK;#xyL&HrDY5GY$n=+tjc^@k|mfH+A_76d%yNXC}6J(hDZKR?g^)j z0sNw1MXswnE+xObT_~VLSZ&=JJc3gAGISq*R6Q@ZG5edbt~lec4-S&@vSZbt{8g*r>Y@G16cXv(=WSY6cSrq>P}#X=Jad&#`{liF4x^vM%x1KFOl>V~1VVna zYjsMquvjGkA9r0D782v{BYk1Y8~uD+{tSaHD{Ni1!FSLgWTQ#rimH(-{-JHg5VyT+l8`Rd>`MAk;O}3X& z#WlqkV~v)F>Rq`Qx>g)whS)Qt@_GXQ2J3~sTs5XNB0q)8$VP!3J45`aFk%SX7--g0 z_6vE6x|=`37POXfE$q5n13Nn@M0#7<$SWuhud`Q!(PPu)Bs;mo+_qV6I&`6nK=XeO z^NDn!_+C#Z1<~b7U#MutHOqX)D`HuX@yrlnD&>CtYKbx7X?%?;9h8B5t<||`BCwmp#5p#eWwx4W)Ug$od zGgh!Q;fuTm{INaI6*$6f#yL$5@L7E93844e!V%(!)dGBxuJ&ZOV_gE1p(RwsR{%TR z4*Kx};fs8mEech>7*ZUHhW$xNQh}U@ll%fMFYd5`%gqfFr*YM|(PCrvKk+Df+N)v; z(+PXTd4&eR4YuM71CQr88rzrs>=5$7*~dlz_bD-dfYKTaU*`$fEtn6iYgsV+vV!dq z0`5QuW-aP|GZ;2^facEvXT07{1UTD&JKwAnxXdiKTG~75Agc-`^cp>Aj zK=Vj?O_)6WVU985DQC@s+i-d4BG+cBlD2R~yahkyeEe#z2EUSD0ad;*@QZJPb5#{; zDpko!c-Zs;`hGG{_R-Ly?h9trD{F_91SS+#&CsD3ZdV4wryx-2*_d3|^=!^%16o%Q z_i}?}oBQNYCzp}hg!-3UP$yXhb?h$8WoxvZWoRc9N&f`&TT|JlF8UoKD%?#}zP zG1@~b3w>ifXJ%MFa>VHdB^3`pL0&A4@#K+9yDGYx3PF((itz=79?;!BFet^hsvIAHxJLNz=O z_lJ2U{30!cAMzs-<>mP+t`pK^oFP5n$GJQ48|8nY#(7Ze%VrWTfb%;9PJ+9cQeg>zhp z@a2f4Y}OpJrkEP0&fs3?JC#RHmvKq)WkVOr^)HDnL^H%l;s5p>G%twbQ@&vwSURwQYdL&}d;E zPKka%E$A@!1};L+&^vEoet^3XK)P+g_S+AxV_rIuelhf@~>QTp%mLgN+s*1-fTu_hBO0Cy0!fcd(4^P4ARIt zVt1jH>=@eCW-Ma;M!T`?a+YHE0g9oLJqsM=c2LggVYhcCzy+v)-P4(2CEGmydJ5iQ zBvfG}E7BSbJ(+HFzHyDd*ISwq#y}`#@6x{NL)6CleC35&MyaTj3uyjVKuzkYJJRNx zJK$SelND(ws{pM&oqtZZNh|rA?iFxq{V&WGW4v&+#uiUMQbp>(pA)+B{TZ+9K4J$`iq{h=71@+j2#ZutJ#OlypKSh8+{O+iqncOgAAB6B7l`-AfNh+rOg2lT z{$)=vN7Cue6C+NJF+Xq=I`xly1({o6y`fX{G$c+CLKpIn;TImq-*k0iD~U_E8jPFk zWd-dt`qMZ>eb5N|SAT~c)PJ?pR5BhLN9jQ0iQUJV1il`)nj}9PAajTp`&rq*HLomo z6Y7h%#4o~4VJ$QpS#BOixz%h{?l#=T`jc90L+3H`-5Lo8fsJ+>wxhEh-sE8%1NVi- z{0b%u@8YI%n}p$97x5wAP^=4m^n*ff{)Vs`Jn3uvW~it$P&rme3sMe_WI0$5GXNYS z!D){EShtR#zc_{7=smc^a}kk6p?<#y=hz(<9i7%HTx{hYd*C)YLzTgziPV7irP?ZjfL`KMy@1jNiMSz8w5Ky0&KIc@O2pt zJopmKRxi-W)&bgrt~9e#I7R~{{~o@N-;Eb=IqYPfraS0Nn%#;uf7y@B?wDs3we~xU z%qZs=ZA)4EtMMG{hhcQLc}qV^(evubdW>r6rIo&#ri#j2tzdAhc1{_sEz=x*3p@n} zqAq3Ux7CoH>WqQfkB>lmj+3P(LIrV`bd%2_cjxCrd#VV3RhUn<@@eb}XuJJ%8nM^x zJx){Ll%nkE)*QI*4YIz{a+uA;Sf$|}@y_}Ht+fu&yBW+hVEPwnvV-&_BUqNyX4jFuOe=Vf zZDQK6Js6V-fccvN{%z0T50}P1WHZ21usTlW0&F4Zv~FeBLG|BzP>DL0ow6>Vf?X%V^ySsDQTFK_J zeVoKPWChA^ujJ)VihYDxd+qNfs?qD-nR>a#r@g-V3l`nS&K0v>;Pxin`WB1 z#(qy{I>dTGGdh3Mnr0t*)@nsN(nI!jyQR(3#ZDT1X+DQrcDPx^E(H~T$*N#&z;|VE z>313V%wJ|Mdc@Fae<-cV))&;sbNU>wy37za+vH>KHt3caaTP%ba4T*bz|AeCGV*_An9lZmtEp%xR1WSxjy^?=9@E zvjeRxY*r|jw01)6iC{ssw3gaBQ=eQ^XD|&(0lL{vavIv7=mBs~N;C6~?Q9V`&wl7M zq<`B*=z69fM%Dx<9*);oa?)-~!G$uQ3oOolZd5)=OCz>U}c@444(64MGtC2m5EQGQF$L8XmSkIgTz&xJcBp=gzxjCOvrEs5XZ$$VxE6h89Vr}&8^i zFi5?fvfq2b(?M&UHqX1zT_vr%+0TU+}Sb)BQww7l(vid1|hGFU3_~xJIQV)9q@M%`b+zR^~7A{ zWNPLNWt9KSeexDP*~O%NBBjHmNTN8NPz@hJnJr>rhv&h)+)mtXO`&y@?VJGU~MQ`r_FZ!KG{Et zUB?f%GLf&TFC#dn`k%|8S*49ZpVBVmqqsh)o0T!Vpxq<2j(EX+EOn=t0}RL+#sske zdZkqMA@|hYrtjt}8oj-?|F>ByWTml(w$$Eb9vVN8J@470d=z}z4)=T2$8{zx0&l|{ zWv9`=d)8VNeD7^y47H95q561k6nl(*wmYMb5NR~Ggw19pIt9r%vOQphE>k|SQA}m> z1DMZR+$O6&dy|Y-k4h7m@x~Z&6rI5?x5tTvgJE(naP702dw93hCT#+D&XYITf|*RV zk>XYibB}7$Ri%=Y#o2BAmVfCwb040b*ln=SGI#jx+5x^fX|7UlKD|(IO(rR=tkNUH zLjNoNoZ3syE+@zR6&fe^|FgukRvMMCFO!up$Mu3?wF;q&6OPGx*q1Lm!zKw>uZjC` z+Wu8|rVZzAkr!&&kf`KEVlQck;uWr{Z^M_S9I`4zI?2hjRp@`g?zFpW3Y}-=7nadj zr#^d;Ow{W$C*AJ?16a2-N4w+n|U0KAcNp(FZ*d}y?J;zxp*H&8ySuwNODSog^S-P~tA0rpyDrw`Ky>hZS zm))z3@wDZBrq!3+PG-KTQN$iDoFTwwQ|j()_{efK#?=T8&>gMeuALY^FERJE&+d~( z3B7|mKiX~?J{vnne4i?l*AXS-O%Um-E>&;F2-Motn1Op`@6NL zrVy1PhaXhR*$d>F{#Btbl^@1u&lWS%x<{rtrN~M<6@IzHxPJ_tMEd7*W%0Ge!PBC4T(ad4!eV$b|k(|yBc7khs;z6;R_es)((3h!U!W_@b zw7ldvIn9n>I|bIbqTvgh;Q!CNkC|=GqDA@sa_f|7Vq0OP|8@AY#MWWQ0%>HP`%G#9 zYq*zp@@lW_P4XM=dcwgBCxX@NAFjE9j!+J$3?A@u_DS$chCbNjlKGCVP3Ba69NW!3 zOkK}xHy?4EJq=T`SRP-fy^oGEUdmgHqw*eQtZBRYvn9=8!H%J0iQ7<&N%Zf`eUn21 z#Uilx!YoohlUZUhrXejDxZ%4dh55VFd1Ps5d#!ii6*D?2-XELN%B6Udm5x>kZF~4> zMoX(~-{x}rb_W8=MCpioj8Q3YPsr^~qyH%dX;;rmu8)6%)n9t)t(er%ED@5zuAxU$ zE=dsFR=z3Mts7B-vMErSlnR{*gf(Lo3=yEE^*ALV@&>mc<$yDV>m^sT-T|Vq>Gu{+h=}JFiH0Z_eZb`j3p?v=>hMFG=X+Kbf#-!^391aTx|CYPR71W&E$lx$8n?QT&zEj_?55N#CS> z5e~S98TZ)C$_sZveqHdR^IF(uZ56Jl{j34u`TQ)|12@h~p633^B#tTWeUTu!E0cKo zkeOn=wD);?q>W)_3pYhy(nntM{m{#stBgru9wRPTL8!?b5v!;JxZ`$F;V@e%kScuT zk6{PpkDUltp2^Mw<+?n8tCLn+y2`6sB>EJ>UM4>C3Vh!(7^c+3aRruyho}Qt>_{5r z!sNh9BjDbr|3xlY)9q%$7w8!NC6zWR3AyzyOtiPV*1&qnSGG=Yx%GkUY;F~tKYJOo z(`TY~GbWE7FGuOexqr=YzBBtWm{VHCzIDbM&Fz!o6RR`%8m#DEXf@Jf`Hpmiw{_}t zPg!HCJx08&G$%XkbzDKjaxOh$9_?jleEkimj54j-QmKC$UCp z#gw({+xf*PC=Ie?fOUsoY+hhK*^X0!->=RU_u6NjqhveH%2&n?^Pc)ev_e_{C$P`?$nLUQ!OONI931>`XYItkv}XbXbdkGZjb-jQzZsVO16SDw z&Mwl?x&f{824px>-+IZG0{f+d)847;^k%#^DZAQ12m7WO)rw z@3GK{dJ1*hO-^T|SPvf5FL2)$aKQcwtvSuBc%mZE+Io)r&SEY=CoC^gU+c66 z7KlJItuXFc5vk>2`Z*(@2l(2tal+XQzdylkx)Rj7owvdw3#)iao)k%Z~)~VjiajP&~}Mn^5Yr}`KyL@(hR5ndGIgY zJ)FVs*U&<%1TXT+c&9jYM6RIBz5?;~2WjVkwpu9EtAZ$lL}-s?g!Xxz&K&9}lx z3JP?WO^HSi6wXd|~i+#T9Nh~hWMlZI=uBAvqc%#SD!;FR2lR61ua5 zQC6XNC8B;9lzbNGCzin#p->^rh*Z;+@p9r-7_QEU?-W5)X*7m9R>Aoj}P30d5W`(K^W|2?6M z^aP~kN1k)x6A{mu8^5{##b5e!hapc{afjTvt_B@#*@Xhc4(N2n|DeYRAT8%7TNbs%;Ctygp(yk8cV#?TZahzQ zq+S4Nh2m8Pe5VBB&W9^Qa0ecDh{D}?T$3Kp8lKvX+VAHl{)kui@aj2K8(-p@Z>Y;p z$mtV!bU#P$@d9oCDQfN)+%6K))>82dELukzqT@UMPDP$lasMRbJUwm*%8Nrurs4Z0 zB27YB3iy#wBLClWc+plps7Z=h35Pn#_P<`f2KtAxXo=CNqk1UqzYv*Kk=kEqwfWHk z)8nBSA|VUfUU_^k4@xW=EizqMH|zgezl9cQ;7L*ucRJ!dz2{U>1|n*KqHcbpo{~^T z5?*`olYVUwKPh-WT`kr`S?Fj%BJ!FS5n2TG5Qb;RW+eLZpNOF!Xaz43ziH?t5}lHW zn~b<7GoCa(zAE79^CO?Rkk7Ixr@EN0lt)X8M(-YiUN<+=5YShbKnX?TdpS@7xe;MS zkn{AOD7_acgf@`le{mg)Yt!{dzu@^&P|j({Ljv;i8ZARZ30rubhEG2rhTfw7?&JTf zsP!L+fFNq}2mX%3+2K35*x!JD>NBXYesT`rlRZ%8+>6hz;+uC6eP5CPhe%n$(**uE z_Le~|T>tAAqyE>I=EwVGQU1lz7e}CFr0eB2#O$gzYN8K*W?(Kn8t?Q(Z(0X)tfr{7 zDAY?4)KEM0>m7jsEP+?~(Z_O_6aJ04TS>$Shq0#+^4$Qlu<`i63(5@~YJ55!D}|ZN zJmA;H;9P4m(-?QGjI^?$#w#!*5eW@ZYD4jzZn$C|+%y+KDSH!s_FyfcLqRvdoCD%& z9nx)&`whYO+u-VS4P+0rehXm+SO@p(hcn(`n8A*~-Dcr;d!XamBR@lt${?iK7r*;J zlead|J}q$aRtI-(2HoB&_-1*enIGqHf8li<{H~1oa|G%!J5tY%QcCX`)7#`JV8cG3 zc1*-#JTy;#;GKJz<=%kC?N2D5#sd%a6^h0Qct0K0^9J?(Fx{0EpFYBS7g1+7@#!|y zaYvl1)^lb6+qV$=xUHS3_}(+r@lIzbR9`1DEd^aV;rc25ORi)7T9Q4dt(M> z^|4kHoOZXeD%!W4hvX9fjDN=|$k}7(CU5~B;EO5)(Y7CR`sem0Fm+?;e0vfTWOIT4 z`OAEzk25(djCsob!KveO`UH+erESg{Xb+}6t$y|=#N{HB<(%(0Q{Un&VTz;Ysti?Mzg-Kpt+2yc zi7dyN{Y7A+_K|hmWAXxy@kPWK_>)UQWuYLy4SFUyL=%p-Mc|mb8t2zRa)VUGu6rbB z!=3scocs^LdHw_VQ?G&({4$m!{ouvWf&GFpGs;U^``+2ar)3*d7$9zB%@EZ#d#W|o0pQLHlntBdRCH>xnpu{SgoILhxprWQbH zEJNQ{4cC4F*0luA)my>E_a3{4bOJ856FEoPAs2~kV=|t7z}AOf%Ty@ROXv%WqXe5G zYV$&&d?r(w)tJlpG>(+RiG41VS`VQKT(u8E`@ROesVV|NxE5Z^N61{V1#ZDL+06Kt zlLaaQ8jnXX&lho8-x+$>Z=L>5eQ1B@z)oU);2V}uKeQg70u7u%_tVp~5IsO!QPwI5 zJ-&KilS`(dFEYa5D>97!HmA`4aN_R3V@rqHs5n-07TGph z;Bd7615nt%1cYb^vj^P4{?K}S&P)Ld^%c9F?Ljg?qkbOF)HAZP*f6#=e0(^#i77x5 zT!pXY2lfy92|m+b$SWw>zhhRh703*D)YajJkZB|v$HM`r2U!fPM;sf;6$kI+G@FmL znSx|JIlvX=G*Seprf=L!4rez|)7lK|^a^0Lzu-4^BXH(l31|Pu&;vDrKDdpNNd@=G z1QNuuzuOJK<~V>@SxZ|}XfQy|0`jOW9dA6=594>C6yni8aInz!w?Z6FC13;*W3!dZmhRjyMKflFU|Sc0jZL9rnDAIcMx-cxqI# zZqkReFP&*tH)G6|=0x+Nu~R>(UDC=z?YgG^L))#^(TeL6j3+3Uvs!H}pSE3(G2R=s zbVt3Qj)zO;YG719X|Y;1oYS8(OVLe0Y|b#|(r1?Fys{2jJ*|&+ePAfAVzldu8Br;A zHCYIZR5$hndzTdFZ}7+XQv4dO58qkHEY=kNgV)?uVKCh5nn~@YNunZdmP+Dug#&@8 zq?A!=DHV|fX}GvX+#oi9-|Q+u;~NQEgj#U4tH>YZws6rryji&laJU%_C#}iwsforO zR2g7g>%*DBZ*2uSq%#mu-)K9!*F0^WMXZ-KTfu)tgp-_PbkH;Eo3so1BfW&N8?KG* z^*8!ephAD?+jUO=hWw-`t)%8sBjLOorp{7gln+6daxdsnz6Dn+1C^5MCAA||RWczKb^vV{3;fss>`ktB z2EnO90`4IPM!3a53vI@E(HnA~JSGR=ZPuIYfy?4fvK+X}T%&)MVpJ!+I7oHT0Nzqac>7)2qEQDHrBMuf7sB!EPLd9Pwr`bXZPx(S{aXJRy6dx9I zqu^fHn6C*;^fm4?Cj$T60D56L$pd(^e8MdAK6K$f+UtPh+G&-uKG9J$C$x!5n!k)b zhRZmlH$%MN)FxekJ*kz` z&lwZVMnH#;fG^4}dnX)C?g9^Vm@Q9!V>HhV6jg0*0QU?^I3Kx0wDaY_&s+vqV;{E* z>hj-!5d6uxcpVPXWq3F0zoK{_{_O?fe@>++IYc(4OK|9R;6L6{Iwn>VF9>NunD{{0 zEr6v8hsH3t+z*9cV;Wb2f5AV7_F^AC6Q9WS1O)5;OSUwrs{L%>x7d-2+3_fA!)F-hpFu0_Vam(oqN!*6AA@a` zwMw*dKbRLN>_@t2E;nbJ8Z;E5X+3!VE@7d~MBV~<_MP13dc!YiDpwu2==1Dh(ib^R zhGJPUE|$cDvvQQ2gl^GnsKFm6Ztes(iJu|V5LY7BiBtiO?3(llZQnv0A1A&{N5YD^ z#TP;ecrSk8r=!Pz4@agA+&S(aAm27}i?{{cd2SuzzAE%c>T#pEHSqKLOvVyFI{_HC zTUa&y?OZ{hACJ+0GTm?bjowB^BfFvM!*oR(uH91CsKAh`AC&coZGZ3sob~<=W(pn) ztPS)CtPdOx^bUqA@020xPj#r)R$pqwnr#s8)iIaZ1is8lHVosp3#0Nz)a+!g6Yl>v zka-r_hq+K4;6bk=-{0VeJ01OgVQvSoiu-wszXm_|%+fLR>6zu)a%*`ilu25^dHxra zh&xD+#esA+JK2R^mGwGN}(e2n*!dDR$doYcGPN!lT8geGXi;L3GcX@(7_63SOZ_L<;f zMESwM=s?lHhyV%B3Jy|os^3)+idhTIY4jy#@RPB3SR8oFs%XVWpu}DWGm{ZqQEtHnMaxQtFycP;MwdLFJ zq4!Cv#C+lzVK5wiPx6iUSnd^fkedmmr|(F44w9K7i?wrL@0deA&Y(^2#9Xx(dhL%qFML9) zoe13NM$XM01^)6J+Z6dOhM8?XI24U#U$J>H&#%E1=KcJ0VLaOQJgKaF70$F}T@i4n zy)I9Z8_74M8tCJTi;skQf>(IPSBLjgdoYqx$ZF)fCzMj^5@g2h@MmnW_)&kN1@$NS~$wBAEAu&>p~J{^QI826-vDLpouG-i%+zZ-@VM zQS`Z-Fa`B<22pn#@%30ClZFa}IVk3JACSBJSq{2{)Kuv4&r>hCMg z6Q5x;J1tHSJBax)ejOBApcMOH_BNj@!o^`VRF9M*KY{jL16+PK){U9lBj}kg#~OMr z*lGV@U%mo%aQ?#?$XA>#zOqUJu{@ksrs4Doyd?V~=M&K{)iHh0+FfllG71~-^gU2L z8mUjhd~TV(SzoFz)NAQw^{3iCt&Y}8YoNW={%G6u0mcZk5N%@BusdSZ{>aqAs-`q~ zg2=6oHnyAVkN$fiRJY~=AzOj_fjaS#HSnQ&!|cc2L09Y>?0_ee!8B)YvTo9y{K3lR z1nRyoT&bVJNzH^V*at+yGyXpRh))8q;VR<%Hs-kl_-Kq<3*fN!h>XWKE1>T>1*SuO zwge)o5B#9(u@RWrAIJFw&bgddz(|+1gVt@ZleSnrFq129Y4AOKNRQGrbUNJ#e$Q+; zsy3#faNhOO@8)Eb3^md)!<_?7X4$x}-_*ar87;(EX*@J4qc>kn^IH$CIq<#OjWehf zI1iixK4)LFxR1aFPv)kecRRqH;VyBjz{XgKel{H^?BhZ>oqQ(;G1i~M`gI|hkMTX- z_xm35|C0!aV*(HEchDT{L34)CY0j~*)&?dC4NF%}V?m0U$%k%vt}S+>LO z*L~+8n2$K-wl{-U7Gmc{e4hsUA|LST=V?t~r4`dMeY5~-KGsytZ1j)Wn3kZW!Pdx6 zqo`&Y<}-5zR=?xT24)Sjwi#g-$7r7uC9%lNNPE%m(5nfx{{{caz!_eUDFU^a=g4y} zG7#sV2fh60p3kHv5&UE=?6~B z4=|n0~G59Msrd4PT%vz{P!Ke9d*1!yQJk}fip*da$tf~5F2W_E} z5QZM0C~&7iGYWBViB`hyK?6GlJ$Fv*J{5--ofq0(JIGx;TNC)IZbn?c<$i;)5Y7MQ zC}t?|e{N!OVbzm>c~~3FF+RhqPsPmJk9Z7+65Igp5ImmCKtHtu-w-P$7Zj;CpoRCw zn!6sC0lm;AjMtsGL~v{_qsRAfuaL?M)Nlsk#fq{Ldfu|+JR3!ZvW3CUqSy&3j<#PO zJI*(8LaW;o&^mis3g(#QP#Zr@iQYHwnhOwznaueHuG9lp@16AxdJFpylXVb@^}ubK zVoWj~V!fEGZPYS?Y5rVaZ`?M5##KWwvl&nHiTYcuGggSJwC~#gk#rVdQd?gek0&!3 z-C1E5cZySr6fIiZ-6?J@P~3|aio08J_hKzYio5$_>*JP8Y`&NO_sp|8nwy&|=e+0l z_Kqbh@06wLDAX}d)wA@KI9u0{d%plB*$^T1$}KE90B}vEsXhP z*>q58=WyTh#rfa)RQ>^1lWT+>%RYK4MnyAb_^GsjW4wURK9KFl&SCelxvZJ{2CLys z9ETOrY0%w%5he*1;X2=x&%>@0-Jz_#jaihBX`0?@xW z<7j+4HH7{KE7Q^3S$-{dpYOz{2@3_8R;9^FZ?$^FnhX zZm-~{dRh3$H|AG!<-w@V0s&$y(KBAzXr}E@C&X6DufabKg)Rj${l$GnymdTVJf}Tv zJXO42&n0g=ZxvsL_k{17Px00Fzw%cNb`Ad%?HbF}Bz-wpFVpDvOb_lAUteeh(xEPL zR@QdbR?;>R%=jJf2dgbKLgZ&9o1ex0jw#06!Jg7iZ-$HTNQ_yRQBTp6I|O%z>!v2= zr{*Q%aPfq=O|)B_mI9XXmV@F0b9+;m@5X&(?xCKi2$V(ZjfRN&wNQG&kKj{ieqgw7 zqUVYGN*>{UpZC4{lDmt$rKhdi>)z$=;F;=a<{b;ihMU15k;hVVb&%1V=sO-N^D@%`p`}m`l*I$=Fs3xU0@2*raGu;nmdAQF zO3z{|a0|J|{AcWV?+Y7E6U^U=buE)EZ!O0ylK9Ge-_$~w0GHYBIA0E`Bsmvx%1Y`T z`AU=uTY~F+TRo}n<+;9`Q#ov|IcHAJ&Fm%F^RxG6AI)KM+vIJ`EAP4MjRoj%XL+bP z(x?T!`*mhHSK4&Zyx2Ozw%PuZz+a`~Bfp}V6CVl{$@|5rxJkTlZ2&NtVm~2K| zH7xoaY+%*>wIi?e&@9q{b& zjSGE=4pp{-62F!Dj;YHJGnuTz?OWh2Uf6ZpbbtGt;S+h${X9ZQduk6jta(QHe8#Eyn4na!ePjCNq-LUR=p(LORBzbsQ>X*aE0O zoQg`)dl=6r#QIB)NVDK8U%K0!H#9df_hHV|?3y{pGh1b4WbDiAk$EqZ%=s_(rpN66 zJsgW(Qo3oq<44F+49&CVtJa_GSNiUyyM)hJuh;r<^7$_+9=k4(x@@);Z3WuddvD&&GkC7Bh z@cqpjMVD=;JlEk>FhJ80hfWtJn+MR*Rd=9|eVh>`ZkYb4+x6w&&WL*vH#DTRVxzOsn|HaLm9}&4w|WcE?(G1)}^hstkPtyNG1$ zfO=z=s0z0pg?+?^5$&zYlIVj6Xano-`%ox*P=UHMHbLqUog6wJyczg2kl`!m|H=E_ zy*Y1w?r*u#+{PZacbLCY;9xK_R5*Glc0zLyyWy*~ge}RBHSH9;;d&|M_+&3=&$W%S zHMQ-vjk;3)LCBlUHAIjEFNRilhcW_G;OdcQ z!3P0PpryZ}e~GV>*W+H3Hzx0CZWs4sPk;Z1fIaj+JU@Iqni4a#LBv4jA-9{~Cp0k) zH181W+veCSIbPe-?2WC{#aZTSCPgR>UyMU+Plg6j^#xI#3}d(KA!Atm`^hTw6egGX z%+BI!@E!T#f?r4h+vF5njXrY|5NleEt9%?41)ajG%|}h4gms4yT;xymm-f&1YTlyW z>+Z{W#q+A?jmUfHZsmRHmjmU(`@=6I<>Wl&lU|i7!d--CLswH_b4T-M@x67H{i^+n zeWI;|rI`6wA&Ec1E$4<|zP>@Fz!Pq)@rzNActN}&hl4)w9b%X_7(}!AkHR`rb#o8% zcjo=3Yo^5}hpD14m7k57dK1l3BvIYip%qih#RzF)xJ>Y}ufEsntKq%tZQ=XY>+saV zta3DO6I`|0_;&`kh3iMZmqyC9m0EgxqB2_qJ{oI8w|G!IDTXYP^{)M+eZT#IwYWH4 z=*HD!UxS5n1%#I=8%W4&hk#_^Zqu>HP8HO&Ko zTfy#n3-+KT$oBD&UIgA!>$PsiWa2E0GVTk{fy(V;}=Jw zb8NCsssR_O`nC1!>%aYzjr$<47jcf|v59I~I zz65UxcSZLa_aVI4otJN505t**PirtHjhzyAqS32OCx0oedm|e+6*wTDq z@xCp^6-?fll98O_3OO3tM9}tsvUIinU^#2v!;hf-nEG?F-{o*@v~;y>5Hro?#QbKJTS5oo z6Sbwv(b#CUgx3%ySc-1oATF0lbILHVq_RClZEl<6`XxgY56+veVq*VoK4c$r=lM@X5{(A*LhhBAW*KRDX(?cBY2WJ%B@Rq^n{P?-K-UfH zS#B>i)EKXI(WmMGt)JRTo*Vkz=f_!W=xZEu#qvNqp1|&4eOy#1Y@T2p>8zBbrc_NG zo?1Fz*`!Jd!yOH+&BVp#$3iaG1NAHYhziEf+5yBmKC6GipLI4E42S*~(QFJA(S9Rh zc%U##LJX|D5rQ1=o5;g}!%z9T_!jx@;-+}_xRuw_M9b=5h&jy9 zLNB3~`H^|AWwx!p^Ob97QlVrzajEm5MdOPxy;1A&2pPZUYF(wV)Fd>;f7WxyUE2HH zecK+(n*NCYOfVxt$ab}{-Y=d;2I$$m$$Zt8 z;Vhn*FZm4+hgTEdxH`I8CQNhQw|?S$=z)FoyIOValv+b&5NXRrBsjsC ziK8K)$_Ucs*cFINB(Mxqn66FMC7v4hK@MB4l!!HnZVkh~G9ddL{-wUZy9kq~)h;DyhMi)1EOxgKwS8y5E_O8kYpQH2BKEgTKrg3F zE4jM#lDJQQr_NU@Ah$kOt}fS#9g1bd$}6p7*~s7wR#S|3#3IJaA22@^ms%TJ&DO!z zSyl?(AOq}&?R#wZtg5I94lW;al+23nfG+bKZbyxBhNgu;gkVuKQXBD&nON~W0xRnf zGo2np-X>~7uk!w@wk%#83ZOIUzp9BXy!1<({!0)rMLzVGe z`an%q$0?^`hTKnnByE!VNe!hz$P<5&R>_58fmml`zK5z$!4-lAPk(Q$0>!Eah&*>d zLRv&`hBNpD>>ktUw#f8a*y_+Qnc3f%w}{!zMjN~V>)%V>As3LnL0H-g_3N8>i+Itv z1zOF2plN$bPjQe6dr=qOFMi2mMz^hEfvbV%AQimSk1(in6r6Y7{<|pi$kE10vF$=MCu7lXb0jOXoM6Sjl z+hwC3lZT+B?2c?#5@=CJk?%_(58-tlAdlLL{F!`-ecE`cJMt=*D1j#EzK8?QrT5UM z=>5=BE<-l14D!Gqpg0*uIjMipYL&=rD9park==x{bQ1H_#*x=V)*%_0@sfrG zxra`mQT+oRog-cnE<2;4B%BSU?e9>14uEpz4Q8b-h8@?*U_AlX>QHbyZlOg|)w{}0 zrMU)Kh4QCt4!tO3gqPoPcV9KhZ1VXp6U014OX*Ba(d`S$vAQ zi9UNCiij9`%y`5QPm!~}>i&rT!1+9huZ6^7WIV^<=YM#OHK0cRg*v@4Lx8ab((sPJXx5sf@4}18Mn&FyP4?MINQZ&8W|>@&O?8?i^uH% zK3GZuHTsCeShm0Wjh?}7pUINPKY>Zix zjX%MvoNDwnhQe*Sw2@+zFx;r}-U6c20KFkx#t0}!KOw&n)n02d)(17vRvq;z`V9Ra zJzcjMmGSI<;b|?#>&JgVjiMcz^3BLLr;;PdKgq4+Ytn^?Q#~Bn&&U+cMc(NU`p6Eb zW|lz-JDO^Uyi6IinvDu#Y(N_OVjaVex~d}|EU`bSqv(-blOSW7jU z!E?N}(N0;cRf#3YcOuJTIm&Rgt5yp&xtea(d%+8D5ZQ;R&o33uup8+Ibc~p#bd$S5 zn^O|&hAVL^Sp@8>>Y%n=QfDbkbUSp;Gl+kf`ka}sY&wD6e-UmTGV0S1mncmaW0zpg z=mm=LELtJU5L2M8Zxf$Rq?6a7LMuwE$QjK=@A5-w_7lB_IlXWAOSg$`I9=?es`S;I5eVnobKn&4^@+)SdM(S1N2w0Xs#BaeHp#s~Sza@M&1;p{zVYcsVkF1ld<*bD*9V~mqMDrb? z0{@<^$DRVIZ#31E@)Mh&vp7utj*Mv@XkamNCRH0*-6ZIWt3aK+grey4$aZfg`Vk%D zJK_(Z(06Oq)y}bIvFXyWNG!Y|R5&;-csB4n@NeLA;FAAMXkBP}uw^(s)FUt{kQn$W zxCH)_*5Jp$%D`t|2mkCqUt~WH1!_ko$(=PHQI6@)xrN@Q8kRpTLoGCN70tv^;s$Y$ z<(=pf(}Z5!A9Q(gtZ`bq7CSBd7EO;%mfpx0V;5D6R#%@6oyt_A0(@^SvMzqEslV9G zR@9+8O1nxYj&mJwu`bm)-!a%;)waMgUz}rVBm66zg3e6iR)AHpm7!TTeHl6o0y~*f z$j{y8S|Q{1UZ@Fg?9RwijbQVT9c_&blLRa<+wItqC+Z1~ey%iS2bAm#! zSV##^h{BalEv7e)_a}SOm01fvSa@K{5l2}I*}B=gI~2zaXHG)ZgeuO@@PF=M$uTFQ zII%U>@~1JLJViGDJ17u0(r@T8%rfQ)ba)xyA;!2i{2}C^S_ymjg?usI%9+9Keur85 zBIcg0HcJPi9ixpS8^hjk`*7<}-C()EYJYeCE#D3A zUH5P|gDhLgydRNa%kvG4v{9ECx2dZ96Y++9p=$)Z*|MC?oNw)2Y@@}pd}YLJPLjV6 z4~%;1OxYzBm%2v=gl-0&dF%SV1g?gQgf@k4>HBqpD8-Br&kM6f#!|!aoAq~V4RMY5 z#guNE#$V$8Vmfl|kc0UhI+l6ZwYCHaaNhswr&?kDiG#!Xmfpww!4}|t=W@6>cZW;j zUqf-di(SJG0THkd-Gv5098T|(QG@XYYv~|X4g2DBd>HCC|H8VHB5R}WqYL?p><$O0 zSEw_&udmimsY%Lfc|mNl{87pazmd`dv1qmEna~KiuV6anEh&|G_gHc@HJlp=C{cJ*m1Rk-S@lLaq5Ib>6>?g7^(IKv@K4o;Q zlbkJ8lPIZ99Jhg$5O{`#EarMmIki(c_dupq0-)p;T@tHFPfgj7R zpx+`VdeImGzWZ#{xI938#|CN;GnReJUFTXrc}@#%zJRcwpU(H>f8#UwEBpn%D8Ga& z#vQ{R=NMMB%b~gN2gScHUP?<+n^_+clc6TRMp&a$)~R+F6fV# zZ{it5eQkvCqkfl2R@Tccq^aua;159Oj+Tc63P%HhnwU8sNkbwPlpV5L9wl{>$Z+|n zH_#?<*2f0FSk77*#EEy?n%m}C zvP_M^Nn?t78Y%1fwfio>|Bzh;42692p7gT?<^d z6M~L5wqs%gVHNXR_4YA6ENV-C8Em3;@}Z@sHlMzP2KMRw(2sA}*=V75QQpYFfxABMTY9ULA?44c9i zLqe!vu%AEO*V=m=GtutcL%F-$@4Tszd-5b>4!ug~jfhc|FI1ARKpXW)xuovZM(N)}d;Ni!PYt3)b{#vPy@5Q% zTX()8)o+yIN}pInZUgnyQdQFvk^8#CcN90-(_CkhmL;nx zol*)WUvfHaGfXeoMbyOjM!lqVS^ZUMrj>yzlGk4uZH-p(y779bl&g(;iqdRF{;=>w z2$&|DE#^+b55h9G3iFQYMovVHXx*V7iEbW(k-8>6&TjbQw?VNofV_oLd%sHL~kt{tw_n)Xmiy~(w4~R@JoLe_okd3IS+GNxbJv>4lPo~QzK1h9J`W| z^Q9&CPOh0$C~<^i3OqZH@Hg2_^i!%_oYqRqnc)h-vc4~#nx5lcEpRmYMr}hL;lHsL zPhwL~6^a#+i(M36U0&^`Vewjc0;8nDy3u?@LbZZ(qL6xRnt~ z|M)&Rr$lrtu}DaEohh`n!1N+QQ6m4Iv_e+1xjml^j_*CAh4v-3G?e7+o;~$*-lu`- zS3k9K{}nu<&ZJG|4GCY8f~hT1cjw!aQp$DFwa@iF;WtMe>mWXlVzul@311}lbN0c^ zjybP9=R^19LV90%7dO{b%kr;nl)a;CPwJQgiwb5Hy;r1P!9)3fPnaV5IT)-kjC1nRHi=r^aCs^T%nnuNv4-z7In3?~$|hD<%VVN5czzH6W!xQGbw z9e9Uc)BaPZs?GER#!2ETnBigiHv60^3*Nv@t!j+a?#q0vs&YXoBdwS8aBYA7Kn34- z-UWGVAjf;bd(&GYaM61YIp49Kp6@SuDt(ynn*01~#?!#*cx!8O8&PmmzUFBuDG!qx z+sj%;3CTQVZq1dVs~g?rbl=jP>lvrgw|#o{>Q6b6m|%a+oAlz+Sih0CF!xMeWlwisJVZy!!jJqI3e0F+$r2GlI;(LlLAG9g967x1$~jgP`}x?+19Zi>CPW;?C(2W{tADoa-V0D3<)ON|cog3^`92@!K=IPeErZKFl7=cB3?oY?k zr=~|fzssHE@2*Uwt_!QpX~J(}5&IG6#1uzCs!)p}Df!1FO>;Q8CPu%|2yZb@UC#*b zR{!hZ&Cq~Q)reC{h?O%QQxAjz4mru3HY#m1@>}KY`^=lU8`KL>zt+<{SA*!xEhCS9 z^$m+fr5vLNS(=&zC0kPqX>y5kIRR^?41PUV+I&(RVagIduvy&wcnP?XZYTCJ^Y|>} zaViSC5aG;4>|#3cC$)hd%Fg4yuoLLfpq?$$T=4c0wFCM-Eua>NWkOT6JX8s*fp4OP z)nnm4%BgT!UkBgz>@vBvvWADQcrj@vB(jA!{GiyzKd>Fl zdA3b1H8C+sN%-U}Z?9q7Z)+>s%(Q4Hx?ru((9UwynqazNH{kr+&h**Tp1LnI;X6{N1ezbmeBxU`Wibishi|EG$pX|0 zCO>q(A$l~;5VuK=SfrIB>mrj^nHU2{m=3ZMYANrD)(-EArv@_Meqp#r1pf7W%HHZB z{6BgJd0WYKG6#BWzOI>_f+KVO$t?b`F+~0P|0-2C2zK>J%VT_PGWM zU8yVar^HU}m2z0wBX^1%4poJlyv6g2pN@=EyV6C>CrqVmg{&7W*PxOqWqV@lVJmJQ zW=>~cQ-zg9pj5f^cKqK)ZT1_wfp!T#dqog&Ww_I9oDW&+@m2T(VgaTWpUdfJzi-7O z;y}^H@8O+n3N@Sa@vGTfRO6IFv?xY))Wg^Xts}2dUD-)Y4Ds0N`dy=hKERkn>@-%A z8{!|dA?gCTPWVc=een0ttJok(RyT!@M)vzJc`|+9Wqio1k$cFSKUh5SqtwRzO?YF* zm8@@bre>YX8Id<6`<8pQ$`JQ#hb#_@*-_BER9t26<}B&Tm-K`6rKP>Fg!IJ91snRd zyJz?c2R8)gM2kf$hc)SY$!-iHM=}3cGMzn>`=nh>DV_MtvELNsJFx?~Bg`8jsIQaq zM~L7TZ;tPlZ*Ryf)l~`->EP@yy z6-tsJ#6UMF*AdA}hi<+jH-#U<5aup?38sl@kWkS)&2rf>!ur@&);e2gz}fkqm{V*+ zRHFXHG=%fX1HBh=A5(M%F}5yX#y4h8&`;tss8jG|vND6>n}|~66VMw;c}ci9w8lju zjiqG7_HL-7BcB7qg3a^nITm*b&)&Rku@1qp;kV&VZpO1QvqyHp0GqYem*t+4z9r8Y zIu_uqswHANlTgI6*u2C#*ml_d-f_`3%DP*4#%wc&Mmq$WdR}__y0>}TN7@CCMWV5B z;cD6+#;16WP}w%lbu!?Ddjv8-bC1PiSQ~X!`Z#`oJt;2Ps z3wl}dU%CdWEpvrJVvhBa^G(9tgig-twvVPXHiMX_-x7u=Jl311?4cS;%S>AW*TZh^ zu(4O8O@&QM$TH?#rpkJE(@aZo?yjwJ^2j90T0Wtm*np~GzR1@l=9_(DdBhQ>2^PeP z2hdqDKM@z|QX`qI+&20IQIoy};@#hjgUTW-Amx_Pe^hQjk5E@C80#%vM6FIu#1!X8 zmiw3b*XP*X2eZ$3j|PV4nqytEOGYo}{urF+y_h*)e*3DV`=F=gvy;+%xo3E4{>@?` zvy?i39%fFk^){bNmYr>rn%V}!XFCtNhbO)}?xgS~-_*bkE$lxpSC9w$OX&HH&%{*c zHT%g_ciPCL-;!R~;+Ap37V)&Om2b{3h*ycNjv2ms-XDTFp0|;PG8r2We#S&<0yCLy zVZLack+3K6T=LQ6nF+USmrbjgYt&G^6`!7(O|uCH?WF^c_*iU;rwez5pQ~@;+uJO} zTdtw#(*kriM@Qi=OVgw+iGt&{*xz=XTEY~x`bmykY8%DPj^}a>j0-aIDu|?aHIHWl zq>CLxH8HM`5_Xm4>ET3O#Eh3}7v-zr$H7;D65+(~h3KXIIHwo00r3n%!O+tBn?l+ubIc`7(Erp3C19 zP0muTaS6I(hWQHJHda*1_M^=ryF!EI-LWl*do2v5hL%N_si&#e7Rvc4scLe=gcpu) zEDuc2O^l_5qi)Omgf$Y7X5&t0xA9M7lO}gx!VFZUA}J z7);u^6{x>5j60HDt`n&e{#6-;vi30^!iTkZ%5iZ*K8ASMn}XP`qb3?@V;uv#+sE7m8D= zGEuD~^2ZSpdM|Aa4l*P{ycPDobQ$G#;`n-5zmxV9$y61yZ(uKcz%%XI5h^K#Z9 z{6*A&pGF{1BXBw}*0(j-J=94atk%<3l76xgJ6E`6nq?W`nCDDM$a0u$OUzFGGt+>M zGPlfSg%YMt!dOs(Cg|H@e`@Q-CzLDoge??T342|(E_4jg5^`?i1d4CR04zs~Q;f{e-L2vHGz|h=Z zvv1_x&g`DMBKvz~zqZRX!}`TSh{Mdkb309YZRZkNCpWQAah?@E6NmI~qxQ((K{@y& zP%&6E(j`(SxDM|yMM}}y)18FMj)cVKNkP{Ur($m=?z6NP2aCU(^K(zAwVGLO9X{=E z9q1Wc6)GJq6I-nmGo4zA^Q(a8B8IBB8&dsI>~*_wKMOm>YswU4cWvyvD&k3K|3~mKKic4pQ zBF1x4f2qIJo*P|=zSJUAV;!P>($vsK_HFu@El4Qy%;tNIUWX{Yf>fZ9HRk&*W zGCR=Hz&Tl*!@U+X)+PRy*vLLGL39-1AF`yeFMLb--e>a854QB53ayrMgH)(u_-1HL zbVd9VvrH_OSkSfFwb<3c+QJevPZx%X_k}fVZDzjKQ|=c2+gI5))&Df`TWF8;m68$H zuOf@%p-*s>xX7+MH@Mn3%(i|O$uwS&t#c6L$g#9yD}lE^AT~`prL0r7Dmw{XU4cAf zX}upgi`KY3rtj@VZ0Qc0t&QcSc+S+;bjH-oqzOoPF)Bz+x1jO*R*$NMi2!9s{ox(v z9^&wQ==-F|w1A%>YQR{58XUiioLdMwyhPNzE>t{vz1ZSNFtXNNA&}p*Bd{_&0KA7k z!gZn@y!c9QHhnt;_i%BcK1hr*K@b?l9Fz*kXGD=Exmxq|n|z$9dL% z+1}PZj2}owj7NSxvOUnzdoem!ib=0x{lh3^lGa8ZYxkL(;(N!)vZ%IP4JuD1>Nxxld(!o(FcH;z z>nqe8b*M55Rn!lZp0TuWqj2fqId|(|!0!pTqucJKbB(KQFeM?!>+mTZ;X-U!dmxLNylmTWRZl`yxwz^H;a?hKeY)lkSK6AxBs{ zHb!0?E*Z)fejPH&JB=o6UGpGYK5GH%H^_Y~6>gY1Sq52Ct%c3&_159E%go=2Wo)Ymsh^|n%5?+PzJ7Ze6E=)1$9FnI*cMja@ew-O)JjmA4|rLrZqI;;f? zhcw@`;NoB%UwMB@;EdZJDI9q3I~!@@UlaO1cstM`u*)6skM_NkD+_hl_2OL1E~=pL zEqj}pVq0f@IlN}=*AH>y&?Mg;H?N01WG(#FEU5qx1IwRkN-vv{`cO&JMN%1F4 z8}XjKQPS8%I^|6AoTOeUOA?p6l3X8cO~nU1D(tj;u~q+z8ZHPnKuoDQ7!d>I?do4f zHxhgvlWH#K$nPj`4>_9Jr&<5B_O>Xdc0!PCOXe7*^#01Pv0$u*yjyM=Bf*MU5xb^b zh&_e_-Z^;rwIN!?-{}+LeT*B*JX9|~Qf#rKYAK~cFcPcbznWJ&)F!aR;|Mp2Ue7-1 z-xjie*$`apefg!TFIRpYzhN3G9yO;Rg4%)GD6CC*>l&6sCC#v|bZ|l(Jiq($tk@>C zuaZxi3f=fv&u32-ba}VKg_VukW;zv)x93dPh2iFM)|1xt36~QqBwchYxBV{cW&Wn$ zfHHnbx2x%LMfrRXppYFX@fIREh|4TI`HknP@_7;+_bNOh21^;(y`<%Te2NB(;xTiMnK7vPthyn|VU+gPPK@(hhl{d^`3S zJh<$5LDX&9q`J~>5IjyrTLmk^zp#bQfE*< z)7%ko)prVstrFi@Punt0edtx(O*TL@AvbE9;QBK@z5CsP|p2{-G7o)8b|0`%o9OmM+E8 z{8XX4`Memj6t`2>jZnG|HfKN?Tb28Zo6fePpMeK9L$7CC($8u0jW5vWCB-j5C-Mv) zQ{R#^=#{8#ZNpAwBdG9^!N#hN>XdeDG3>#<5QX4ox?Wokmd!`xb%f}W@c7__K&`+7 ze{%Rupg^!}Fn@4-Xh3kRe{gsr6nvASwZd^!OeRI22PcFZU?h9$_xZ;J3u_dK8~50I&FKsGB)<&kGv^95Nla_iKO|3cm;V_w|O%5OS4U7%o9w1 z3)T6C$P(4Z?zAcQ7}XLnh6B~-YlRacB+$AgLg`kHnu8eaVbmx83;(fCsM%bGn)EKX zFOwPry+SFPqgznFlYOB9^ur|$j5*>UveP{gyDf&i;M&-TSW$VRR3%bBIwW!ll#`i&{&&Qpyt;L!Vx@8aEw@~H6FsL51Y=%l_<86;3d z#vsS7fwRz?%uoCRmO>xBvep5W=6k`$C>NUo3dpkPyr_t(HCJ?PWJTn1hzzHoRyZ2= zgtvv4g)_pX!f(RAgr2||^nUngl#Lx#qOnBPskR64Y`YPUParqYZe}}Xj32lNC@(X) z68scll+Z^A@(f>vJIQuKo!l8@`C5bc5GA*PZ~740uv_5MwSoHTN7T1B0nsFfoJ@_U zGnjX5V>l`g6?O^Bg?mC+=wcdZx@3B6dT82U$`B}_1V0vpoTtzxk6_NyM=72fKxQF# zXoc=+C-hEHGLw7=zl0Rf&uSy0*p@6t`k}2{4GO^q=sr7u2sQ~s!{L<0M_Dmb2CmGKXE9nY%=Wfux{XiXtB55ti zNE4VZ%=hf?>>?1LcCc4Kj{L-msPn$TUW5zWThyR!Vo$NBQ5(L29mIA7>nlH7merY5 z^oDBeSg1=&qT2m6^NhI)uE<{0W6cEx;3rh?WI%_ym+pXT<0|zFG-E;XD3pWcNFDCW zr=S?F4VrF2%-~m{ej5#ya82+gKN}~FpW$lyQeOdWa24qMc57`kMO~^ERCl2*tjZJc z&f0;6_E6p;x0GG-acQK~3;NjdQU_^(G+df5Etl>|#pMz59l37oZ*ZkMt6l zZ`OMo#XuXL4~6SCL@JL!>3bLqgHfn~XQA-@1GjzPTkXX6RjA}&hwnq#?rdWyXsd#b zL*Ymba8#Z%|De8q11gu+F_XYin~L8KKr4KiIfdJQXp0ZHJ;LL1L0dVD7+*W^0@I+J z&4iM6F&HEo)GzZ<_g9Ac0FB%{=0Nv zDh&pL6ZDxBWsUNUdLI;oH2tDJ5ZCNvaC(mubwG)%0i~Z6+SogEA!Y(xT6LyAJA_>b ze$D|@Njh9qq|-h%-6%(IIS{x9p~Uc-VoITcR-WrtP}w&?{w_VS2-Ro4 z15qc55ipAwcxE>Jntp+G_#OH(bhlfe&>jo*bSrROihy5}2Rhx!|FKtFQ)MXzo|YGp zT^|9RsGZCt{)LXVKdMyZ_*Vw+U~skK#x;0a4>YP6s{RpO{c0C;m>Uml>tvr;H zFTpLD4^C@M)ul$1N6LO>5m;21@5wrrtu(HfkFy4Y$z*Y&}2N_N$>+L7-*yz^x}ZJ8i(is|%HMNje$oX+P%92XOB_24>42)NdFWdO=0o2;)Op zsxaJT8A?L^20}x#AkbgltyA`CGDd74whrTokOw#wnbvS5D1dXXKILk5=#{YsV zya=?AT53ZeefG->IXr6v>wXW%TOSFAp;}~{oq+_#*R6ODhMuUA@up;R4H(3szXIw z9n`eyV7ipUe+omN?}EE=KHN;0Q*_dYdE_;DA4==Lpi|$BmY4+xtUh49RwYY=e@jEN z{t5HW1!4=h&_lspt&Lg(Clo-hp?5xrUNaGH^3Cyyi$Z(NfE}Czh4o{s?v6uYywliV ztVKJ4pA6m6K(tnO__Vi&4|M}5jBDeHYlvGTqcJKrI$)+6imPx2yl8hoKXV@RB4}`+ z-A?`LWCL|~NA%En;Cyb5?+54e5UL3tf$;v79smFS0~TnWzq%Pzfik=nnCT5csP2SY zUoe-t;$yfd!X6h7+(o3{_^;IR4&Ye#$^a@Q*iqgG~&UaT2F;ae^}I`bs`ez}HtU{1-gm3kH4y z80E=$TqUA1j<+U`x(Hg_jAtd_`EjuLGmsI!gZDX!{_YA)lq zTlo7R+F&*Qw=KRiz6mXI3{J$Gp}_t6_?uwxUjWPXE-D-zpylr4>tB50cWBG6$DNKx zWZ?IY|6{fLaQ?mc&5d?=hxUGqN54Z0-N$3T((JwX{MjJ&e|1mbz{XKQ_V?kM{^}o* zj>mlY-?Q$c{aDVw9qK5Cc-c8^jZt5mY$*4rlMaJz-yR5!Vlrz$2cpe z{@3&O;W|HzN1nwwzlhr%T=n#McbyQD4FV zpbvV<1iarWyl*F*|Nqbf2;wp^EB=7kMQle8`3nrFa>Stc0gQ9`h?C&fuE1x>6qY#h zfnENGk{7Kd-QfzG8bQU>8PVFQ${eybzYGqs3~^R3r*1P`R2!-!-si9}T6f2fgR0em z#5)_kwO#t%_%L!4(IuXQ7Th2A5Y;HRQ4THZq(0G=tgPh|w^e;7Ef3!?3?|d`h5p{y z1r6OvwY%N`uXr(j)B3NuBtL`8P^U%b#vZXXp?j@j9>(3){}$MIT{vMwpN?!rQ_~u9 zt;*<8q74{mf%rAL7*$qmDy*a~U~Lsq{?&iwrt&@5#elPTq0OBICg3e)GqsRy!B$|# z5CLttS|qkto1>K^OK|0cDfH_2YxS|V+KABK<6Kek_EZ5ZH@|3o)%URrdPB`;a7Ibm zPNW)VRT}%7{YrJzxSjH^N)8Bcy_M>*c1lz&9`8n;&}zj?$Jgk6Kp0ynT~`ikm64N)5cfd# z{#ALed?gXF zR?GOH98y}USM=-)T2wM!k?c&5WkSYS)Nc+k zMjMB51-&A!L+f=8_naeJkm*L-_!!j1ok3mpO!5xujxIwbVIoi9cyAgt5bv!{{zDF- z-=luS6F)^104Z());G_Ie5i8mPu0S^R>9bx0$-2sh^e?HS73EHk6K7PCZ|#J$exr5 z*Jf4YBK%MOAkFYJE=*knyE_M0_Hyi&KfysH0d04Xv|trG6l-;k=oC*f=4i)2mrsNG zI7Sy|F2H#{9&Zoz><7I7s$ebDQ84~SV~u-*7>dui4pmxbjlB3oGKrCK6;?<6*g&!; z(GVlTAwnkig4$?7{m(S)3U}ewmwb-?zaAq~0;;k`q5jMXa_>pxMb;B*wk($SlM?K7fhXg*1V9 zw*i%sK2&2>LG71-y1nc0Qy@`WLSt1qooGt7#F*9}%)cHO%>?2U#@k`>!dUl2;1c5k zk9IHUgA+&}Dvs-bh53EI;iuUR4v=?oA}&JZ(jU9`0HPo2%N_ zlO~gEkqhzLzG2Nq#@#_M!Ly)W(-{4UN4QO1M4H@ZoK69#6zf5sSstTsmbgl6!{~G+ zk6~;c;jUN@eVq9?51Qg$suI(YIMIN7O9*5g)C8rVcwS~Z;Qk(mKer>L=RcR)LQrlj zM6U2a+?kK!Tr;um(&5yui}y7k_9GcD57Gr|A={`u?zPvE>ew827{h9dyZbg&>J8yy z-}?Xh7mD&Z(Vnd)Xs zHX?39fj5mfMdiV&nTEAr9{${Up^YSPwylJpUjlxbr5KSZ1d=UrV=&>mJOd~G-%tu~ zLmp8J++P-91@43P;w$R&;rQ$%s4z(T8g}Bz)>~9bPptd+a{}s#PpGBVn~A8b?wGHj z9+OdPw}P8v9sGV6O4SqA8`L3_P;;Y{D?Dk+~;gyqiJu4#| z@Ee)W3_sE;RvWF1esGD+ZRRwenS0F&c+C^=s=u4_%#{C_;c!*{FEJoqKL)S=-bhL~ z`rmm9ci|1jM&lNKpJ@c(ZCY2a4hPMH@HWkdKg+|l^{QSU-osan-{|d>fi5-))qNIo z1sr_;m|3h27}3da{_I2c!A+qi(~7OfMY%WpJfWJn6D*kk*ft-;C1O4?tGH3fDcs@z z;_q_(xo_-P_}-s_4&psM03EDD(EcXl)jmQ7&2prlRD%LK8J_PCFrv9|uJ*GNv2$F& z6Pu2woE>`I@Aw(kn*+`C=1KSn7l1$OS$zsT;iKr!l-2(8nMhHms!P;y>KkBnU)4rh zYxHYIX%)49G(-K4j+Ur40p^-0ACuyNCC-($NDbs_K=)2-f5RcYt6h;iMwMhXu#dPp zd^KSvTn#gc8N|23A)&DFj4#N4;F@t?klGMMmo_gmf_?!+V*}IIBm6om)dalD^nY~@)GnE z>M8Ozf>68@Rs(Yz+<+(P9(}pi2pR>@Kr|Za+TW_Eh1FN81!s*k@NYc@WHKFkbtjcx zN>*h({0=@!6Qy#}{n*-An%EMsr)EMQVT6~3YXi%@6`mAn6I~d)FZEH1X!rFe=4`tY z`GQ)__}G?w4M7qXitogz_(!}6C2J`>K_6cCYVH$qIyN!C;IuXY_X`s>&0ss#DhTI= zrs)0@HVdLNR1RzDFRc68)ES@&`#7Di4EIpmvE2FGX*vVWVbFxmb68>vaTuJ_8}J#B z4Y7zD$$e+rviF$)U6r0jZNdyN;d^lc7Ac*pi)=)EYFrtOI@VHU~`RwawaLVDSRrF68aI`6nq>k6}lB# z95y4pV|8RnNv8|Ap}I*wb)K2bz2ggtZpS;vLg!7V&-vK7)map-eYB&Ocpg5OL-|KY zZgH|DnZN0VP#f<;#ofx@hVIUDbjJ3WJ~KboUooW9VgiL@c(jWNYKLxKJ5|2Kh7CdJcMoL3_iv$Kb!R%r7 z!a4NSoB=#5J&}uSK<$U~W=U==KV9f0{uW0A=U?uq;|Pj#L`j$`)DkiYek5glf&bi3 z{u5Mr{V+3@K@qeM=*Sv)W}HCJ_c)xrU*JxX1?vCpI8}9YK>o!IyodD$KHJx@7w$sG zc8wNN=cz|nL=J>=hX(*%O$xS!f-|Roov(&3m`eFl zQ*Zee_*VsThhB%%!8LKL`cZFUogki3mw<|862^$z9KDw#+!)Xo!} zK1U%(AMiZ(@kim2`h+=#=T?B~gHt0BqtXE@AVZC+E-#%%ne=e3n76Nn}{XEzr?U`LYO32{4Qh-oX7ri2^HWE zuJhw$=!mh365!wSB^Nu z@8)$T;rD>!f#ZZ@vtyoPj-#C;pW~X?M@%oy!ps?oq?x5i-BH==NSH!i8`jScW<4;p zT+AhC;V)1XasTZBe#<%ZuWZ~pYvR6w_JJXR+y34D%>L~DdcH}%L#cgyi+neHtpl+@_wa?taw$Q5 zp!czwk+bPK+zx(&s5;I#U%H;Vs=I5t!!Fi6z~z7=813xmSO^!x0Dp{|%4TLZPzA{y zc1DXc&jROep?w3!G#gCeEJj~5U=1erkag%a@D-@YR}yB3^)L%%=Xuvgcg+3OQ^9*3 zUYUQ#WrP2%?kVOW-D{nHJ8lT;fY=s>!_XnRGg9p$6bId5Z8-NYhTlqGIR9V8J>xey zn{<$m&Xr1VW$h&Z!umo-V zoGd@X#v@%AaTi|=u2R$npyQz%?t1mA{p$41FX zS{d`4y?_cc4f!N-n#185;rfp88REL;TI&>?qZ|#vTYt>=;Ci!3%t7iniS#k+zA;2E z0*(GiB_K!SsY-Trp_ao~Wp1)pkmcy=Y(MTJpHqAwUUM{YK64&$O?IDy)8v7;NPLz= zJ+Vqs*`#lYXA+vm5B8RKXK>yS#&CO}63B!)F$HJ?jogDC^gVhUR@4|K2H(>%^bz>+ zeENSf%q?m@d`~yP#po$Bj+q8e)NW)h;-od&9BpjS*J(AiK59nwG0xotcnBmybE%6+4l?0b~8P9Vs;Zu~3J`7e}I#l2Z@V+a7M2qQgk&viKK;JGi zY1j(Lv&zAH;A?)KpUtlV-+l(ul+HoVL@k_*n_FQpH(r~5qm7XRs^446S-GRMHFhF0 zJA6Gf2At!kc=hdry@Pdv9|J9d@qtTlAgkdY?tkj<>gx+X*5?5;SU6Z8PrYLvJbA-gMf->p1JLWp;p5tDG(J?SCte6q0EZc2xtB*CSl9QQQTyb%Yc+_#hS>IXTH4dJTmoanxiBscp#ZQS3 zdq2P}bED&tSOt2Di%5Z*O`l_@((B=`K9H|1++{NfUO2rUVkf|DpczK`j1Y13c8qjY za6NXecXW0v7325~>^8a#Qw_-xZ{Tkfg}2Q=)NtUm7on@aZJsjMn7MJ!u`G(H1XSuM z^ylc|gRNIjTPx?1rbL|4AK?mND_l78Jv2VjCvY_E2_A(Pd7a?$KbgYyec40FP~1<8 zjt@1F28A|6o`l;6_}~=3`ilsp^T$G6lr%~S;v@Bg_qqzW?t9XB#yEA?QZc(@vbd7h zgl@t~zKrPOiZcYY%#!u0@|M`8aE8$O$jR_Inbk9!$B``mCI&)UWz5J)ho5ziYI$kV1{5i>*{&@Lp;_b6%ZfC=wWG zd$ud?Q$Btl&p{{LiY`PRMj{q4Y_+y_LD>pCb_MqEj`Gmh-e~D?i*T{vpkOE<`$5bJ z<&v_-9!m9iugwfTwwLp z^xQPZridF{=iReB^*w1lqW8I{ipTVPahLNLzSH(A*P}DQSXvMqA8r*( zRo)sgvLW}!QN)`k;b7v@w7b&QN;@O%$~5Paz9d}qmUGW`jN%?rb?k^%N~ssy5IGfj z5j}Imf+YNVi=USFy;g`?&~bah=u81-avct5Sszv?aZB}#gugOpRr9jh8GCS?pAl#csD zk@vCjzACX=esB1A_`AP-pik(v&+SY7bMe=c-*baoG{5wYT*T+}-cDMY=6<>-8Td5a z5<0|(!k zh0bhsT)H@?XPCPTJOw9v%6goh#hx#ow{E}VC2vzzt$bSh*uRm{;cVfx(Qrk)wD#y5&j zadBcd!a(xKE9D>kkW^8sB^5>Lg%MpGObm|=Eew>2t&|fiiA-j*c^f3YO)>NA;P0T>GiSD-D#~S{12ijDz#wLsF+C3K6rAXzvKHH=&j=F4ho7`=qNPsl_FX`4JUJqYtK(TFM! z2k7h+LH`T(@)C4wR$Ft8Vn$zMg;7d#nzi+r$kf@TJ;lAYg8oijA%~T6ks-?7=$B|p zv}ZJ4S}Jvmjsf@dVCcSlAn+!-#kczBh+h-GxBfCSNJpE}gCj^NXnl9yt+ziUT27#?mx6FzRh>A*Q?UnXQ&uA}47V~FzmE*4K zITXj`6Gtb`N+=dr)g9pl1_+?mSgIA?9sD;i&^OHbrA)BU1$Ep`$mJ$o2h;r>C%8I_ z`P{{Y&bYBOX1Y+V;WB;~KgnXcrL!TFq|2z9Y%yX2bC762ZL;f-W0~ySd1@Pz6;AW1 z$f5=AXbrWJ&80>+E0O%yx(1)d%-nnWZ>kjCoq9q(Ftei5Gg`Z+9sw$IS1YfoT3gN1 zCL?jGsd`1a2VbC+$nMCH(A9_$y&;W=os?45Rnl|iTJ%#mV_^5MHb1P?o8d00YojxD zo*(LZ5chYw#~G{s)hS)CjP=tL^yH=U5nuIX(fa}KpJ9O_zJ7tH;cBsbQg?ZUe!whk ze>HOe6UjIxCRuYn;ma4yU>Ubpw_3 zM?_xval&x(Li{)}V*>5%;PScaI5YE)A@Kryn|{a?+6cBubMmZRhU|=tt1)IBJ-@kC)6A>jzj%OmlqV}# zFW|>l!icE6u~TZGkCtXC7u0QXN4382A2AcXo{2?OE3Kk zADo5Jf}yH`TfRr3j##I!NV4iZzFQfqVD6aj^2#k+i{;g+1vx2_xajPOCp=G zP8l5i5PBQ*1($`2L)$b=nQr{H^D{|&0G?!T+>O0OJCN^ZbU*tbNvt{vv>U3XcXleet?`5d70OkpKb{&p!TD?e ztfPxj2c0svdB?b8j04h{9yu7K+DHlEF7P*!Q69(A##$ps=WC=v>~$m<6W}M52uA2Q z^}Kpo$sg+!eHD1%%l@ZrXjia9q_&jZ%*9O;i^n~WFO@za?e?S%Nz>v=I@VHK?Ed;e zxiK_GgTju;Hmrguv>zGZqdC-i1y+x4*C+4Lo0)?`amR93dv{Uq;kZwpfV+t^olt_= zZ%6gLiWS)t${*?zvO~XXl#?x;r{E zYsu%9sy&tR#P&qzhQ~%LM+B*^+*2y9>`@CCh0Ik}dSXAB1NWMW;tc0F*AjO-Z)VSC z_W|cbA;jJxhg(DS_HvF`!$|#b(MY<;%jjZxx%$FrL=bdqc&aTH)`*K8UmY*RLgE2o zI%@7U%wT#xRR_!<(<*H5L$cZmyBU}T=a6^yhMEaJ>r|-7=dps zx>MzafKZ#yh1x5Dn~B=`8gSHlglXonZ1mm-SRa71_9R-OVv2{V@D%Q@=a4xL&qc5f zSAvzh3Z1Ia}d}P6PQ#&Xjcv-bHTIoH^Z0OMaL(I1ilO86(brkQRTHfU;?Y->T?jGhWFG@^h zGTF`!?Zz7ELG*XbE0vTBs)h83@yR+)ZbP>B8FoJs9gYfF(Z|b^{UdwT6Z%V{IrE4g<)do6Yj)-0-j z&{i6qv{_mUeGbw=#rYCk3AQs6LdMAp_~Qn^ z2ocN)Mpmn$wau;qR1(}MJ7iI2%&4lL*PbC;be7H-&A`|6VUnY`A`OR%DIbT#8pX`AKYV_;hGa;KZM!!C^I(e>n&&0Ki zyY0#1F5sxaccD{=RMTOk=*^&^yKWq^eiIz^4}FWiibTBPOl{y5ZNQyeVS8@PywV#lMfW5`jqo`_hFEFc%GK;~);LtJo2w&q2 zaExpLoc|nn!~&ET9mtx%YZY)~;f@X_Uny|tPT4tdDvW^-@d)cLFoZXnFM(9|0{WZ; zCB`G;sd2#=2ZZsJUch()1wmH$l#6go--H~zN=b6HI*|;CiSrLKpC&} zS0*bbmBMNvH3=TgzcpL0VnFL@mcm$EhAVS9aO~XZt#5@x^kDd&3dknEOf{p^z?J$g zQW9mTCf@fuWQ$$3f@An!gPnvPsaZu=BC5_iemTCszkbeLp*a7=W7ogZ}q0jpV zH^nA!z*=M$F?p=>$!1=&h56hx%zTz+GQijWJI!sAf7@iMuodPL2mGscJy*Mu6wN0#4;XsKCOffDGivJnn>_%>pxr0i$a> z>i5<7jw6tjJQ^Qe3*7LrU~P7RCTkHfkyr&T@M`?M5MJbK(2;0-LN0K6?h9^;NJy&J$0-x7Y>7$ao;=|M7e@uzhaZarg;8fr}bKuE=Mw zoRjd%OMrdb9bJTlc;$ET^E$xpDTn!S80_sb#6h5>H|-o)Uthp?HNhKM1+MD~Fkrjd z1;Dm@j(@$x``Y0>Q}C+`_(%JZR6G^G{$Nz#>4fpyj=7hFukjJzyDpgF3HYBsV5^qJ zDCUJvX)EF@xU|2abo~x~;Bq@V_zF?Hwvk}pkHFWdgye$EU|P*XFMKfY6&7PP!Y&9F zbU}>MHCrOq+Uvk9>w;&y9ju@u_^JAVX|o6n6%%vn1lTz)%$j{xC&IQC5L2zec0(|E zQs5q)5gyK&um-l+AHj}k0q$-pQJl))PIc_%s`7p5>y$HA9!ZwhrM0n=$_nBhIst3> zsqB3ADp{L(PBkRlMrD1RF<0AaytbDi0q+b`i?7Ae%nJJsG0C=&S$oBfXFifC^bcrn z+JG;28T>Wfo@eX>E7`L8+FyuURCTD)J|NvZANCKL%uCb-BV-7;Qm5@2)&tu#_W>pT zK>ZRg^YdJVnB_n+uhBQif%ahLyP1fOxJhoXcbZ4YCWOywXlAr(n+5g!Msce)5Yscn zakz1OCi){YfFRD`No}@uYbb8lZ;74YST@7z*$x)Udux_a#wcn&wqC&Ly9Ig5iYJE? zX~|~HT7m@Pm;%I=Ft6DyajR%;+_H|?-_2#V%Cc=5Vu(b=lry z&9U#m33#ow)~svQ)BU=xpTX{{m^I-ewaxlq&NMq1B{T&|D;G6F?`SkOuj)Kd#BIo( z=x^)==6wgqcRZ9~7mOpwX6cSQMJ}+iCPQ5`6*{!~P<%?}59BrsH+vYT&40n$DP!KU zdw_wO)eIwR>LR$aqb#^%VxOpQKOy#m>GsL|hVPsod(#P{7y6AZxQvdn`eSaD0K0U|fzSes0e@f{lAb2nN0Ftd*mJ-;J4ftAex8%c zK{O{b1Dov$2f-|q0S4?;VjWiR60q;f;KVwHbH26>hb=o3eD4nvYw1Mt6j1%2Si$ei z+0c!Qr)rZ&?MA@){cuhh1)kzsymC?NgcCf`y4re>PH;2tAcI$4PSdJk%`=o zpEd>OA`4u9CA`Zg(!ZeOJ4#WE0dJe_;MFj6XL1vGrB{&|@EeJ6x5(D?PHG(Cq}I_x z@%{{KC!9C+fy-pV#}>qi(*(NAbJSBHb}95a?Vm^rhxEuos;0wpfiM^wD57U4g^lcAPLX%!ft>q!64$VtNCp z7rPiIj9O+NeDo1xhGFX0j9QqUI!>UroD->@1xo;%Cc z&0^-QvkutL^mNU?Nn2@>0~MsiV~N+AA=&)+sgQ<*00 z54IRm8VkVtrk_wwTq+)dTKtU|=Lm?j zTaD0iYJb!{>Q1#Ox>zR_Oa3gUlP&4H#K=FSt8gT`juF`>t(O*ITr|m$L^-WoOztkP zmjiM)#RC4;7Jjs0eIv5tZPg^mtSohv-Mh!Fhv<;0~VsMtTIYTbdxlyEqsGDpJ;eQ9+7^7PUTZ z+djBZZAJ22SN0_9f@}0DcrFQiCH`N2IqL40XU~IGmwcudE-LwSltU3rRzbndUCA;zwTIO1^0SD+=Qfqiwqy;bPUF=-!B7QxFn^-n! zuyj!>0soaU${VGh>O%JULcJ~cl$yE8YJxk3l(OuUr~jnMZR@0#f4>gZjM@5|(#!XW+Vii`&61;JRQ{ z{${shZe&B|>2#(z5Q(*PE4%_3jQx33cXZG_)Gw%3?_y@&C9jiDk(~M#HCsg}3ig9{ zPyw0ir=Wt!30BWMD5N7yZng=O)hE%R`-b%n3oiC6cC3?7*^ht{JwF@*-#}Zu9-8B# zj32t;VVJcBdU7dLWs0V5q7uvtmB3(hN9IA-ISESff6T2$0{RGVv=_vnnTq+eJMs_-_GP1s(bb$0^GK_uQt~HxqS922se7Sy5{$37 z3$3>%;bt8L-|1htgJfauFo|G=&WC1x9k`KSkxVub^XfRKfO~R@Ysty%1NJ87+fud; z9N1?eU$-Jt04bC2;1Ez7oH&h21*>`yx+HZmf+H}F=l`El#K4WDI@Zq=>}XezT4R90 zJPa%E1ruW8*-S`Rsfw@IlC6nVUJ1JVifm!*>Ur3#aBgt18rIjc zNAW!@Fx@@mEL1v8>?0Npgkrq$PA`H?)wgO*^(>Hxqu7fcgN^8mjf&-oeTuGyFUO4N z|ucDoJy62H@T^Og(BxF5oW6xkkd}Lq1Yv8}^hcBCSlJn4>!V()LD$sCn^(MTPDz(u-@*_X)qhAV0Wv=RAOo}L*YuW3$OnM zR@NovG#F8z;M4Gsxs6?GKT=p8FmD;2^}^XAD;tMDlkm4NcC9J+`~-O4Z^IqzH%^^o z>M)WPm!XR4gETt{mH0MviMs2Xz~)uev1&b>dqb61@=m!A94VY~HZTRQNb{sWI0KtY zH>CE!x^qbLr6F)=xFWrl-0~p#IDDAPC^Hnl;#CQJohw>9y*ut6Ys?jJRX>SZs|Y$U zMQ|UjPamL@;W}{(dq_)UlwV+fA+h{7?hvn#QMZvD4VRXJY+AT6T!A;i5@s+miApiK znXmMHjMieDffbNHok+XEE!;p~qrcE@j7?o;ELQd@{)=6EFrI1Dn|No5p9s3iz9y=Sm z68jJf#j=1S&{WzWy^{WhF4!fMm1C*-Jm{ch`OO0 zVG;kLM>7N`)ll@+60vJ+!#;l>vtS$clBvvSyrOQ{(P{WKDD-dK7j7WWZ#Bm0U%Cgf z?7QP0u@5z~ zj5Ws^d5lMT4eaVOv?QS0JHUK-gIi{K#Q`>KK)xdHlNZ4QWs*EW9t-!Eaq>WUIFOS= z@*_FF(j4RaQgOixV=U717_BmN1nZHV_X(aT)r=k3-?yS~bj_NEEk6wcFsxi|G{uT;O_4GKy3EU0y!Skgt)?yjVvpMKFZp6vB6DhLSSsyyJAK)Ov zBFU%{@+Nz84Y^KSdz_~wxMG}%8S2N)^gdj3PO}^EnJeHa)d?e*fHUJ2{6@|&JF!N4 zqgUyGd(0W^a7~bSa}hkMBDh7ZBU=F>ScM+S3)FhFy%~z3D0*7C%@@XEqddBcDS94g zDIUSuu&MSGo!L@q5^}IEA$M7=<3D&eG$fsc-Qd?{5wnZU#DU@pak@AW*_^w?Gva34 z&&P=E#HwOiF)Az)$_jDN7IZ{n=XKmIGjUIFin!U8m^=B9fxd|L(woq=y@Sy((SiDl zxPl&CI{PP-L~vpPlbko78>5U8sIGqM2jG)GOwXyS(4{`pc4#xT)kq6$0oG?7aCDEL z154`9v|I21>7-3XHp^G_s=63XRSnc4>Owe1-Br55OK1ras8yvWTz4vAOd6;!)aLk% za!{;}(dz3p_4nxMW;bxJ2}WB8x^nUJ-27UQ`@< z+yqucFPx1#xfnN&zXp}T3#d%`iuXjp(H9))>@Fk7>OV;@hgB6alPvY|@J1(jh{bTFC$TknKB&^hbLi)^Nzfrh;S#x3U!+e1CvXJ5V_tlO zq}7BnX_uZ4`IdW)`G$smi)amjnzM=35fxQsy8!yFyV3Q7pD?=nO;OoYr&YQE_HYK( z#XPnR_YSURJ^ALS-Es(Lh4o@uj7)Dw8R!ufV3dY9;vJvG#o|`fdpq#fPz(yqg-`rN z{te$BsZv$BxA60+fL+T5M|K4~TW-RGB@Hzl^Y{YvVt>#}9D&Z=SoG<}n7L4$B!UAw z%%}>CZ5(<=qmZyTNUwtNc#8zeZ5WBy+Ij6BGT^==1#Zy)wOUWKv2c_5g7<~BtMDiq zf!~uc3b`?Qcd?H8Be^&aG|1yHR&%xfT0f0O{zD1<5-P_BYF_O*zEUTB8ET>W+5x?^ zG2UpdpVH>&nT&Gg4xn6x;rufm{p~3hkCeoW_7ZrrEJp{s3mgWvL%S2U7Xs&5hw;;( z;(3MS%bRdBO2$6cjrq=GhKE&2oRD5_3=}I7aI;($K~A06w}1o&8K^JEJ33+cAt%3Ap!vKn^O9 zXYe1ZVf76KU#Gv7iY&}Z=)hC~O1QyH#@&}kUTQ__hvl}rT2oEgoM7Fx3qT9B*Upc- zOn>_XahE(zjf9454(gHF+zH{VW1LeK?;?e3Ge-y^{wsF@_aYBFgDs6JX%iCL68VAL zHgJh+Vfu7hPt*&yv5wfi@|bJkINC^Wr#ChzGo3jbPw%rn2LHblUthzic?cTCMrKy@ zySKm>=QFrG`>bQ=R=DAo;v)LmzpMpTL3Ay5qFetPtK5UX{g`MyOM1iq>&UOvA-WT8?W^ch@YSn`+NcxL z2IDFGpWFQweL^Q(wz@J|nH{JOR-^8`h7843=pa12l4g*zk0x5887)PSJ+z#jNbfA51 zJd6Bb_bk9%l+4D^ul_Q>n8{GcbhNITz0EGrbY6$!l7eT_5C6V_XIRr}3=WtdULNN?e7bRd1v^i&l61+0;A)6l)jW1_PO;v7U-Ub#n$^sl53b zHGFZ?iSO0X`e}|bb@Y_)!Czn&Iv*kY;f^FDCt zy5>yGr@TO{XP6nFEnR?-FNvLTD|Uc5^E6ndnK7qB##r3ACqaR^$@mwDVkdJt_Mpzj zKN#incpj@w4qb!3(6^^Ue|QE)u?SSyH=v|;Tl-Aa{Ay(Z>wQ;I)qsl4=lM8RRA4}`^d!3g4KSLW*9H}OXI+5 z7=#(?!i>F#JM|TuZRdbr_oG)}{Nhnv?V$6ZNBf5=OxFS;o6IZ&TD=x4q$J$4Qm9o> zA1F|&Mu41;#oR4K@_34kF?(wQjZ7j^frtg6sy~Mn_ZPCzH^W0Iw=KX4Wfm0HOE81C zp>OH3mVxn<2bimB_Q9+=4~%sSYPjFX%gtrh2Ffwpc%gU2E_obxsKYvG0Q@y>;f~W7 z-fB|~4@USU?mWef{FuKx%s0k;-EBB9W5cFpY&KXZDGA)(|H1Cl6#K(Q%>Vo5O6*7< z&HG3qu4s=!>T?@B@q^eW?!e`!E6~Cfb`zW*O@UZUM*6D)KjUVYKTdQwOJlyCBPQV0 zw?I#%3_2vws4l2l-r$BjA3Z&EW0;yOfeN=hU5u%Z8}$v=2WL75cHuktR}b9Gm!U?> zj+D&{n9%{a70zMmF*dy!wOfDGDfcj!A5ra~m|2SM%Q(11bO6FV5Iw$GSUV5ky|o3J z@@Bw2=HuqQ*Dgeq#C_#0kmE(zvpXWkdNBIuPk}Zo&~6t3y7K@E{=aZO<;I%}b2el= zH=1B4zG##+4;f>P14d!2e+zg)Ds*B?jVkEzRWfY7kdYocbOZ=nUULvo0ia9PFy!7| zK!2q$&ePU-UgN=@x`|!j7j}k-c@bk-!m5p1USar{>o^tmKwmu${sC_6z_YDvU=hKw z1FKJf6T}bV8&QYqiHdL``2`H^XXvQ4#|&+ZT{Z{mv>VJ$j1M{uVde z4yf{8>U;He=qh*8+v_h;A)8t{>;^wDk42n!vvCqm0(R3ED)a)zTBwtw_|9EXMgI%z zV;YduH|7iALWS)C@G2TYULkonm9$|7u^w=?((#k|n}W-c$8pLr%Q4zf3MxmztvC@a z;&|rx;>_*5;EXvwImS5m;11v0;dYjAY7W6!9(UGf;(c+Jco8+vB;hu{h@Xz`m=6`@ z9dK0M08cAO4Z*#jJh2}8L-H6ojo0u6(6sV!l3k(> zS1TyvpkfQl50p#t5c#t-Q2HTtgHQQ)$&@xpY2*%aEvW<$hl?^T@0Oa!a!9k}j&dci zluyJ);Nuv1h%{8nuI!XMC|S{8_@SJV-=b$cO8udXL1ldyNw2E<6>jtC6%p&cpgBNK zi{#f|L~bPMEd_$`l6lKd;CV+oF$m4^B1b>>FV`&h0asSfO;3xssJDvefoBWy-`?Tx zL%sPt(>#5gc6>OlrdM{Caxk9J9@|^oT^!lNubumm)lu46)!`5h@ePYV=dG zA-Y^2!D;P)B>58Xjd+AvQ^BYLzt}6<9ytAGQnRRqfk?dv({?I4D|6(s(paDeP0?v0 zq+Ky7dNlfPR7D!j;RqAGgkE_R8royw&f)%%jo~7ZyWw16E^;x{Heeir}QYai97Wqrj zl&#ovx*5NSrQ~WJMHhTJ&Mcj+jk zn|?wqCXa(XlZ52g@n%+|vsMER`hO|=<#y7tXqM>UNS^SCP&D{F5DY90vx5d}@ zyg?Rs&{fgZ&H3KtcD`_w6F-aFg>U>X9>_V5{7du-itw-4oNONMAzP8NP{-W_qp%BG z7Tu?gcruUSRyV{hY92=&zh0f7zLN7u&!Uy0N5GnABDX`8f)4^k1GNKPeQ*6up{1S_ zWPI-fYO3HLTezt1LXr#gDHWJ$dFkbkz-lg$ z&PQ$oqoFi65lMCHq0{wQ#UlJ(Aam8&np^}O?u^VGG9O)^tV~XT8od+H za+63yr4pyXqj^N_q3hBOFt>gQyG2I4?wTeJcInRD?zCRfebw8=T`q2qrvX}J) zUhKyqW4B&hJ1Ga{Sx9|d9Qzg?9@_;q{_p6xP}R^mU%JSz;E+(UK<~gI|Cqo6=&Wu9 zJ_UnVRb2FNa9*g651?$ILx2xm2@DC{4)l+F4(9}uJrpaeCaY--Q`>KCwRaJ($(fXg zeMWC(MG`a0D!zzv6Pk&7N*!gVasu8Kx22x)(r^W-j&w7$FiJ|Vj?a^IUe)ArcgM3V_0T;r@wnOqk z4!$7Q4gIn$0xPtEkIy2m6uJ@*fSz2Sw@{^-d{jDm9}w=7WP7l-dXtrL_i*8vuBWds zNjSk)b2){4!c@@~JBy{9+Z|b)W8nky%FzXyj6LwwcrQ*D@(6FCP$|uaf$rr+ucIV- zZ0~Rv?*-*?Q|z?)h~vy2ay8f!qvw6iVZk<$yS@YCk zc8c7}IHP@+9?D&<-{Egcb+t*D(CWx(rC4;8lr#DVNM#1e70!nQg+9?&(W}zs;5|Ud zqayzB?nuL!8SE(Mi@M|;v0kc6s;~Y9mzA}~TYU+56hrI{cs`rS++e)s1)?>MJ%+B> z8*ojkazmKQKtO1wD?Nvqif4C)ng)EnD9~#jif@We!A%_WL6i>l z6Yc0CB^x;~vVmAHn=y1-tzF?ZN_{&av`V^S*A3rQZb~K8w~;*BnAj%$Rj`nLQ93Er zRNqCD^sSMlMmYA#NCx&>%iO4)(?6M|jaKFWREYDyOC3YJwa3xJh(Y9VwkrI+b}=^9 zhHi&fU7fi^RiRU9ohgs3sVdB9Xu98!f4Iu@6=Ef$)9a~$Tmw6&c!!GP7FrkGC#eu| zhVRbTCX@M1^jfZ^zz`{XDLxI!i(kyN;uxj~^VO{5E=lCK-U%6~4Rl_$TJkSnIia%C zlok-L=setPrMuQCQIreB-U`1Qp8_e?*Z8r4eQ@0yD}4XMIc{^u1Hb9>#1=D0G{U_B z5`8$FW344mB*oIsX-f|W7a+Ye}wkDd($=<|dOGbvkGCM*R zpo-y6Iu#u9PxO8!k-A2j+&f^fE2wLxc= zNoqjl(K)({en%-U)(w^AE9o<1JX_0hhB}D3m3fwKc(j`Iu*f$m$@rrcGym1b6Tj3w z#scE7+81cXJpFGgm$}BuZ2geVQncyT*Ahuql&BFa!&SHcR_hR9D%tF(C3Ax}Z;eGbpN0rKqvFa&d>rrRq4(1S*%BiLd7PORGu^5(Y9o z=UQJ{k(r>x> zzXm7Iiv7i33&=@5LT8x--6F>`m0~}LDfADraO5J}fk}+s(H9C=jOvobn8bKU#XWQeR`PSq`Po> zN2?S%oqdSzdLzbVNmL4V+P^vdjPNBpyEw<1uWxYX(;wIwc+o0DP1aVWsrhGyYbYD1 zCNY1DOO^htz+8o++y*MeF*khBv(+MvxsGN^1=oQ{xr8fWg=|Ylr(L3<)%) zKk|d4apGP1K0j9PPusCG&JN0YqOXz5u}Q6>%#H6FK1jtGCtSsRI~?E0IgdSkfTF<15S))2W!x;ejVvd!$;jxxa%ZU=qYJ_Cf} z4%=PXCe)66cl5W@naeGOtOtfrwQwuvVSZuYCOJ>orT;~4(KkB_DhJtyN_j^~qo{gc z%&ab89G1eyOYPDq-^+NvYS&nAw1BXb+Z1^yHX{z}Etxa)@xZI3SZFerQ+0SYNz;ic zOnY>Z#-yx4oeBnfz5SU72O**VOb{p~msv-$c?_Go62ySd}>* zc|?qKR8lHYwb)5UOY1x}*$k_LXkOpOJu#ao&En=p$_TS1pU_ghKt-hOl);sWW##8t z$;?!%qdZa^Z1+*~7Ua8xbKqv=fGb*jETe$F$G_qmgcYo3RYF`_c0LP~=S6{UoP6h%@> zBvev~(jZh)N~u(o2&sh3GEWzGe&*p!?|0q%{y)#>nah;-b7C*{q@i63jP8cZI2r_sUMnyZ7=Ii#}VUA^--dM z+3kNB1!Of_2QGe#^{P6Ln3W~w4sukzM9k6tYNa*P5>E`)+RW3oyTJK(A>-?#>fK7x zTt#fl3B>aL7YyTr%5v*ZbsDR}L0c1I%C8`1>N#s386~z8GyO%ff9xkhafs-dy{%(v znc0qb4!4?*ssG@)c?wVBgX~uQ4f?zR-VcXZ8SO)^fgVH|o&!GjH+)te#JA)&)}zM| z-FvF_I+A6Sr3~`^xgH=gG!OaHdTIRr=uy6Z+nyPK?6(Q1Rt6A zt-*Me47UEH{ZG=$uGVNV7)hONp=Vp3Ru5SNxJGBay)|v91ud4Y#T)1r{PRC0cKkGD2^77C z$hWhxUEk8fgGB$k8mn%qU0<1Oc7z1ew@VjDTNNL!T;wfc%%}Sdks{d4AqXYH*5yqygCszx(m{LJKk#T@b!vY_aaFx@Skc(G}}TrGM<=o zPW1c~tsI~v;dw6QYvaMDao?Mf!ka1m1JQW5A_Mc4+15uA_gQ(>b^sD#gr@f_=;tfRyZYTv={2)kPk!}nh3;%T(KsqzPYb&aZ| z$OG-7NW`z?Ecgk1yhC{f8K{9KgyH(HoY_VE?g001t;~a8^Pv0+e4Vg|W~KOha4Kv>w6f2#b-%B#+6)OrqC*^hNNjE?3&Z3_!? zh2Is_b%|0!+8t=>)$s2CEi;hF zdi3)W^0ywI{XkjCO%dmQpHW|gdv8E5oY3Vr^yMAe`7tuSfh;CF@L$=^(O;-(6EgV| zR{H{es~2c9@j2MXI)vtY3JuTE;sm8)P-q{=S}Kd+Ll!h0M|+0z8j82#Qf&1MybI^! z!#4|Wh+7y@E@GW;MYdKU69eFQ3)basz<2RUyksZStNR!a+9NY2J!(#kHIUAHq_)_) zp1X$NwE-=zpf?&4ah6tC_}Tphclx6(Ly*csR_6S`rZUXw>^v*{zQ*cHuAbWX>|KTb;QiJh zXx5dzy}|gubz_hEW{!@*`dZ86mqQCtj`X1H==9T zB{F;onSko@yD@o~I#EMos}5O#?#F}iCj7PLAv<5;-MAR;KM$&)HER(Up!r?2H^P|r z2yK{$MG0XUjw87@pqGDGW1!^qP`_@Kzvb;5?af~LNGR9>YUMx^KT%dD;Y&Rd-rUPr zQ;U)2Dy;73=;AK$@pJL2yoIa#!dV~g zPWHHGQ(^=h9Y!4YS&Zbf;rDio43|c~&Ss;B0$FrWDpS2MM)f=bB>uGHA(T)Rchl-y!*Z(bm0K$$#-6 z{TF%pRY_bl+$#!cTbv{-N!NWIhc2|BEZu z@mT}wSBqQ<-;j~vQ*`GCq|Hm~zEy6se-P@Pb0zS$bBJ+gW^%=_NHW>etOY*%27gOu zJ`Jo2Gnofuw`MQ#`x@? z{9o7~(R0aQ(%82n#}+smn3S3l>mU9qIfDJwE$UBX1+$WI*52NZX_0;Lx6Hckk9^+; z>bM)(&t%>=oXQB#H(tX(%=(h=X}g%k5R8yx|kh|>#aLL zBFBuDrc-&-EF!;1wmDiGPt4HXT4m<4(a+pyYlJWSCyJezEB)g6$?wwbjr~fFbFl3l z<#qjXW_0q2*neb;{!E*IckE$h1o`?J$1ldG#qLS2PHjp*L4=TU;)~s*T}OoT1=`Te zhsFnHp>ozdtCf%!v%)xH?xwYK%omMpb2n(ZG2mI<;M+PFKat`1b#1)1Q>zb_WfQnb zO>JtQ2hwDh^Mvgqe8R5SZ*l(3YIuK75BG4QfNyg?$X@d#*YjYuA8<9m@9j%`>;`~~ z$_3MtBC7j1{K#InhwT#_G4?65!749zzT&>h)ykc8>;hMDFFTs=I&XC>082i{VGzSC zja~4t8nctkadD#Jgtgt;kj%&I#jhvE_mi2$cs<6ndgNi=mw7(DgnXw@6U*igFbay< z$haHy-!g5ZejV7ceIQXn~L-ZU2hcNJH>gsHffu zc5Acorv7WD1GzLm%{Yxy`XZ~gnr~fWyGi}Xc1(Sb-P7U3uwLyxN37n5yw7`n_f7E} zB@^mtd{$W3aXDOb9YgJ{ZQWEAn|du{*(kheKOyf#0Z|a%x4nn=MQ8h2@Vx`@)!FWt zLe9g6M1_Bs`Qbe4Tm0S*f-H3#Q}oBRqq@tOZZ*a~{1Ri|DxwnnVHOy_fVJBV4refa zD1$&Wf2dE_@-v5tgYyUx#zv-FrPijtNc@u=7oQb##`Z_bBd%B~RveuXYaVe&=T`nw zF}`$D*^jY7sot4I)-?MZ?^D6efy%62*~fy7f)Rhj|5U(7cJ@5qK;I}&(3!9PZmdh+ zm+(ZUhnq!)M01Eywlb-L(tX(+!jscbTNg*EbDH;@zfbV5tl6R2**iiPg72}5^t`*3 z>s4|amZM!u$xrw=J{k+i&hZd9_4?%L8)9GLc*a%Xe#{f}EhL}73D3jscJ4W@D1I$d zorfJyIhWbn+J945GZNsBZw@4`vz^R)l9Tf##wrEe$Wr{sHmW1RnE6kpBqrpI#Km46+mL8RW{jN#Gk_1NT@LnRL}gN{M!RdTadY z$m|NG?B253Wq*`!joh7_pued6>e%J}*x!XHsEx8e%Jycr&b~EU4~+?Z798p?@~n3D zunjX?Xcpmnjmj$}cbD!g^_BK1n_pfLeJwLnUFSL#sFzp2 z=FY+{b(R+nuYEK>k{t}Z?ETQWSNTCdo_;koHa;r)M7U4o3+2A@ITanFn-XQ2K59Gf zxSYcUz3bL)5N^1rVcUA&*ZwE}sjNpl=WG#kt>#J(PNZVvbuxMW7jmGbMto#V69J&ZPL*xs1@5R<)r-*?{LM5T0Oy&U>2 zuq~^#f0=u=W3!TDPD$@ftc!PwXAz^jyEfjKrYN?QO?66im_l_j@fG6wIsHuLHSKf# zV5XC?27j7oQ=JpT;^!j6qWdD{;YLw68O}!}n&I>OX1aZ1e8e5Us-kDvf|Bq>`^9gI zT92n3^=5DMpxAx!lZ)k91!`=<@3kMg3LQS00^^LL5%grkj zF0Z=0yX5}LVeu};ZuJ^>o^P!`kQL6_Oq}D~tc`(!goq>#a_Pj{Fn#md^>VE$)5!`AZAVTwSstx+FHpjJfw^ zcdYkN-8Y-Q)bzVLp&FAyH9b?U`{Vs8ZZ7UovaV!vX{P*g<>%3t#nW$QsKfHw+Bzz%hDB*xn)mY{`kVsi!(1>Th=`?I5k{3YyZvF z91K#6tj_{n10#d~g+e+1<-Jj(dCkSO?#jO=w>0=1;T`uRhK6nBPZyUIe^oX*QY(40 zKE?i#r%@mg>=WvmJwB^`=wIIj?>^so*HlkC*BE@7K4xA|W+weXZC&bj#@n~`OU41~ zBl~X0RQK1OD{hzfch`^jtn{`%)W6Au(qE-2Qm50S@SPZ`jbb+RNh&Y-P27yU5bhtf z!lBse%AL`&*!F0zXuCux>Wug+^GaSSUUqrN>Gb)YmH#AqXsbNAImha)Z=y79*}7)a zjRi|`GLB!(hms9TiZAuKbnT^g85f5{K1<9^_Dhe_XR6<;gKX=ZgMHq-jx{$G+8SPL zkg5AeouBfy1mCxRA8%WJ{rX7q{9Kh;MMmd4(y`QvbdB^xP*XEKU4c)tJJkx-j@Ic{e?rYJ1r0(! zJ1>~mM+&aQFVsFa{er)!QOTE)&$VZ5gS^voy5(9q59S`t$_n)krRR)Zd5)o2k=21ox6dn%tR)ym&3JxKij-y*iexkzsE zu}J+$UUnn1g>&-OuigGTt2PH zd1=s<+GSR3d-_RY{miz4#yImet6XVk&vv%(-j}nz<_(3p^*^rjX^jQh1Dt(~7KyXP zV=wy7y>kBL;vE&e64R9x&XvB}S>I(Z3C#)oPPWtI+D4q8xOb-KZGV&OPn^ws zzu9xmE9z{0G3$}_mDcReZzFEp_ul@VGBPwYP&V3+;winx`@O^P{uzAFuVhX49&ows zGj0Fb`kMDEADE9?%gBn?&-JvUopZJQ4fQqqaAt1LDZRlGe`kzmrOC&P^?H23-pceR zrp3|J;PiXR%fxFwn6y%vc)j!~^ZxW#ct#~-UuEu#{1nL$g>`IM*Yd)!zcLsZ8s1Vd zv+};u(Pd9u`RY>pD{mB8B|DOHj2*6*L(6g(6^^d+SK;x3wb`9LD=a+`i>-*fS$R*T zQMo>Li1@uHz@^7ntr=)OW#4bV#`CWCZU3L4cd{?$?#kI4yxlk4VVMgud8rQZ{?X5( z74b{y$>2M_cD_x_lX{-Iu76!aoDbL!f~S4n@sM+`E7yI8yTAMP(4+ZVf`@Bd%Kjy= zH0L3AFXuk{yT+B&ed&SPEOQnSY(MZ^LpFg}XkBo6@Ybw_-kr`NUcbAAXP*Z!sN-{| zWxvseA1yi0cG#LKj}kwmCjMmW$y_+vwp-bYcR~+P$3J9V%kZ>>ch} z_G0lxhKJ_X7*x(EKKiI{2S1GmqLCc@;*75yHR+IDOFX`ub3nXM;bu}I;xyD)Ino8`gzriwe zv2}4i;@C>8Jp4jzLz#D8=kfWE`s-(p$+;tMao*{ixmo=K1s=QeH~UuduIAb{+Gly+ z2{sKDW_QeL73z`I-(T#$&K0*=ME)IbBRmCpCZeA6zS+L(ec*`PlaRPtHk~KODfmX0 z8k&BK5!0U3`|@P;k$I4qa8t;{HJ3HE&xo_M3gl`B@+5tmUPL}1XX5$zIOZ2qlaa(W zqO0CR)RdX{w7wJXp6DB!7Hb&o9J?4RjQtin5KF`l#dpPDiN6%vl2{xKMC*iom0wpn z!|{qxR8YB+Ph$Up3#2Z!?R?idx5K^YppNinc~0VyVc*t-#IYBU*u~Un1Q8x z(EX9y>pDq%?|OKfkhd6Q)=>M+j@h;r_U@oi4ucDyNc0)M9>6nmmG&u_aN3bkYEe3* zZPH94FJ2^CE9Bkk%|srFRv#;L2gWq4^H@!KPO)&%IkmeP_i{S$IjrL zUzjRRJfAGaEBNzReIgqSifOT>M9g?F+LmXl+v6ug=L_e{^nCn`MP77eHF1XI@s#reQ?@VsxAZxw@htgJHU$aPk8SPV6FSO=~8aT zE3kpm(7alC7$3?)kQ*DUqj(ivW1QAqW)r=GHQN|$9n;%bH?W#C)VyGPMf|3Ga}74U zhI%8};v2JS^bTHwxz=K%11tW$ta(OP>t$_@F~PdY7^5uK^OQm687pLVHzwKM*G?1d zdN^JWW5GwJ$uRo6dRBi<`^MOwxeo8Ib$IJ9v2{_K8mnyc*aiOEit6PyukowxdE-{> z*=(bem20%p23db=MMg>qI?r0FtG??=X9L?FSEcGuX4v{zBOP6du(rke#E6^MslDyJ zjXP{JKsJ8hSa0v3^doLomh-Z{gnU?ii~#Gr`-n+u*>|$b`nK(9rIn@Pi_uqc+YZ@> z**>#1v9+}~GiR%#%pT5d#!G4&WNn^(rS7(FCl*ej@{ps#u-V!fkC3JKS!08JTKYeA zp}kRN1o;z=*vA_e$;Y}CwIx;d9*&YF9bmEjkI zkH<=E2OVRr8(b@malYS-zfx~y?r}Vkoap$Bb}6B^%2S?qG8I-0+xA2Noc3aeOaBY6 ztG7HaRQT-^wL8hE@}^$H(>gQXxjSWd?Mm)brv&bgT@1L)SKZIp+v#sxdm=kLPuQMK zzvmfkpAhd6I1%e-8=tw)Ho+L9UFE++e^4D^t#-bvKAE_~C=7Nw? z*n8TRd2WiWHx~Ljs4gwb=;f-+jJ1tSuVamFczTY5XnW2TnlJH7)|cr!wGC#T>p|zf zvLpVD$n_xmoL=`ZtApJNMf!#m;4QjT4E- z8fWZLJsG!qvffV1&%RvRG5EH1G`-AjPj&Fu2R$^<^@!F{t!ZA$Je}R9@>l!2`hHK5 zw$duLWr#}ont2<3&mY*HQFoeUmHX%WbPF$JuW3z7T$0TjzZT zL|vil^Yjw`()b&$+sb~-Z4e)BlsmsM#+YAtK8;LsrM-7n&hhoPwY3JN&zXBByGEXQSvV0Y3cZt=; z`AuYmr=Pae^`gDr<>~oR#b!1yZCoSndE~-F!QzZ9G_ow$dBUvcc`r6ZZRMVsoD=w} z{15+rBRlE$c1cYeY z*GZ+Ney@7Ho^<9WPpcE04`}77eQt-^H?_>!+3afnHuiN^o6J)Epf$vGI`v8_6{;8C zZ@63ueUmx>&%A$KesR-K)Lj$-%E`tjmY}-)oJ*qxcPPB7LnoQhDE~bdFZqYS$TaT@|X=oStsv zzT5t9W;b4a^OZ5KQOPs<*S>6Rpz^Bi=JXS;wMt)ej1E}XGdQ!-IO%Li9Er++Hj5uX~{SLHPH;C8IDi!TVNxfwIS!lbmRNY1?YAYkr>o#Wg~ET&ZUV`-+p6y@+!}}_JNskU`Ce(zBXzk zcDiaQ>+MD90jWUF)KpI>ewDsNecv9KbF-$y4Aq!3C(r z)(CT-{b}pn%)<_aeV{egdi%pl(0J1r=U8oQRv$B8FxRbIsC^sH>EZwSSy5td++1ng862t@_3^ z$D6vUeB+*;eBAyRu|1x(4NvxPKch}bZ}sd+6*$YRi|HbFXJwg|>*}0&(B8}0I5ohw z&a6r7=_k>Xy2{J?GoBChnVIJ9b;c2P6WdsSI~HrNFr%<3ciG=y#!%asMTFg-t(cl2 z62LF&GB62^l;_lahDH>zCycuGw&q%Pe21Ew)WK#SWsTZEA7h`6x8mPs0oV@1*yK25 zbl~~0t~%1#iRaAk#*J1BVuMUl+Y`SUytb8XbDQ@OCHyO`mwk)zoTAykNpDc6g4k%R zM)Wc2tKi_~T3yXJ7OBv<&0Ytjqf@ zo2=bn6>^D{?qg3q8|=dto>LBi!`fvvRDWQPc{Y2&<3V;T;n*c>7W?w2tv|T;IjD98 zq)!|6+JB;!Qy{FIpyBF(?E$Itz48f| zsJlS4O^1UPSgNM%I@>9A0~m?`*q)u%1dc4VUZ#y5DfyB07bvU%yZE!Xrdid#cpjxc zB1-Xb&>VBq)%7+jE4xA0fDN&HU^)H*k#Z~CNDzs6H<-H9V3|&U zQYZ(1QWKoUuhhALE1QBESc-#%@PZ*b$Ij1gu8T95)xhxV1QA0NMS@2u~E90X54LFuVG{gkR zokMV=HTby@Xqn^Cwh1-4=(EADc$6zESS|BZA;_|+xhD6jg`}UNc0mHUkn(KG?O{#n zD5$EM@D+0iXVSb*LCqu_@ba8eMA;qidOsBRP^t)S9|8CDCs?PGAdC)BTO8zso4QZ) z`#<=3l&k+k1{9ECVSY2ope(xOU zyf`v@ocbqwP<99Svl6Uk9(@xEl){@R{ip%mZSbs| zzLwFC3R>Xjp9d*ZkoF47D9|iHKZ0D5gw_`+;R1gZsG?$;`g7234GOs7SS~cm1CJA+ zOe3yT!OmIGuo2qd5U$sSCpC};LFHW!rv%YwLt8^t(b7ZUnp$zqRs1jDOg>UwyXs$K zj_1P30_3G3$9rHU>QYM_PPa$idhss{Nfr5N&u4S)*9x8sdbBCOnsIdxWV8l0E*ET9 zEy}c~?xtvLXSm-0JLIBWwP;rsSLN_8i`L{LiP`*1Lbnht6AFk9*5+&y+6Z;*^e{j# zl1N+%YJ`zFJ1whCpCZu3&UY5qWRTSex_KEosvwtyw#D$m32(|l`;{OW!leku(xBc_ zNKq;L33FDg>Luji2xUaNim(F3d>(|;3FP4l$0|8;1-S^L$AV1!8=l3Gvn=lAME^>$ zG$*MmgCt#n$ECC)#xZHDAKBNrOAS5+-KtT}4c{+Pri2p3NL>)Q%7WJysa>wLLn8wn z(BQa>vZC37{S6=oHa=}oEP~`^ai>OJz!sX^V|(>r#scNy$Zm3Xr$A%uRr10UgCV61`qce z+}OzbNB)JG@7Td>9fR`!VF`9&O$534J2=g6K{0-bt@sqw)p0P4$FM5fx#|Gc;47@r z39QFHUYkJd?xXx3%5Gs!_z#>vKH+RzEQHUrnz03=^Kc+v+7UjrYn0k3yGoXO*SQ#7$APE9bNdJe%;2A#)X5+fT8?^>)!|Cw+z34 z#q7DvXI44STE%-M@u*%R)6FNG2Y*b7Rh0b>bmu#??lW4m9n9Y;WM>_}mvZKBBYEwv?KaoEk(gHoY(&_HfXhp&=AxI=~K16 zWDflZ6vtcSoH%0W=8tA6JA6B_mFKND)gfeF`C93y?o`6e>6;>>M?vnEoB7r`qW+&F zqSO|%fVfyTBGF(S)Z6~$+&bFY?n!-n`T619;J^Zwx*BxX}gFlbb(R5HnQ>@sM(%a zAXZGkD~}~c|NF%Kd4V`sHsdv7g~f?CbhkN|n0q6MxHr;VL*|N=`oH=`;)Ng5ek9)I zbD6yK4Jj@0C%NdnL@xd^5lnrP{z+eLUGHe)KI!vkwa%WDdn|WJ?zkL(Xm}vU+sN74 z*3W!eJDsv8PbNkZtv5fpJ+&zFlwQ+(!@5@8YYRDC&R^J%{N44ed!{?=?#5nH3r{o8 zDfd942Hx##>o{(^Mh#hmjIj1%W=eVmQ8dq`mZtGh*Ne?|s@w6F>t65O{(FM4tihq* zLj@s!*8PEIzHROgoI}~!I7L3M$@&~^h1Qb&ix&DJ{UhT`eE$DcH1%8Jj13|4R)6PZ zB8S^udnCd%aVx&G&nAZ1F+3!E){W-FtmTi=HO&Bfv6goSeW4y84(3|od9N@x5V7t< zux8hqKauNWp-Po$@%zeMp7t|^Iq0Z-ZW)%Kjvs2!RET`uIo28#fEC`<~ zt64gup9z0=MSz$p0@1Ab@HF`YyOr2Z(yqE;S#F=%b)Q8+#w|<@#?9JVXn>m$+)J@t@zIe5;0S#rD>$!@7xW4={?T zaC_`??62DJhgORjDf98VDB=me0~+_T*%v8qV|F+1HJ6w(!FYN=GWtY!id7^x5d6)3)L<{3JM}P0`;M&}VwN@3bEgW9>Nv$>cEptBi zS2*L!+ZikFjtnG@P5;u8l38W#D_cd^#rh@-wbsPWIAyat`?x>!{OP^k|GmFEd%?Q{ zq2Ql^J^sGFBc7|=4V*!8tq<3JP5qqsF?xG=O6Bp&N5iWkZQ|Fb9Qsez3Ht_jE&rI{ z=urKfCOO@*muF25lzDG+7dpOF-Z7rgwx`>s62zH#FmX4#Xi92kx^JeDHcLNgthAm{ zJht!cbG;ja^#Y%T?$7B`u&APTWXiiH`h@LMWTsrH@#LF*UnOf4AyyP#$mZUnF z!)$XL4tLnQIM6FNJNQNL`C#YZ#K18BgT9*HtK1iy+Z^+4(^yUVNDHSQPx%s0N7qL_ ziyVoJk1mRhN`%vQ8EurG?YFsG`}PH_taG8#(ETA>R^z}j?*!M|w)d=0^j4Wnaz*^E zSZ5-*cP1ZdzvQ7*I5Sz#GhecHDn;a3joUSQCt_HC;9Ta)bA9hBa&_}>57_+c{oDNO zvfd1w@{VyX04bL={?)D|CnvVVUWuinU&M+)0Ikxmx9(Ao6I*Sn=aT1&zp3|5Pl2Z% zam(wnv->csu?cN4vCZ0~*JZxSOd|SNAN?z+`VaQ0oV=4W&5q0_ZpL44fpV7kCI6^n znN8GWHf9;`sQ)5!2N7yd5nFCAksYGMPd%e2h#2;=wkSO> zbvpS}{GI5`aHgV0S%cCK%C=R!S@}ifT5ylLeCscNR9UbGEQg?fVLj zWn?U#7_aFeo?CB?-pN+y;ztH_D`6dRVl+1#Q&>}c(7?0YM)EBJeGN3ba9 z%X%XCZJ?!pulHMTp4a0pcRXXSZ7Z>k8C|pv>EX%Y@s6=Q(VfxF(bwX8lRYxG8Qql) z_FC?RzLvrBS>?oe+MInV^hxl%Z?0#z^KaYZ@aWpi2dQDnPZIqS6A}f<$5QUh{n}sT z4qm{DUPRSwJ;`%&AFt6@cili+vQUq zwCtHmBy|OQ;n~sL_;WFLra`8QvBLV*R_^TU+UafQx#BJK&h?D++~j=N(OEsIE;nZr z6ZK=(zC&p1n)I+tC==2~=q{;`HsA0Ao);4-f(Hp!TA+DYGPfHn}mrGIk*Hc;&U_=S#CH zo~~R}Svyh`|5b0UE^^wv*;!rkzbLr1mRjqtnxhIP=6{o24dbH-Egsmbx_pR_V_srndsd{=sB`=1DW8vHra*#uuV(!|m3lKGC=L6`Wy(YBGlDo0fGA^KN`l7n!p zpt5J=t(Z%TSuLEOdtMC;$Q@9i)#_F|QfpMLuM7UpZ=G%Vo^ZBM=NqR|XQDML>sKrS zUp}}jw_;J{gs7P~oPNmo*s5<2x^g{-yqA5G0s|R?TxeA1z^gu&?-6fb&rb3nkFo!% z)>EdN59y;b@l^ZNJ!!J<0v4;`Fat)aa(Hv2Apma*gu*?ycjWkIkp z-tq4B&X4WcwjRnta}*gtmTK8V>29HYp$*gL7$eQi_(c3?Z|}By#`psMR95HgX!eS1 zM`(Ta{NT}`%UcB6C8D%bKi4+scO^ZUe~2=FN9Mg$+SsHy)os=}+XnJ?wsJq>c+cJ5 zG1#@n@jRJR+Nmw=0qY;72L4yCkT2~O_7i4X?}N9_*H#h5_ zn>+QZ%nf?Ta2eS|(43@g)<4j?>0U6BKj^2ls~OQUAUEn7e;T!o1%}QT_bv#^#`Ne1 zyh@(ZULq4rhxBjBJChH^KaF{!BO~3z7b|Y7C@H@+yeYgnS}T4o^_8(&ZQ}aM^F#2q zobvp?3jQm&v1VzF!W!*!@6H-anXZEK9 z={l+P$&ZqaQx{V^(mCXI9$>aq`q&QJ$2#}B9`@vVzaS?5eMIUS?s?Di3fWS=aI_`g z?9=x5?H{TW?Lji1F6U|QU*#L~obtOtZYVsJ$Yy9hpboVr+J8`2IL12`I6Jt1ajqcC z_8#|(&bGur-2>Y7HRkY>$v^gi);zOS`##-6caZD;b>k)z4|(e*+X%IZ{cGE9TW5Pj zJ!CttOi{&Kx;_}K>&&n8FN{IN$!yNddVxM)TTY~}Rl3PMVujvX&o^3;*?c-|>DocKh;L)9qjMvn zNU!k9$PbZ&vG?LPrKW2Cm=|nMIbZX33$)1|np-#jNdAZUyYi#CU*-6-8v37cXE{t| zvwmB~N_-v9jy)WGH99@oJ$`@USn9P*QlD&1RkwkA?*@9Whxb!&u5Yrh2d^1qA(-T8 z;Ca=(%C*(`8WDMS5Xt-jrJ;4Tc}Z`jH`G32S;s+i_P;XAwCnWuk<7VDciU`xJ?A*r zdiVX_U%jV&@A$g-`!a9d<_mf?&%N&D^z372|910$_6HHL?$&d!8sR zl9i-3YY-r1?N5@ScbM%CTR&S1@`E(DJ#KT@bTahbsh(zqs{tztlHs6@m4n~zbXH#a zlZ*9s=-b8iA(@(T?Fr(CFDG-xe)0wPS1+)Z@+;n|vxuPB9e%&*oZx!HeYfW(_n)8< z&$w=N7CKDZYpfc~GY45t<5uiQAEU_JV~j#74v>-Hr0oQ)nQnhqz2AOF*`zvHlbd37 zG;aeDQ(?9;o-j`u-HA)xl-R~9$l9gkbiB;-j85Li0xNB7Be(apMCZ;XL**w#3!lvj z$Vy|l@v!l+-cusq8qev!5D#&a-kVtMZ{u&7B*ORu#C)G?9%1+QKRC4>^mjkvQy(P) z`wn9zF@zsxMc_SrR_-IxyWhC1Z`U38RvyfJl(~{#l|GwVmRg)Tm3ob5D35ldt{I1| zwQ7B0dAD*s?S9S^@znCpVJ_OAXucO5e)~4H8_^by6EFS&o)xn7`Fa`+ZYy%am^p}; z!Bff4SVYXY#&#!>YQMBSP1NLCwqo@!@+VCpD)kHGX1+!3Ow?b=g1er5&`zvWIaoK| z%=+5Tte0i67Bv~KV>g%|3oPXV+xND1WQ=={Y&;Kx-76qs+|aJn-Q{yt%y>?&5?ce8f4npL~E zSn1u#8f}>MXjy%`m-VrUtb&bX&Fnf>;51f=H?cDFIy1Ky$YHRMY@To8<^3kBj!zJQ z_f~MI`&c!YjBYk2MtL#OmyeMtv#?u+BROH0 zv&Y_!d^V$rS8jl^T|mCj0;5uIX&CywMjKWnJ~rB!J+g+vN1P&X50e*(Si z#+sq5?mo%-;d^A<-H0dlVKVjph$ei*iq!ki^DWlBJ|yDyLFEE|;6=p252=2nJ&bR6 z3Hg05Ddp^bA0VIcKCakCRG6RGN#Dg4|0xm0N6Gpi5jybM6`t28HhD4e^LLPOcOf3( zkFtt98ap+dRjhujIN!aR$olx*bK1aZdInY@Ax96h!(^w^bguK6f;LlLFHV{w!&iw0X4Pf5g!it)f>9Sl_ z!mKc_VCT~Cs2jfceJDK??!Ih&!`ks%@b@`ZlE3A(n79Gssdov#-k|0|)*Y-P&*h%u zhyXu^_0E-)e~mZ*^I1#&o>kxN))!U#68o6#UuJLPB=Vg?${Mjk+k*AkhOF}TpkF;W zB5@UZ!Ig)g$i2|!8T{eJOa68I*_ZJB0c(n@Sa1CR&VI`41J(=IaqJ5;U_GmvpYi?@ z-hGKL{c?Es0{H@8#}j@48YaHR{^f6LnUmy~+5s@#w0@=Dg` z*P>S!*aeYQ{l8c@-^CjEBKRa(vxoBe0MUmBai`^Im#p@0VIL^Xx{sTYtCAV`Dc0Ek zV-B9a=ZC*asMXy0-2B5)1(d6gxSojp3 ze;o;biPh)@#32}r?B_!9YpFwWK{sTreK68E6FV}Q?iCu?9=Q)V9a;(ON0n`1FgL9MUZ zW0HK^$B@CNq2X5O^9A3ppa!= z-dyB=lG1;&c6Wpwj=v}^eUn|>e0HS>nh*D6U;bd#F7IW|c+qK%J;7_McAr|qS=pzm z%}!P!yP(Hulf?4KsoGKMOsW3ZhDPj_f%HX+gRFncPJWQmaoTi}T@^p~DW)aGv^j@T zHIXMfoRfX6=In3fRqg1u;J(FBL$D8(?Ai(1rkq{O>uEV)Whbr`QgRvEG)2l%)bFFl_O#DnZL$^pYKEP=hm}wT zxw;1p6l_5~=&}!*wBVcu)$(aWE~8>05?Bm(bncP`W@9_t$zVPqCGwE_x8=&Ksb3 z8lEQ5JBbV9rblVGB)jS@kk1bELwI!+I@1O1xJ+%be{~U=+e!;|aSxN)&LKOc;8kwK z=A1zy&aU0i!LJPe_y zwLmrPhDXh~Zy9wILq8QB-VD!Wk7^YC>%@NN_4H6Dj?_a{vX)gvu?)NRLy-4fTsN6B zO_7lYkXav*XkO)i6#8bN+aXF6(UPCBd)=t%S1e=*`K`-Ncp>(o1AQ2dg=z#MC5Xlc zKp&k&rz7-6C%#QzS}*c=J@S@Euk*>#@Bur;HNi=}!zg$TeJp0>uOoe4LVr6@uSE9w z2dl9N3U|d$ybDi4$TKN0puj^lm8sGw4En?0zA#Bl#vSK=GQ$ zvzNX8rda%bP^~*XZx4559#Rv{Ea1r1a6q*227V9XU5l3>^jdM$hmPb!JBdhB3q9>b zeRVlnyQ;NCRU=v)eUov^2JNomDkr}ib5{r7`BiqJ1eDg*SOt~V^`f3G$W&YOFi35+ zkO{#Zm7<-ek?t5gJ_a9h(b+=AnATO+WdQo2LR0C>dZg|G`m_T1KY}i8<~k3h1kWIR zzeMj(pobeMcLYh=f_2zMo#)}d#3vI3pIFO0`XJWVViak`Ud$ipY#JQH4#sqed9{%` zR48mpX~GX0NowXNzJK|AHrt`xdmz>zre?BQ;Xs+uB5kQ*L) zXocWxE+FNF=#w3JcF?a}de;={+MuroYf%IGYqZb8DupRipz_L4O0Z~xD62y)@~TN6TF`pgUCZSy2tfsEDo}&1 z9j>}edju;acp@3ca;eizTLhhSnU`?mINU#io&T%Kh97{>e_%UKL-W(|KIQ7bXPMX6=1A?T*<&GB6!Ne7cq6`RQUYW{ z)m4odi<)p{vnpQJgs=6hv`;uHD4$&Jn9D1hb_oA#QzD-$WoBQGe}eY%(H;HQqk|Q!LJ+@Oq_*n`4DxlgWXe}H)1>ep>#|uzXtgukEoR=Ia;&qO9 zIagFx>2R`&hlGQGb}Zf@7g9FHbw!jH>5_kOJ`+$}aFs6Z?1P3rt`PYY{*_hTUo=NB zb(i^!a8%ADsZ(%eqQNG=G~U7`(JF3@#><_h9+3u(bJ8~(^$WtRy5Al+BogD~Wk-gD z$HHN`UL;UQMqC`X(E}MP9j%b@SoW3KQ}Y=4sP1e6gaagRn^wUoHYPs-oos<$gCbl4(U!oa8 z9YLj%wz8_lLPepK^jGYn!nhuHcm0Q(L(T(c26#0>3WmWHr|L3jDIAm0aR7sMI(X#3zW0lTZxFD8CI4vWGV7FyeLgpUT z*N7cSR+X2lG|I}iNYejoP@I~>a7GZ(|C3(fRgxoWl~orRaPgM0K=>j$uW>d(ujIZm zpOBhGI;B10;~-@P)o$Yod5dKDpiq$OMI&WCA=Fc=DC(@T;(72wI47$lHTf<7geNlQ zH->9l3Q# zr#x56h%8n^WfYbfw2TdF+tvo_r?{tCTm8*F)3R;rpU8`$d?Xf#RiCGSKAmFA*yv=#t@kUil)o+MfL2Y zT82eG<*3YVq#Yt}(w6F%Sd3~ps-BC;*b(O~W#tOts?;FTE!W9ie8XI5K`>d~xPzpLl4QcCVDd>78j_#|>6S4f@z z`z7;8krjE)@bW3&)$83oN zUgW57{wz-+$9Tp(RrP#%lHX^M3DM*;l#n>>GJeSVfoS~|&dKN?BaTRtXtjJ-KXG0} zie&Avh%1V?MxJP74k9*4?2gPCL?Y$cvRa#^&7!O1_MycxvdMTL=S8Muj1URT@T>Y+ zL?o(uhEY9NmJvy~EYAcgFBws*$F%A>rf7@MAx1fw-Ng7O5?%e&DNjY!Gl%MtMKn;f zQl2x!){E|EAptTklpFv#{I1sI>e;2tBV}G7YXdUVtW!0|lqXAh_K?|x%sAu;Ry@q) znOMf0>TyTL9w{Lu>oVFjtEx+$L*?nR4)>9@7I}J4kj`nvQqs7B{Cv&3z;{I<&zOiepSop|N2t>oGUYt z>L(46gzA}wv�E)Fxw(%tA!(s&!eOf<#MYycdnFwjeT=SN{`<5t}VqT&>MAF35aD zq(e9__Cw}_^7JCJ0~y1l5As_^eVNOO){EVjx5#4kGlk4-#m2~|@L29D^IQ28t5$uD zSSM+hJTr-fk=cXrTt-2mi)cyp47A#I$o=D0)Dwyd?`5Qrm&^c!TJi)UoD?mRaYd}` z|Hft+d1dA%)E7=wONa15^i;H0p8l(!3B{(1MHh(@UWi4Ixu2|rh%AUr5gC$kt9mXd zTopYQ{>o?~qkc7a<+n(%^h_jPMpMxhnf-~~7QGYwkYBQzl*PADURHJT{{Ng%w|ZSf zp6jaZxL9(TX;xEHo&~Gt_`=2N`l`nS(KT5sl9$j~?1Pl3zOOuKh;CQwxXd(V4_a0u zWVK43I_mN%JJPMHUa~GFyTeU5_J2gZ2Yi&p_dUEl+c&+DP7*>3p@-f?I?@CIDI$u9 z1p!eMv13J4us1-^FM=W>f`Fj(-qL$dHd~VHrfiZe`Jc1%<9*prLU#9g=9#%|?%aFk z4($m&QCb9MWEj7S!6)J~e#)!(qUTC`5sRmV|KD$pn7;45k(=7NB zO-|`1T}egC7GpBhEOM0mBL{`AE$A<1b$&}eavm3cg;)(zLIekC=c%i-48rQtrt^Cd zv7%SW*+k44+*iukhBF7Pgz!(q9KuLr31%dIPv4U9FwS7K9-M_3%MvSDe$Ta%)jwk~ zhRWTnB621y3T>D0L>;(?J~CrqjA8N}J!j5w!Z+fx*dw+tycPb%`A%p8JuW_-krXE= zMvEvs6OL2JtynSfH)lFhg!YxxWn7gn;qQyM0`-`(r1Q%uHJ{!ISHN5oGG0LcRCo_u zF;Ko}RT#Y&_J!+L5qqUHGHS_ri_**4hH*~LM2xCO`sJ6NP7q|C)$?hAxbmV-3Ep$2 zB>&l0uCIlqVMJ8xzm^)!^%v(G&M%yAXi>SwyeC7Py=ko}MT}cfmS}-QOp#F#wm~Z) z<}0q1_>b{aYJa<*_OuCHQSvh}Q;?3NI3-a;80jC8M#5j9UBJWq>qz#U`YJp(+62y; zq*N@$UTAN`3Ym7y1bRuq16pChg)U@V=yQ_3oUf=y zoD*oX=^0Yjs-UGs$Z8aKmBDwU|5%7m3Xq-1MXsU{-<3eqgl}1ltC-K=wWWUjF2$Y9 zY~00_%#86Y7hX;MWVVKOgZd_NCppMX75sl~K9Z3Fz4hJ3;#){1_39p}pf8 zmbQ@drO4aR!=>cY2D0D6PBALaXd?Z25$#b!C+PW7U#J_@MfxI~uj%^=`m6qr|6DI{ zzdO%3;#oDUF~C`(2RO-!2aMrzKBnKt$`{lz>5*|6#V8+XwHJ1E$E$4@+g@xXibCOJ3*ORjGiM9ks8j&jKFcW=J=Tz zqXpsWn{j1o0anFeFw%QsjsvT0YD*e9`g9Jxr|FP_Q>2b zJvk5J6upr#9fVhZa9-_~C(b{V6MA%v3OVsN;~V|J?P!c#=s_FC84Kwe09oN_D~y6M zT0+|`EG^d+M4!;dNQ9giF_LQN6g8Mg123|(l%z=X#sju!WRN~3EU;&8Aq2#pSTc@Av zx1sMs*BMXc8ArF!{!O&cDgl%r7kXF-i*Vha^Y);B%#HT}1=Pr%=iv^nK03io$}x42 zJ*Br|gS-`^PqaCD&@UDgC&rH}dawU;dP1xqKufFxXK4vIdl4H$|D1LeZ*Ywf8rQ?m z3Fe_`@mSx2F$2~sioJ7a#jw&`#8{E%uh++V>I-Z z5*CYVJJ4HNy=dSsGH^`-o;eJjC|8|aZG@s1z3`4+11lhiyokadfoDFB^Jx?$V?46^ zBk+8FLvNZnJy!mZKu?)zq3^+&jVxpLNJXw4d4lu^D4_~S3Ts>NEJ1q^0k5a3A)fRo9=dm`b8#ba%d1aQ9|lIs9HXenubSmlW; zMUh)2&xp0;yv3R zSeNV2@>pC|1o~XV=vl`FXEZ^YhoW`rf(zG4@C-^`u^*L5>rQWpHFA1G0wusR(Vn!! z#)Y95R?x2#EuY7oJQt8F9?DIwA2)jjPbvc!>(L_{Mw5o`S+#^Kf)3P02*IxQgFuE7 z38#jhg!_3w3vdybyxX7&Po3-!?Fj^xbbiF11$sUb2*!!{hPXkZHo5X;Tr3s1#^HEI zqkjfU1GZ``B!amgJ;td6X5}PMKOR)bdI~t#Ae_-x1k8%bZ)wT*k-*ULyufz!yBgPd zfotL#asUub{n3YZ(6#{**&PVEdH&P9GSOp3_o!n-p}+m1cTIlmDUl^{z*UX_s_r9T zb0^?=7x7I!xDXDGv9?lg(4O9JJ4UQQ4<`U=I}ZKP!H$dqhH5(a8jX=N@=j|-v|=*) zG9DPdOpL0~&o9l@!_qpKr58zo8M$`o=*We0b?|uZ1FAM1ZnV4N=kl8L*YMm$FYt>~(B4pR zFc>)R)!o*UkK!NG;p@j(3wlnutJ=7`55F$4eB^yjYs48 zTsZ|{WKRS2nE+lCW3DoR0)rv5ZP4#Pe>JjrR5GgoO87cxG7FsEhdON|AwM&)PoWt$ zG!pXSg^i&PHU!c%3skrWd>t`iQt*T5$q2uV>4~ayCqT0vpdV*AT64}S;rKQl^4AR% z4+9P4c$W@)91jUA!`Kymix`G0lYyz#LjF7Pc_Vb>2C#s`prb@P65qs?Z!h#A3KSro z)B$P}o2ADccOb3IL+J1fuD~3i{8C)$0-iGwC%f7q7uV6}q0mDgK6Asy(2MgyHW|%d z3GEt%k+LS!O|)x78>QgNXgoCtPc;JLXNAlWvpgGA?S{`;>CcN6J0ay=;3rp~Q-IS{ zf|@?ilv)u3q;nIHy0d_ZodX*BK;;-vsvkz}0yc96^rQo$7MW(A@X1PJsia4glClY2BJ3&l2C!ZyK!phAbisZee8#}6VV#|PbcKM z9enH#+J{00;$iFC(Z|8y1gq2i3L7AS{IvV+;7M>L4ierEy(cEq zta5i082oh%ul${r-%KRjZZx|k$Vt$#9T-u zdcmr8Tz_#D6AO(BK)YOZ&~K68hgg{ySQT1$Gvv?)Yrq&5dF@8uW6?hyM(lxn5YMaz z&&YXJVoZiNkb|qJvr50N){cpid62tid~? z&B)Odoj2qa#_+3N7ZUJ{TF-+ZMDy>)!}z;;y1J(3mTS zaD2RNNJW>v;Y)5&lVd>Lgq>6z_;gkAft0!*jc%r8GQ;kpf1g z8K31`-49o>7G@}<)PebPG};&f`x=h1apfsP|LJ#i;*|7Ee9C#A+4eq=DpH6Z&k)!P zv1Vr^iS+@wo>TjK-VJoHgIjr^K4)LEHKHtMs zFnc2KHv$zRvK1l@$D9pU=+p^DRr#B-TY(beDu=tzn0MiC${gPdOb@dm;vTMCxYFgy znYmihxDCA#_*j9p;mAY=<6cCQc{1X+L=KKwN_t=GY1`a6EuE@;L&lE^PW?+dD;k6QQP2?>Y8D%Tv9oylmgVsQx*u?+LsxXg985I2E zlNfsv*cZm-=<|u}3B4oQd`7*QlVygKnNsGzs=*OvtwfHN*>Q4Ju9(TahvpagALeG69b?T`<~$g?dc@DSrEuOa#HY-noLhz>>OZT zg4r^WyJuca(b(;RZ$k8&AMjfXFa>ga*jBygaXKb2zR)PDYjo|Dha-)<1 zfm3CamYio6gI0_Yer875GW`T%a`+9kf;Lfbj;IKc^(3#!M_Nr>;Fp1W2s3hnnFyj` zNPn)G82cu|g84aS$%L(;kHGvHy*iNzWiEt~baFQycd-rrEo=;Jl)&E6k~3l}tQPO5 z-QaI>nRqwDnIhk_K?4`gNF``DBK$#Y34?-g%Jrd^n#4e(8ykk^_t%{k1GYj+aLP7)r zNT5q7QF6bg&>I)%5NZo;E$K_nQ?H2`zL&q{E+C>d_#Hh5fw>* zGsDJNgZqiNqM~e3W<|DMcpoCGPixIAmq0Ht8%C+54Ph?61#OexT+cJNLJWBweshIE zz89k(<@n8t3_>=!Q;9MAGQ6ivqhttK!8slHPWUg3oAWPPCR!4%IT-O3=oI3Rm~o}u z5D39$T*>UJ@JL$yy$}c?Vgh*=wM6)CKD5S4>a-p7D2NRb`}v5#p!^B%gYqQ23FhU9 z!=?0bcNp^;+;hq*@VuT#Bw9MoaKu}2UJ_Y%H9q4kPcL6&bZBX5YpEyH0AiDv+a$VA z_$Pc4(OJ!YOTp|RwVgcRo<#a3!o%dcmva+6LGHkzSFFOb1iqUz68m>Vu8saKEedm- zTu*WqVWzGO+~FQW&LsDCrE!hHs_HzAfEjaU&!rdv?TqO4gWmVX?@YhN7=-Uw>4U$A;r|R=&#X%f$e0%Wch*biDC2Rh1S@or=?=meneSr+Bp04uF?hgC9cN-{ zpBp@*56bm7SK%(aL7)+(Q|sEKpBKW#&w&$J{AME5W&+D_gOzgqo%W zhnWB7e&gPlCB}g3J#Zzl!Upj0KFlQ|x60EAm`P*2g3qB(DxxFo1?Nj@68AaL3#KG+ zJ;lAwwCc1=>=947AO?fk0Ag8q$_S$#di07JK;}Jkkft>Bz>VCU${$Bkfh%me0b}GW zMeAin9I*f+w_v5qz1fuhcF>D>z9O{AoIbO3%zx7x;LdukcFA*b28bC{5)|T|KGqtj z!2J^RjK~ny`EY`|%?8$NAJX_20`Z-OASX4X+!9f5LL^ZcM|l_ga{Yc>-6Rlz%Ali zcybHr&6sB)`o&llb5~u6Ug<$;W=1(Dk|svnBXZirxTK=z#3gbTOhpcCCb%Dr7!51f z--0xDAkI&0&PYfD&I5*yQ7@_2>4-{=#qR`g<1+f#1lq-ecX582aTUe98D|k%>QIc5 z(P|>ghr=5g4}WDSC{zndvdVo8`Wg+NWg4EPL~Jt)v~oE-MkS`~_N82Du!8tk6u{-yLzwL&&K7h`d%MqFf2aJO}Z&gPgWUDNFA(Btv|;_yjLfv$c!|-tu}m9h|hS2 z5Ni@~mQMgR8OQ2}(T&G%qHJg%mFOSOfr!WVVn=B&)F4{tuRAmu9%paRhjF7s&~yQQ z#^Wy+`uGJ#eFeI@7`3NfMOI}DsK5^ViY0 zL1<|rD42=o67A)J28_WIhF}!I7%NeAr@>MBIBDnw5h;3TB==y1p#MYB7A+t>7w#Ql zjG-H%^;7Yk5w&4{!Z{H8@%o^Jeg4D!c{M_jW@U_o1u zdFT(SjfHh_K$n>(XO*!8^t~3(s>ZzzXr3Hn9)zBeiy4p@qNG?mX&BZYlkklOI$i;J z$cNP(0;yq4LITUpbC*toXX)tYN`DS76j`D3xRWbLQNth&9OF4$S3xo6c;|z+BVf61 zfxfIalYnQYqQ{itdT0+<$MKMe-gqj{!?=uI>afze5B+1-xB~pkfh?$ylXBr|55~nD z*(txiWJCJ0KpC#G!$A|`2$~=j!qar1x0gUW6=>TJeaM1_a0OZc8_IYkqdQ#pP#)T# z733rBuN|CYPmY7?@faEBEeCj44%x2(=SDz^kD*`xf+Ng(HbNHuML)Q1C`HeX;Kgir z1$>4Jzzys{ZL}XCE!B8x1f)I>EnV@;1Eqz2IqNdf_dJ2Siy$u}kiQ%ZI$B{}xSy&V ze|5uic>;_SdRPt}sl)G{;KcplaTlI*9xbppnaE;|hjcMoe;;(*3msquD^`^miqG80 z9*zS&9>CdPY~&a*GyRNFw74=za7oQGlRj4(f;`K9Gn?420F)8>uWf> zc_esv7yU4U5>1ennRs?QEJZjd)(-t)rJ4ItTdgINT+a7f%?z}$8`s@LUm{RxZz0<0gg!ojT1>Z5S#lfhIE{NKiC6Id zIh;0g9y}u|C=YW?75Kr^sP^0fdt;zcT!C8P^-lnf<1O6( zEo$Y(BRhH;)+h@77>E{v&<1U9E-0kHEd3d(c|C+Np2Jwwkmvg`BMgJaay?pr9y5-} zbMksXf-Ru%_u$5D&~zx|r9bG>3i=akbr+PV#91n=?i7LgYww|2QXYCAf~V5=DMp>V zx1qDlBXUjB0=Z#zvKr8uo&b!0j7vcVipkXFtDFae`29=KY;M&o!7mHvgXQQY8LuJan z(3LTMoET4Rn*~Z$V~mHv8G-T~4JnNSbv}iLXT#bo!<|#n*U_L*5)eHLQN4`$?I~z; zEJi*LI(`oNybz}d6#(h85p5>|@%A}-#JwJzQCWqj4Dxyie7J-PjGKWIOU2ndlcA?M z=;sAIow~(y92;?c8OEZ8mCD7nR!AjlIerDnI}bhmAD+hQV7c(co`$S&_KgAe>Gl7N zN}Ibe%kX@(DsWMWc{&D^*oW4SLuQ)&7ULmUyEx?&_^?9%Xm~B zzPSb+Ndwh6hY;^XWSq977R$F$df|16|ZWSH|Jz0;o@PB+nxaM~gfO%+j@3as<^lSw-)*zrJ1&{;KSn z4cqWFxMGF>{V`@t3#9c4%rkMI0x<=}&}=0%vmH3YM^IVw4PaV`zGlU?IpA77Y(yaX znFPDhgy--~tA%)O2x{i-fgS%;a#FGrvXPB>fr#5ZkXARajE};mH2b$)1ZxBQ2bADvlw)607pX%-f5gI zb_}0Pf;rJkqreckh(!ht`b; z*X?-jHF!`PLE~s>MGfZp`S^~=r(UoJ&!Hb&jVHrCj0MW=Dq7;2s}lE&geSsGdVpU54`+Ujxm!`(Zw$c?D?u5^V3MKqJoUipRX#7Zq=BN&Z8dyMbez4Xqo3 z(|KK}Py801T7ycwUqB!3LPwr}j7*1J`U>Yh4TSw`!kM&L&Hh9lkrfM?osxWk9xnzx zZP3sHc%m!8+bxin=iz}e>)8ibz+7Kbptnso)}(H~U*5mal%S-ac?!@S={ zaHl^|nPN;tq70GjmRtdqrhvvgFV=~t{D~Qhv-ua0@;LOy270<-ccx?Beh$`|l@VF< zkMl(2F^u5SCPrJ~0wbycpL`M7gt$Ujw+f96ruSw3rQgyFuX- zXz>xyqZvFn4Jn)mTl@tqbQ`?u3E)>3IB*HA8`0)U+{apht>7M!L?wP(%S;ndv_#r9=zluH$gN1t%KP`Cj4)DALBalN{Sfel)bIT6QI9K4UbbuOnFp37? z;tbH=U%T$_8iJOaK~WpdPiCJU0B?yyCpP(Y_%C^w9S4FwNf@CK96AjBFGfG&U^@nY zD%~)uGoY>p8uAz13-tF z@OFrZKMXD4S&3TcSqi8#8$4E^jh^r&Qt|W?(9FZI%w@11FJr{dVD5hy_O}l-^;uY! zPOK(2qKfuN*yn0U^IOoBC*f&+h}cIOq?~A-aP*=8@v;oz!$6z*j8+P|7#xMx} z;38PI-jMXY;N(HfleD`lK%eK)>P0+J4~<9$?Kk1>G}NU18&V^Mjb4XY>j>tW1-K3; zu%jgvJjE=?M-AvP6C8dN>%*DwDb3)q5j|vWi6yY`55fAL1J4)WuW6v*UCc={;j8lm z;#kbsn;>2N{Ogr=(2Gy%gvA|%tegBiard+y(-DC7_AO8>F$rYwBX5)fcqZ! zkTb#c@shVO;)k%7FyN%~QuypsFxn^lws;`!XEnbBc*j4&ri_FK^cL*lFw9ifA=5W7 zUvO7i3M|PJu%KUpLKW~N%#ejE_-huf=Ps}^%w4;oL)=Nw1D|OzM>`<@*?U>g_g5z(aJ+1{AZ{gaH>#tL2EM_Qv|t=L3v(Z%cULiX<`Nli9*KFID}V%y5vSHcqG{P8 z@%a!~AjZFKctRuQxJ&qd1gy_AXwyEl(Hl>t)t-p{jYO*l@q}Dlbq=e$C1`cH-|ioW z_T2{xYM^r?A=hglWlgw33r#o&d6htKdPA0OfM%k~sS!439_H~O=#c`W=mpI070^xv zc^QK_#e?2+e?&egWB}D?z#fGlcDWtgDFl}&$*rIO&!p}St5J=X4`O7~Vd1*@*9dvA z3~JDXF%H&r-pa9>j$dbk7qO2JtJr=TgngQq+2=~T=Z z3!#w%F`xW_-u?|b%KSPO11Ud)mh8}po>3-PvXRb z?U2wR@C%4}*^fTgfL{G1@4%W%;5U4XK8E4k1p0hS@iQFProqxsFImT%8kGaB{1%ct z4A%B9*ynYiRZooWTb%Z=6OwPh%)Ja_=BoE6$TIE8V!Q@H^0}`11EVd0*Zl!}ke!gF zf&M)xR_MahSb6mUFL+Y*dGz>8&?W=zEW$`wU!Uiecfc0h52>62xw{4a$@oJD!w4Y^{iHvz~h}1%YGBBGnS^th|)2p z64$I3fcju;Uvg zpJ0`@8nz<@9s={a+%Mk^x|s@(BM*78*APe2#(K-9cZt zXJII=ng>hI+!DP9IeNyOXHP!s~A2Kl;7F7+|yaL)6qJNRF0RusW2*|$= z>$oJmqpg^M=ceG(5YV2{=yY&wJZwWLsL7KU80{dYkBAKi#(ox`UWTmngEY;CrkMO| zZQ>mrn0Xkt@u9ywUz)4YiLg68Ko8>0S!FaD_R0lrQ1VGV=5Ls@Wc5)mxN-&Kw_{{n zA>4(QGD=8n7bEkG`w(ftI_a#P&Fqf_{bRnZ8otj+Xi6%0T!K%i`HW>V3deoJJkOpu zT_tohsoT8b(cHawdc1IFlC0lNy%y@(z* z;VDEybf9Nm$P)Md5MM&%2-n@iV6(Y*p%0!n030DIdjKdu6i?SecGx5Cn5J>zFkAKkHK;)a(BMh18)G1|#J+c6WCpYt1b*s3$5iNF3-o6n##sx_8o&jj zm{?oVh#vB+3sxH7T9tcC&igYC+;c!QM7v)Gu7YYj%Y&5%h%#uvr#y+=il-5kMYI94 zC(WS09KB=ysT8!Wg`Ay%u2!LS?l7;wJbwh7Vlfhx4Z{dz=;k>g3Qyxbd@11ohG-|QL zBO~;TsF~i-(=9T`A#s zKMFRH8F^xNxZ^zu_QC)?AkvNan_6&$b1D(-#ACAJ1?^@EBy17ph5PX|R(W9txF<=9k2~v$<)kDM>qK-mF(!#rneun$? zi7aB>57q;ff)4Sxmgfr*o69UCcax=H#?FA;xB5q@#Ct}+WBl}HhPDk}KCxUp1A#af z(q86Yj}zTR%pmiVMB>p(v){~T5wAu{@syny%o@b$vGxh$M4}?LsKLNHc?tlLeW{Sv z?idI62@qGroiCIH;_VqHAijr{V2D!Uh=`2gt~%n^i3AdRnTZu))qC!}p*1N-U%2~L z)ZOm#^NLsXZ+ANQS*BgE(m#2Ph?eg|<;lmJ#zB%VgpRpEYPJ@^^{e$iTS z4>(a2-0#d?z|4RSg64An*a&EEAIJr(ov?-k<%~9%_Dc+MTQQFzx4?U2*74#eV$n$W~)vES5AR+-`MB!Rs%pB}zSUkfbiU1K9@aI`+G|~7`2@<7O6^KkNZ5B6(>54$TecXXgkTHd;3TP zvYfkXh_vNi4$3XDO(n4R`S_pJD2Vwdf|kfhV!6dG0BSP#h}8J47txf&aSNmwvEq#Q zh}|XhIRr{v>GIQ z_i%`6__S)=|4q-AFG@JEnUr=to=ICQD!~%#z$=L+6#I>69k`#2qZ50+#Qq23BPipn zvLdP${{Qb~R;E5;LRE%yJ^LuLvgHHHs5i>=Ela za(4>%j1h^$Z;99@(n#z>5!RKSDRKS+TS!jP_7P9a`eWSrND7J_G^8{s%RN6t$#G8x z5%HpOouEFYgwi0x-2$~Lc3y~`MQl~zDR~c3G-59v{WI=G5PP?Xr02d3)PW(qLq4sbG4R`PmeMzb3tR!BxWS5_`T3s?gqS9@tM(fYC7v-vMMHL9_qWOx5r&z3HY1nT2|L0Zktx0ReqRHB7&PQ zo-{&FmQ|8?nihABapqu+Ue185%g6bFICxPVRMhQb^$S`a_JMIG_LV*nR}}o`gEVkQ zr;t+CJs{RhR4!+)i9crR^sYz+W-qwkOzeK4-%RUBY&dm>d+K-{rI&KbeR1@RD53Pg zxOSvGQtoIwiR7S#<_-$>q8;)_oHu2RH7r=AmiZv^iZ+}aBW9Yj9&J0T<%-#ivP(-Z zJ{S8EssG$XKwCjeKuj%t4@!iv*kWD~wFPNK_?@_xc7^jJyi4;lL^dduA$qzA`KdEgl72XpTl<%n`f50w|PVAE6UTj)V{!e-5?loE_ zdbixY#@e;QX0iSVy%_R=`bw+9{bKZE>78-T;v67&E9L=lc7d2HxRa4H4W*raFZY?mFihf_sH};+no#+mMre z{4tBZ_>6mHC}%x!4`&WO@!t6noE_--@rhiAioIs6`%f!Qiynx-Szj!`|C=ME^%AoM zzvZrVp3f*Mh_lz^8rMaf)5v4)dgX|j1)idN3S^3gIF>81oRzP zn~nRmSb2i=s4gL9dJfU6{}9n4_WNJNcCP!gLf8GaAq#(BgEhE@|5<RZ09nrB#QGAevTtj{Trg|te3M7ne#*bIQRiXt`6hpGFrR=`@vWe zYwEEcWF72Cw!dfGE6*B_tlr4#p4>G?nQ6t!j58Wn-Hh9a*}pgRa1i23{qe)KDC0{5 z@fqKb!p|5)k|!a`H3_ltIf(i!L{8v&WM5Y!!~HBWFV7(sv;sNF1&CHJz+dxm1*-r} z#$OZhGY{XdM&xoc?s^@0qZfd**@`>YAvgLc?q7kZ5u??k@ht8dqUOf=`9qCjbb&Kd z0JO;iUEs{iH3n;YaQ!6q)QdBC7!%^jz?`8OHFSdpw3A|f=Zr_>KKD*Ffl~FLAF*;x zpky8BUJAN18gm*IMMCHXvL1D6$z_k?;5( z8K;BDzwX0zzXOr+JJ7iN*@LXsN63q=1vd3*oa#LtdF|Q2FFcBN-a%&P3p{-<puJ}#^-B+?b_f!P5c==Z7Hr=hEG=jmGLkzeov#PYw`4#Fsc{P?_H3v-|eq0R{V-tVCSjZ-fesmrli44^IKKncwm}D3+nAPh7_7@k$kb+> z%)J|DnE#D)&96x6Bpyk!v_le!lj7G&H%V7XpO=0s-6H){`n~j&^e^eZ(njfJ=|4Ed z-X=XF{ZaaybhEUVbg)znUAnUHg2yeWQJuzN6lMz3aU5 zy^nZZp1YoJJw=`uJ^y)zdiHp{?&Y4FZlC)%_fhu&cfR|y`&aiq_bK;S&nQoU=M7Jt z$Le{~d&ZmYo9L_a#X{;L(c34{j}qy#vYWEG^2PEe<$L8n$=l=x#UjOMMTWv6pP;x; z(Nj?@zaYOZzbvnnm&=RgU&wdKm&zB(L*+qowfuzab=fr8a+zM{l>ROKM!H72Qrbsq z0Uu^VzlMP3XMrLg-Bsr^`@Y3%pfBB5>do=~=N;*t<&E>UdCEQQp52}sp7%Xfo?ks~ zyn1*Oyi>d*yfeK&dS`n}y(_#H@9S`Zk9nv02K&zXUh%d1+I&}ik98%1cP|0ovI$%~ z2#LA@&PKsc_y+OvCxLPMK~g2TEs;yj5<57uRJu%BEqzs*Aq$tSmaUbYmucnc@?r87 z@^>-vZi@5r(TV`Y+luEDD&?mNtFTQw>rjlPVJPsSGr?8X3uJGU*A68 zSVZQ(lB|)MrLV}`vNz?aiW3SQdb2}$Lb*+KNR<>I4|qRdbHFt9T6LU8r+G#5qUMms zsj+G@w2x^sw29h@+DY0j%^A%vnoXKHni$PKbszON0iCL^RBx!xD(5Q~C~nI;WO=d{ zNTeG!^a3RQXWupNMDI<{M9&`g)9%Zi>pFLLJk(L%ez^U`c0>Dh*G$(G*AmyC_?hSG z?V9R(!1bnUuWPKUrR}4(pW3>$J=qrBI;!=DmH{n?n?Gyb+1%;;usO83s(E(vxt0?x zKeTRYo9G(dzN_PMr{1&8d(gK~GEf>IPf|Rm+^t#_uw4CwrdV@aTc;hR8?Wo7+o+qP zYt@G9-qe1gylY8{#a&3l@c zG;Z~7_4Dcp>X!qC1bn1wQNE_MDmEyxps=JrgLil1Ahaxe_@>Invc{~0j~#quiB;hOIe|iD?;RHvdz*-sZD?Z2R1%wkO&~+5c(S->}}%&+&%yiIyd;AG9Cs*y--@lzCT4{+8{MN2|hB z9|yFkTQtw>YIMo^hxJ{$VTPv+YYopER_i12{(1eE`tAD5y8CrwbZ2xk^h5MFb?@sI z=|c3g^(*ze^#AC-)g|du^e^fa`X6)?bt>HwT@T$A?G$af`fK&a>bn7t25eCEQVvx- z4<2rkrAmV&SA4;~SG@0fcDdi{9NzI_d!y@d*U7fOTkBgMZwYJpqj^npK=S~n(^2DS zb)0t$by%7fH(hT0xp8mf+{PCgs_c!nYqk`-#9m~bXWwegvaPV+Zd}|r%dy2FZ64q9 ziA&$UsdE^t=%TJ$(p1=npn$r7wK2r>Nua~jZVEFs8`m4ZG$ot# zre}<~hHng~jXg~&Q={=4;~Zm!akc4X(^6Blaij59W4P(8@lE4RL$u*H{r8~5R^37E zR~ox|y?Ueig@E0vSC#K5+T|{;O6(&=swYM&*}mW{13?MWReo)^83cST6Y$nR4;8?ZZ|C$w#?exD)FINmfM&}@1v zaC~5|z|n!DO!G`W<22&~#-T>5p|3ID5M&%+OfxRTxHlV58haS`8}}K98J8H7jJ5dg zW5Z5EhT))olRgKM-Az}lk!gME8R|Cz?x=3Uwk}khmLHPcl;#4RZRkqyo%DoyH13X$ zu^pl93tbo5a#~M9N&dAOyo>l#%Zd7;I zk2d5Rzcq~wniZHAR1h>ZxH+gZC@tvcz^cH_fj(1e;3K9>#uj5g(>vo&7ELB!TgQ_vX@whXlHQY6+BNf{C ztnYJ4u>7h*6R<`-NxMY1+pyi38mJH475GHpl%PpL3xZw^rwfxiL9YQu5EtKjET`n|f{T7`Cm zW{`SKz!249Wxk@D!Y#{?9hJP%^{Q`!cP6-)(0QT#nCo}Q>Bg3+&5t=JJ9ahgX#Bb% zw4uV*!?wlxzGbDQsD5XCO8p-5hh~-esk)nW0k!tppxTW!>N;=D{A^&K}nyS%9DEVIhDsUp<|O|X8Qex|9~bR+QhAa&4^pthh@L2m{98CYQQ znvR-om|inIY05T!fe}1uN;JI(K29~gZ~D*ly~z~#nQ5FU*yJ`Qn{pw&Zy2)-xrRD} zLw`iSO}9-uRC`t9Rj*Ue4){s69*cv;iXrl)usBz{`gI-juJR0WH+CFt-|u>&ZCh(X z%be!J&P3-uSr+`s(^>^9$ym>n5PM!<^bDYUbCvYrd@g zpl+{uf%S-OR>NP72S)WjW2v#s7;XB-INo^KFwnTv_`dNS;}qj=d_K%2* z*vr^t_}36>m}=Od-wRvwiT1cgrCFgq8K4dbRvDC&6r1HEWj{%$Nk(_2`UZGqo)mX+ zM`p(@S9@Dy>$R4`W<~QH=WCArO>K=^8do%!?Ml1GHpIHl@=yK#`l0o+%nI}8wL@!< z)x2HvQ%!BHwPv)r%-q+u**>>%VAEB{lg;DX8eE6mb95v0 zd8R92rIE&{li?x3GXqB&-!&XEylMK_SYmMN&Zu8imdX3d%)k|2@xJRl=shaUQy!34 zDw4bj9VN}@TIKE6JU{wMJ%6@8+4_59f&Ib8KkTs$lk6>)`L#Q%m)E>hJ*+mc##Onk z>}v7EVn?yL>}1&^6@#m`S9jLr*?(`H;Y*cIQ6;L!s$bRy>jwvK2+s|BILs0@Ei@}+ zTks>n-|BPJdGbeO|6s+O=neHPb4xnDbDe1Wv~8{HO6LymF3DEKAdT0sIdoifMBLhh zgv3dSOS(@@ZcMf&Kal)yVtL%Q7<=Tnu;qad>)TZelzrvZvhQRC(%q8vU6Xz5y$j$W zpZAvfzLu@TRyUsQOg@Iu+EWdmvsS>JG>=7V~?@j;U%=wM)W$Uk9QBX2}B zg>MTF4x1nLcd#k2R?{ZayBpfP&HEg8o2EJbcJ6jGI7hX7*!FCP*B#L{RMs9a)_5%9 z{&+=?AA0vnUzUEXZ*}jTy&6)Ik~YMB8?A}>HKZ^&LYJ#emfn^4T$eoSS|h!Rj`3ZE z-W2&3c`wx|)eEZURo^J)DxQ+Rpco)ekUs1i+WCA(Zu3`;49gz#LzRt{scC~m8c+x%J zdI$L?Nft^US7-xX){ZcI6u3HQVsLoSDdP-%t;V37BKfm(qU(IiX2$|sxJ6rgrXsG) zS{PUGS>BsDf8G81&Uabu*;%<~@@Eu&Q22S_nWA@!s!Kf;8*9Hb|Lw?kuMSufvMf40 zDWLniDe8o;ROwgn)g@F<6t7d3^1Wj=2p3EGO%KH~&|ZgW&@P`TP8KiHM%X>h&by4^a+ncmo^;a6L^?ODr0tJ8L_aV0`7 z;mX7MN5ejexzhdRo>A#5GA<1?_aBg<>Ra5aAazxDZ|s(+TOqxT6V;dGk&;s$W9N|e z;C6Xuu=ld`j%uWCq-kWZC1ga{m9U_&z~Hw{$+}Mh*2&v^7Q;K^g^a)SX^fH;N1Rud#y=16#q)Mm%@HAy`%jopicHc*ZrQ! zoqIdRBZ>1yXSy#(RuYhCxDuodw?_OC;fQdBO$eH!Z&$AGTHoW`v}_sOK08m z+On!{6)%^#@)zb--g)9y{f*R{Z`|3Ly)f^=qLC$o%R0(GtejeVz&zGo+?w9?jb?CI zMQm;A!9G3vKHk^R`zuSqQ52xo8%^7T4Xhz0_@V(x~2j z`%dk9yYEwJpCnnMb{USSbGx+d4>v#2_@KSQyuJ2?+S_$|F_Pcv##x`W{oT~V71I^0 z*%fp?YDUWF^pwE?qn3<4IKeV@)bJC7e(60oX?WCeqd8!%w5#JqS4Qg#t+}w}?{rM} z%;<`ityZ4X+6_+!1_rMSNeo>TGR5?^ZmRN@Z<}j~Bgy`yrKRRz)r#^QBx>&F|C;|& z?!Q^myC-hFcl(jM&f9mgrf1K{2`QXkGQ8@yhTq*EsBVUwi(8l0um3*-OoRR%)Gz(@ z9^c322CdV)Bhk1HICnK&w!1CT`X}o1EdQ7j%x_p+mX+3ZcDKXTaZ`3n*Du1>eOkIK zGiFTY#P=t^Hu2z?&|!zuKkfcw#2ULCpSo&J=-hyZH&gA@(vo+_loCmTWxqI>M#XFU^_GgdC{iOK$y3m#` z+4lB&$uSJ z-M*Ew#j3B>AHuI391s~0s@x}AFMXk_)6>wA?6S1ZZ+1GyH9c>8%(AEYc3D;NV|f>H zPh`JxcSZKetkZYRS?BI<%=s}VtuVSQtnT^Nb&})82?@J;b|2C?;)US}qo)tK+2`Zr zYvIdv8zgg^-?0yY=KN7pUK>H!{89Kjn%lnTTsB8RV@7kCYo;eob}(R=?w0;t-7lIWu#*ENZ+Oz$ z2ef85uQWOvK5d-Tkl4^@%Q1Jz(}_ zJ6qyA9`=4E>!wz&r%|DejWsjCNR6J9+-Z8)H8~x0RY%a^w zbzAGcFxQyhMKt2Ht;tqff3UW;^1IUd!oCF$<~*3~y))*{+1uaTiOL;VY_5B_b(rk; zu;zqMdkr1@_<*88r5O{`UQaHKeIo4lpd9tPip^a&I=8m>a9wL1+q&4PZR&4ZVSc#! zWYwIyl!m9;dj^aT8=F$qf6K7iF{j6NjGUI)+V6+d&9RyghjyhRr)!+=gRZf%H&ric z?rMW|L)G`oz1~~sqSp`50x9Lj#YhM)6ZO0|GxDT>kNy^{C@SA@|vQ6{G&Oo zcYnH@nl&@KGcU2ExGLUZ>%6R)9+{9>-)nrogMHrYb1`{BoG)Tg@M%N2_Va*?$_m*~ z=@jp0?(6Ltt@`G54G&uT)sL%rvo6-=a{f>Hw9ydvYWkZ)t)o|rQ;k}jnVX)I{CCXR zkQ4fO0UyXGNTt#jqz1W5xiMf-Ky*NktjpWvTG^Jgh#%vfA>M zSyFem`s0dm<(&s4pn?$rg6CyT^AZT4y4X6@nlxO1B_IONj=ZSM(# zE@lqNY#!XZUu%!w6J|vn3%O-lpnqGtOH-nrrT$z!FW|T;U2&|dmwQReOZJO(F_qKH zrj?y3`=@e8?IQD|mPf4J>t*KL>dz|F<(oXqlEJ-oFYt*t|wtC~Jys+QJ7nKR9g)w-)cE?-sBzv#F8#{8thKMI{C&z9LM zE>{`q`r00I#&oV!3^oi7|1@So!ixz}@j>0%!Y7A>2FZ~6%`hE9%;c2z!GK1Y+LzUl z-u7Yhe&+{HSM%%bHQsfKi^lB{i3$CCzS}n?V|>OJy+25unlK|uANF$KYNNu~tY2x^ zVsPrNXnLu`H|O2W zwdH+T(7!0AthxM7)up=MZJQjgcD^P%uCs;cyZstl6TdZfeRN{@v%$lV>V3>OGVnW7 zjA?}7vibwX`@Yqk54Or%K58y(iR%dL+NyZRa4F>BZfS|K)H!LJd*=4&o$_*GNbIaA zXIOLaoZxvu`vMOKCY!bx&T7*F`pY^!3fD-7%66#kb-KJ>Me8Z30T`HS2+51(;r>&UxoXM>(*S)DJ|`nLJby#sl9DCmS1Og&!8J!}*O`0TYn>muD-?T; zcf;?GwZ#ueSRcDA$`Lv%*oY{~9z?y)8Ul5*19nLFyU(?L;+WsKzwuq?A=gvB9?Dhv z{GeaLPe;YZyc9D$_T!k*F%!GxMaM)3Mn^`qM-GorhK&jCV=UF&R=m?SyYon^&+$a# z`}PdG%zoE)-1d`oR{hf2AytFR))luFIPzB&v=?kC{;~A$@=KL3RxPV~xO!vF2j=G* z9%x>saY6Buy-Ml1>aE?5`b5PU79D`b0cm2s$MyX-`#t|h+lSzC)O zr*UKJ7EgbyU55e~NTO_KzM8xZU1v1u=VL2gLmptBL(Ax;cDG$O+>i%>?CT z$w!{M?TWUCn#VX&nhF|f?JnC6%X{WK)xTFPD7{#;xUgs8yG75Iyj^aq%&1Y-eo~WB zv$*C`ZJRaGImaUl7!|ZA>UvyV{A+RRqj!Xj3Z5Nk4pfGehFy(_jC?HgE#pkpEl*){ zoPBzIAmC8z9HTm3kQ`IR>0w2J@`8T|UK)BiY)i!S=X zOXpQBk^kAH_ttc-YM<72yZLoTMZ-h35X+}^d#ZygCzbav{i9@m@%fUp(ih79EdRLT zR>hMQ8!Nu9*jSxcAL;m`(-<&5XnT}9ZdyV~?EO*QLKg?dnw|-qAG$trSGTP(7b2et zics(IEO73!9H~7~`=RAP<7!tQpHx0owLyJW{ko<@qt#z9{TXsJa$MXCNk1p|PYp=k zlz2VH7JeYeVE96l5>Tf!DJ$elq!+#Sb;L9uu=lTjtmahJH<0f2He+J%tY! zoX<}xc(u@8ytq_eURg1z<^xMh)1{8t$`Pgok-x=GO)$iN+wErfn$UA0=fj*)f5x6n zSP_3H%4%AxnAQHLJ+m&WW>4MMwzbY9u7BJQd6#()dd_*9e9f}bfY?Ascu~xrBv(pt z>co`KyI+sr6r+#2FFZ6fGdMHoxamd1UTv)EX4k}ye$4~yhs+1+0_x_Q`tpA}ZA@?`)F-wkm380TePiPAIBE3pu%$tN8Z-4fG_zD|W%qf)+hUq>EN|6C*G1Ny zs#{juR5iZrnSAeE)2+rE7jI={M-)C=*1xi)vZ_K=wW984o2ui8>SSn9Tvzh5DY;3N z-EzWz4$BVzKITl)x*j)s6ebOecti7{``@N(_1nx&>x`z+&Fh^CN1SbL-B`2Qvey~v z%?=0&Iuel|wIk}~sJ6(AnELp6$!R^EJumloHL1E=-_R#bv-QEcSM)t~jp`=(FYdPH z-nORNKWpaIepL6J`J*~d#loVuvf^%(UHN|;odtLkX%olSH@llOb$1sEw8iCsLyH~m za5#s%9`14+cDTE1k%PnCp)GZ{v`O2zXXP8drw?gqlI>>Rd2QaA-~acI555-va29%( zdX~R0b(Ex3cCWi@_=sIlz6jnI`6TL9_?)17zB_!U1&QGc5+^sElzcF@XF#n?aO}|k z&@BNalubt-ike zHYNLZPMiGG1;dN_m5ixAuj^o6Kqq=)fu+G!!FK{~1a=9!5_%^tD%mH^H)Ty)jAC#?35$Z#+lTbtz|Fs>a**994Yyv$#l5MHoA52{^Bu2xl?|??W*_fuwjYu zX*q3jT2(ePM*kj2c#>|b6p8Ns-YATROXMxFayzDbSa(QOTVtzxt3IfiQtMSE7S(6l zzjpX^?DNt>RWSW;Zy;FSVdf`;qNqd9Izp+krrWnZ|nOBoPqAWo()igqCD*Kh0=h4R_+jFO{ zDX1~BCh>doS*e>_E^GEKVROWoz!^SHk4z=$zSg5iF_O82?{T%YMCrO}erQyhoe;(D zs|u(JDUs*VIeWf0W)^&p{IMaM&Uuop%?inBn#YuSHq5rGG2GqTXRWWyuelfHdo^fU z*!Orvv#Tj1o1Tm|1eGfY3gmMknp z%b3mFJfEDvsiA!$_r;!1*w(}WNm1@8i<6>a<3c+5HT4*)(91W-$1Bb-4MacjiuI7r ztKqTQp((CU*DSBIRppodSLmDHCYR6sl;1ipDJSd4AK$aTnZ8B-XkCm~b~jc*>1s54 zU4Bn7Rk=gC+52#aJ+devwn=Dmbo~0zc7B(Y1Gs!TfEq`e@z2tC$9??=wMVsgS#9~o zs=amT+E#{p=3lKb4m01C=;%TC&WiMlI~zYdk&KUuI}vY7n3TAnNk9S>`zbWpceHyf zH%ay@SxI&x+|hyJKej%meBH10AL|Fz=V_7}0@V9dORHX2+%0`tbfxG*9-d#Gvp1(@ zPHJvujfpS{X9+Ko&m{kdSsCW-m*TZvIYM5p zD3*K4rsLZj)Ah$S`>SM?Q>z%&x_U~xRp+JeX>b{?nosec=q`T4gFT~8yFV zr*n1`G%IOYvPiYpG}Vk^TZj|PTg4`2bKj~kKK@dotXbct`SGu#euM=EwhO-Jf7CnD z>x3*FYvK9d}emBTTkzkp)*5XM`Xn0 z$Bbw)wn^KjQnUBT`xDY5cZTHpKK5GTk*?^Yc*V5GBcwR{WRt&XtWJdV)BeU=rrm}L zL%p6i62?MPH#l?J1u64hmKEkNI$U!M`sLKhEtOH_VPyjgPh`!^{5_xg@wxVtzSPkf zixLyj7szn*1{33_4WFOzyvg?D=$MkQCxLeFlYS*$M&GZ#DtVzSQS56xt&XkUUG=;= zwsw))ruM77Sn;8}cSYyQu?-*0--TyXTSc1ZQU4GAt%7=nKM600eV^2|*}u(=&CevX zirNs8>Z|pR_fmPx_4uk-PPIZeJBux>Wu57Y<+=5gt+%ziWxi>dG0@!2e9Knl(hDYW zh-4DF3h9pP)?}-}kg3bBcT;7R8;aZIcg|~*v;OmyFQM;pzkK-mU)kTf_XdvNX^U~5 zb%l$|IY*E!N=$m)v}?jIv3G*51jqOX2OSJq5_~yui*kVMy(_z6i|Sj=`)WCy>ZfZ| zHP+GrMe>5aC5J1s8;4sv30_1ry_BtW3-#|D%toD#Cz{SqNk|n^wx!%n_K7yLNl-`X^$U`w8PwCTHXv2nfrt4?WnZ;Y~?cYJ}J(K3LUI<(Eg7T+Yufyq&xHbN1^QFNo(|UatKxzame+$Jj>bV(G>^ zgnxze9)aPvn{;b+w&nKbA<41P>M&zaXrwH*Urev?x!$7`3wW}jecc4rRrQ{RvziBL zZB>48X)c~qQW#Yos`Ise<8PC{5>tp{>@a1!;2kkr6VlU8rYTa-H*cPBG0q&RyX zY9#ZGSuNX+okGX)5sr3NXQM@PTeYTaRH;YNgM!F><<}vfTYt)a-RzC$o5IYpA2Rhq z%RXx_VXD2FPbWC~ntw!c=VpE_E~Ra5@-fCQtTsF+d_h#}7+u7Hz-jD#^s6pXvD%VqeuTmEdh#^?5b{EoNM7tmA3Wk{ z^2HWn%eg70aV4ZMZSfWT(QhAZ`lvkm!NOS?>2vE{?_yT)*mH>O)F#hblW|?x7fnDjJ)gS43H;o zNpVjtNtHDVim8uGj2jYF6RC__8Br7%;{FcZ>4?x>svD=#X)e}oubWuiKJVC%Wj|); zG%a3V(;wp4!M2f(x2|~U5MJ)lH(*~>Ny_pRCZ%7~$|iH0JdHOcu4&RPDKhSLSY2S4 zax}A))QVA%)!2vs?igjtF`Up1(W&zsNw@5Cp8ZBA?xVxoLQk!sN^r? zA|h;CGC?$)JOEG>;IRF-5Gj>}t>y*x|vHTQ$(*Yj^b zfBy6#^}DHHNLiw#ojBF@R><+PHSKE76s6^j$nRIMvaqVOO+%(}muaP=11PfgrYX0}et99C z63SCrHEq{yRMT(Gi;_Mj7z?I4qLJD>bT}bFi73<|z<90*Y=>A>V$L%J!kL|%!(O&deY^Ff+ovk6T z${$$YQ`1oWSINKm<%K7TuIFmMrvI>hNz6R(ZS$ASpDkZ^zu)jJJ8xp;LH=KAwv{1j zk*VBurDq@>J+XQBlmls-Tf`e)k;cdBBx-cubSCY zCby_;J}Yi~)Q8B7pq#*zfG1wd#i6pZ#g|Jn%8Ck-3a%EeF1cA& zuZlBfSz0*WiEHp!w#dUL;6-?gsFzKLrHoA3&~$vW_szYN4<_A8(j`odo*1U}pXXua zw6eQIkd)8YIG5WKOf9%#BuHp5HtMx-WBo z;D#ypDJHoMlSgq|$l2IBq0HXFhMSr+6x6n=vXy-)^({S;H>PA#_Mp65*^2KIeq?07 z&%RzTC#Of=x4hqrM&txmZBTR8x15}L#84cmi1mIG(kW<4%(W(|NvQzM-ydTL?HX7c z@YQocz%P&w^&Pd?4VGg4;<_#BXH{C&jw-UYQ`O1x^QCPnYO54#Qt$6<@1%gu>Q7na z4v&ceb3)t1YnzD4Ym)~jd`bMAurz*r{OQ=8v8^M5g7w~3w?M$ymtlJli(uhjIYwD^ zrXcfZ^9~DTFLHF^7Ya?pU&R3IJi3Xjp(n^x@>XsIkSlkEOXtejuIvzoldYlNqP*C` zam@bQl-+osZnKK5DXLDa98(%lG`w(b-pS&3S-yE)v&6jjIR$yW@<->S0tZoq;WUp2}%-i=?+1 z_0{!S&4aqix;dH)wSQHVl*W`4m3vlSuN|e^YwT<9z`3|ohfkkm;vCk5B z#?OvBlUNeRHklRoCaz5!8~rj|=eyTui{d7m0$G}|#A7r4 zAVf8G3HSL3F#QY&Y358BVvRZ$kFtujp64!aCio)SjN3=F)F3+7VE6HC zEC1o%Ow{VQg9+WMP7I0w2kwxs43j75P5OBf%Ecn6xbROVauHE%CO* z!wLPHL`QB3+2Z@qqeda)2Ffbv9{4k)r?}KP#P;4CVZLChwA4Cs9WPvYc6Y}(r-Cox z|3%MYBdK1DP1avFnNA@4LEg(&qCY+gD@Xe96J6(xo|cb|>Dn`O?bYMdFRJ2da`IPJ z6y*#n%qr}k^)`Rr*MExmud8y0ee$Z~}7~_p1Mn&}Wf!^tHIrG3iP7V=u+0#a~TYkdzXW66F)v z-M_oXRyTL98N}le!UNea4zNA2Z8yv{$xZLfGc7*0UX~m7BjwIlm7S})YUJ7{#yO^3 z+e+6oX9sjQeaY>)(kH+(^ln&Ev=sR$CL^Lhlq-71r^NjgT^T6`2L?rX`+0WZG8ltw zH&hPzNykJ_`x{G7tDo_`d8YA{*~4<(5^N2&4zWJ9E^u^rML_+~V*DX|bs}#q-YyZZ2t;cdRri_hWHU?(1B8 zd2CU-x{rRcZl*3#@9ml(esJw$lOZGMd!WCYB{s^tPe9Y4A|G@3JVg$pA=86+Y4gN8=jfL|E&4^Q&A9uAAt7YzZ-me&@EzBi+5& z-{~72^d@Lm!1U0q0o6gH0wx4z`zd^$cw_Dd5iKO4TbT|$$Hzb$Q9RO z*H!0C%O(3TTdgs{^2Bi5s5BJRkI_Zdla0PAvSD&8!v3zzYXQHE(dc`_eGb`LXyN4;Sx!+V4yLq{_f|=(kWyEDz zJTgMs%kPDFn}>6UL+R*hYi;*{Dv(Xi)_ezKIHnX`SUAa2iTFg>LB(`3MzK|1<@S}s z-7-9mv&)qWj+=;2%RkH#E)1hUm#l~l%d%~wRDPjfw`nJ&S31C(81M#_YhO?So^Pb zhR1C6E>gFVZ>k?;r%k6!0oH}utHwOtEp24|I+IXyNS{^jrhQfStR6->%~11N?O5Te zcCjccGQ^elm%QLU$6ki3&`H z4MCdW=dm226>OR|?iBjC8 zHxj;{ZRvf8Q~pqKL>$cy=Ss;f>}K>o%8%+!w?HQn5*1BU{A((Wc!bT8 zNPIE*MZ&NU=>{HyEws0w`n%fGz4&0^6Cx3D=JjX?DcPAKjdcvfi!B6xsIdtn=`N!~ z8wU!(!V&c&ys0`Ae_6XpTx=@S%6LD`B>uj7p4d?{*45bHZn_|Tu33k;8m?Gd+D6&7 zTfbR-#a@=($S%_a;&1Cse!lcf=tnNJd}0ps?&t^%#kOKrfx-U5HsUxshE_Y|vJ7II z>!qxZonvG1R(6@}78!$2w|}Bd^fZA(2cr3M#&N(i*wV)Bj{8PqoY#N+Z|=VmH1df{ z^>%kYap$R#Y>M~#`WJ~Y4JX6uEGA}s$kO_>(20CISr?B-8sEeNRf^E=*706Rwi)hE z_+HXoWt~3F|7-QJa8&mLz2lj#@(mtRxi)l1g~IovPLDT7`?}P`CEZ?K^>-xGravzIVgnlSVS1vKFHzl~gEnn>G zs@cv+_8Uf%9I1KdKc#rTSA;#pu*c0`bH#mvY6>^S(b16O5nKK=uuOAQ3?-WD+qieA zwkQ_ZIoArJv!ShWll~HV5nHF<=zUXlTltSA5n1CQR7?%F>ayii`C{wwpv;=>L4P)? z-1}1dY^!`$sl=$+mMlKO`(N|mpxX_zJZEBmU|G~q$Gza8rXcnJ;_W_wSfvm2%XD5+ z+;o0KM=GaU9(w$v8wd3=KkA2io-|MLEHt*DPw{yY%ak@|dwkI5D^i@i0gLKV^vuK0 z6yvtQ{;hs_h`FYdXG?t^d(M8!zTCe{L6zqo;V|V>+nGD;kA3eE<5B17w%OLzFe2zq z?NZDG=_S;-T`eGkIO`v-0@o{P60X%ka-^=ca6mdvGzf+I3M!NQY=RNpnx^P!ekNwx zTG2sLXYF#|SapBqt?d?2q_3fXoJ-kC^+IhC@P zkf|3D^tK_8oa$`FY(&l41%cgad-%*UsmM^WmHB|Lwq{W9HcNyv(|x9HY(SDa%Hx=$ zO!lYr&Gg-;yS_cs#EDXWAen-L*yOlGr$Wy3DWo+z1ixjkWp>)<5!w6);Sbisbe#Ul zNG&F+9Od{ULTM$@Bl4lz-SSFXB0U4u8DfX3jYHl*#IZE3(1kc~v7fS9!W-j{+1T z9342a@wGAl8>`yo_r%r7xWTVSQGMLG`p@=4&vb1aI?3Z>`Fj6B^#Q)!T}vv|@zV`G zEF0KGcz0*6G|gV^zpjGd#xTw5m;12VM)ymGDB=k{-Eq|E?mkvs;k#A;jIs)Z^<=<- z>NxkA_Eh2`3KjI`AIh)RW8@x(+=+y?%AFIQ)zCuG%YIrKMLLNd`en>rHnwI>#QlQg zm@DcZ%ys?^njvga%+RzUj`@C8&0*FF8>r6w@Af|Kd03~K-N7EN*Ungcp#68RB<*qe zC`&J5jyRld?;2%J@$wQny3%dweyrKxjKspxo)!p%49mLtWNKPhG40Yy~iJM+@oGXcELB{hCv^3qsT2F)RqK& zVVi4^{JLc@`HQWmU+3EAtQ+GrPZPQkQ{1Oi8-DH%=(YOMNTW||)e!%ab(!8jtjmSj z41ny%g+{I~jF-wj2c z&0Cc0;*MFmx^E%}6cofIXvTQ%Exi$T&U(gio1Ea>>h`{DMs#xhSa!Feo}R8qukPfh z&~H&#?Dxem0yD1&=%k95x3{)c>P@}GuFUV2`RqgEu^;GR`YS6E%{kdLrglS zkG!xxHQ~%UiXrEOV6yi7j>&r>D{wF-TS8DKcRA9iO*ap6thhkB*wNJBQi!)=|(L$<*1#=m*vG}3^Nw=LvmnIAKShIHjq z{-m8l_ep=a&#C^?<5f4tAA2jd;syGBv*OTEQswN>uKnT5hK zsKZ}P`kQ;W>#eu>rmWD|IVePPm;4j&<4jcUX(;!&Y#iV|-f#qIr-(63^jlpyJb10C zQQSlSXF}co)O)+PwmRra$7?ai{kOXHUgu4itf{!mxG>=Ny2oCx454% z{9?UHkG9lvEy?5R+a7l@F8= z+4OFcL~Dt6)|Esczl*7~E=P|7?Q{aMK)=#G(maXjXCLMA^7x@?r7X8R!-`~Y)SJWp zDj(@TPxIPipE(LC@Bm`I1*PN3$Q|bv^jw>IXG&sYqi)E zwFspoYgD@Jx5eR%^xbt&w#`}tTr^B=Hq zae}Bp`T=J@LfkKVWxa^*$A86(#Xy%rd`|s9E((OJ8n2L2@flcKzP}@jJRv>A)%+4Z z4j+zeN3)SCn}$~K?`^5dg)i@xLK#TB2w~M2ZJW!dalQyBlU^n#`STCorwa6}X z2vj(iq7RTF>;#&ErD6wB3%U^1Vo7Kv+K8-2B_J=f!M;fg&~r#C76inIzknzbh-^SZ z;CkLjDRN(`NAAF`#1p6ygJ5R8iIf9V=qi#6+-(v#2B0*7Mk5>1Yd~R72WrS6bO`bk zbdEsj5@zkU=w0bB==TuFVw9JDaw-l0abT{*psRq4dj;(B7*taP(5I*Z#>jBw71{yF zHOR2Bak!F&?8a_s1VGDPqY{K7p;&! z=oqOKm?v@|rX+#RMJ1?P90Q%WzCcttg82P4)$AvMB;qak15-kdJP;}5zNi9jN=NAt(nJbE*I@&(P&}O= z@WbeM9K(mh$%zhV9=8RJG!<&*yrneZA7P+KO7r;rVl=W)=qvtY#yiB{@aZT<6cPl) z;o>1Xzm`ma-V2e@fPs;%7eYdzkR0$4#KA*zx;e+`8{0?EhkSF{C zSE`eCA}4_?up99Lfmi@NvA(^~p&)Cp%|5 zH#yci8ti2b#2IDp?8vnj+f}w5jxP2DdxXOT-xoUu!Yix}i{rap?wadTINe=ou4F#i zb=qljv3xkcfuHGWCxr2T2tGV262e$9SGXP^~^f7A%XPqp3!+mZ-uPVtFujO_zp??Sk@oj4Eah+45@#1*n1-A}fhIm-UUaq^w=AqtaRkZ)AH zkhfOM<@R%L*-dOU(~;T8T$de{Ma$H5Hm#uDWpC+AG85fS=E=-q)-VU5YP$s&&Ygj{ z#YFiTZoT{)&H1OFa@$q8BhO4KcE)DuBV=` z6UoFf{1ohuN20BelTxm@S7;-|@m*cBox>dq?5k`Yt#2$#&8tmy#yrDk{c~N4Zi+Ur zQQvUAKBhij(?)#|@>y@yE>@kXIZ)H2W?jv2)v21jb#K%g8d9~Q)?&D7yk&{E^Y*s< zHNFXA$0ic>*uyc`OGa(LAYH~+*8CWf)w}V)4752k6Z;; z1XPLr>^U}q`^xs^cC&%NVW5F4QNL*;h4PJU0mN^aw3$e(i?yXcR@P!Q9Wbvhkq=+MtkCW$Y#_TSs&&!dzzca&r(1PS;ICk2IemO%Vb)!levGnO_S7BBRQ;cyt zsu&AwgD;BvitqA{isR7UZ#a}|%|2rynYpqYx*OG<8cT+gHwh1*#Vo+D;D6$6@Eh0| z{1#RK9GTe=$!&*sChBlsVg_!<-T}M8j4YPY#n1c$=;tjRZS9k-^DU*OMq^*YGW}z1 zMq_=$@_PUJ^Ds4+*0!x(p_-sl)Ob{9Rk4+URh=ub>NQpURhw$})N31RbWQZrP5+o1 zY-)$vl_2hwE@B(VNmQQfo@^cPLp&5e6wMUt6m8@Mii^;ax4AU=f7}eN4_D7!f!5y4 zUgxSeMNT`KRr=rQ(_ZM9Y08IsM9rcE%0expJJBoXXVhtG1C>do zQ9Z~jgdA^;eL-Z>6k(#v>=3*KJdYHTs&2 z>V?&_tI8|?sT@-=va);m=!%gQepL^vRJEqMT@BN;19S-x6Si7I?G4Vw`~&d{@*jST zm`pX3VT_g?$vp?2h?;eAPgteAl2vkZ!QVyp278#>%Eoi?><_jZTg7;=UD)|-U-mbk ztq81<>n*R9kCJ}{8plS(H0TH8M@Ooprhhz^1DL}~-DcNj8<2qWe} z)cz%IBgPZO#A70w98aZC6l8`p*05J=*NrrdHveJUVZZFUz|WMz&}H~Wax8URwvk!Sz2uI_|CQUh z0rE~WnQ%tj7JB@AZj(Gv;R%^`gMi#I zTw#IvU^{SHeC565{ka&`!J}15bH|ZAs7f}x<87+ZM^ocAa zt<-S3jF!m~z&{!0cjhD`%6yqsvdOY(^gHSwvOBR9--Y%-+{F?6W9KQyJ==8aALhNL z48vf|1urrE51seWIZQ2SK%q$a!Oh^kffZ`GeFA615cKg_QhTjQ#}qb{#c z)fE{2F+^KxEgc*SfrZsuB%~5_5%Gwm=()6#>C0wtdE9Vr344Nh%4}p)nYGLWSs0TB zJ^NR74)cWB!bptBDF6S3GbmfYjAc1?DQgC66a$^6t$Z1H$W5*tcaFWqP;3mlmbu3C zpq+F8Em0HbrR00^8%_`fkfZdNc#J=Vn#4U+6myj}(l%<8tX_7JY0t{o*Q|`|$jP}| zFn`QoR5CZ3q22>at0y)G4VPB&_Z*w;W2{C?JM#(S0)t+crHjz+Z>*}%(JX;JGPUkx zUA8K`_Ll0Sic_Ukqc!U)hgY|%QdM`T`KoHCDQ#?Pd}y9-4YMC~9CtkwH=^&bRYWaS zMfYObGou(alg;d8AFwryjtOF#vGbXSbSafhB#_DEb?Cd%vZ=DM%qhkWUm9jE@Rizf zM>s3@50@bSAuo4(2t0;`Fm6(WHzcjJ&uf}mXc=jDfxhSMyiR` z)GYEXkKs^r{t8pqrA2J9-MZEvrlE0Wxha4Sx#87iD(;X zw(!t-($>T3VfoM0!?;F&PrDE5DGC}NH)!gUH7o0OYSPrD8h`bXTDRK$)iY|l*Z9^9 zRGqF!uZ}CbQ(0M+SG`Alvd%}}O-Gu0STEWxIw$Zg5gISTu2XBNd|592lJ3f^qsOr$ znZv9kJ4H{Gji=_3{fWu=8ln;Nq(+by$RdiP&eB1&j;@#afJam@eryP9lC5Lkb5-&v z`BV7=?wGtD#@)AEIGe$Am2F|XWE1Fo>H+zQct>DlTQZJ(MH+}4Vmgsda@1<71-*#g zDQn9}tPhvVjplA~{pC&Ni(q{#A^^uvJ@znkN`h zW3=I>ZVS}VMrj{5&TRaw^{BtlFaeO%+cdZ8OzIJ;vbv(`&9xJ&hEx?)G8NA&pHy9{ zI;(E4ZUnMZcjIEqAlo|LQ@V&wAm3p9>6WrnGC5O8oo6S=lG&R~8_tVa$26f|Q4YM3 zctQ@Rw78v|Oq`}7siD9GIz(TPyw@t*3&G*Exj3+P{D8)b})XET|OvJG@T89{i$j2=YfV~?;t_*;A| zeiUHYCy17WiJCztGZPs%rcU-;c2(vrZ{Px8)RfD^6k&3+ytCplbB$B5!LUx;4;-E= zgbIIxXXCwy?Lg5R1jM;0k|6uTN@5AwnR-kQW>T3Uz+_5mR&>F2z}9cBi}&eJcc z>(p&x4DlW{A||9Q@E1o4IInY-JNzAacGSMtdePp@*3{b4B$}REIdc>9Ra1>&sqv%! znPHed$Pl61+!)w!s4=4cbp6BH&$Wjjiz2uFWW!_KL6cK|(%RRQ?&ydN;=PHhXg_Ke zGln?E{YM^>VY2&l4ih0`xr4Itv_#GWKR$-PAp)^fVl);Ab{T}8#n+;vpiOsx4gIL6 zlt6YS)9I7+e~g?R2lnw~ih;}KV79O}@)qsJ{3^S|1j0!32+zdc;$FmRB8gDr?@$9S z61`}QcnZ+yTH-JI4)vbY(+lx5YB4#Sgq9(j%RK2}#3iByd4!xs)nKTyuGMaD@eA)@){2&5x+ zK|F&epqGhQ*p2^9r%ED7v29hCLQoY$_RmpGI3arV`gs z3G0YVvgZgFi6Zi&Lv)yl%UBBPVdH^VJ6h84_r*q24%LXbJ1^S4nOb_5+fZYuy$hXtHFMqKa0;pZO$s&3;wZeg7AmysJRC_ z9ocP75{_f-v8}H4{GY-#oJ2>6?OpNWXu8zd*5!_QphLyYjuL4CzLpfMpPUisVJb^} zC8@>R)K*s%Ka0vh4a9T4AJ(2fj|`Ee%W}jz^HBN-F&;kScYZSQJ93Pi=_(`7QSXJZ z$V}HHdbunO8z~$Vw#w{M2WMF*9=SY_i$oc}8-0KtLh|vc_F?2uu?in0 z&U5@mX|XC{lf5a83e|WE=bx@GC}^${EhWLS4QnDrQCs=r!Uukbm`(}&eMbj=h5KK| zIeZqfkUU2Bs%Pab_+coE9oGNu`PI>xeiY?Albb(P$ZNbl3BD`XCX~a9(>&DZ*zi50oGl7}@RJ2xSd7rm7Bm8)%X7L>!w5U#*F5eT=WMmq@8K%Hy22knAQ^T8rIp8+ORD1eN)`cTB@l z?hbO&5aV;Fu_?0w4Ylu7csovu{fJHWD_oVmGhT``bNr)Fo7&Oy`0kiMZ8N--wIgov zE0Es25f_nk$bXnhPY}9D&BbGgfjG!3K=&*fpM;h=E@4v%S~x0A;U9pvuNHlzg=kki z!gUTmip&*`6NCA1ViUwCyOQ^v6~I=j6D|?c`G;r=@{;R!Bp3f7CgR1?Sl4n{j&m6L z0yjBv@`{)#%80$LPy%$yg<^D-(3U)B--CBWEY1*euImFj21(-)VxmwdKERhcPod#h zm9q+8!C#kfGR~fiUqSzHkyrubF8Jda;!vp$`nXx-(8b6%p$2^>_C_MGnfyWQ1hND; z1P8_4$PBEN@B`_9CX2=B0%VA5J2jX0m*x;HT%EBS80x%1>cnE#Zu+XT4AEhKIpe7z z;tSVT$`|A-T43LVH0-Ff2YMBYqzEkDO$Z3JIqv@)^`C@&B{ExTgZ6_MVPE8;xD%ZX z*GfWoaR|B%Vu>A)-4cubC;f-41Z~GQ=pcxq)`RYwH`-ge2VC#JBsFqO+KcRvmV?IR zR!}lOE?oz;!JnG+ccp6(S?!5j1r2#M@Vokfj%a6K`7V+cAlaY`yae7g64?yV+W;W> z=7APr8bpv&kpt30IK>4WY-BYko;L+$;koe2mB4SF2$bPP(raWYXm8#I_A)5D!*g?h zfcrc0Q}z6-ln(FwsZuvt`iu+#R&yY{UIHd|1w7^gCHp<_pKk)o`4NzI}6l4P2 zT@27QW8m&COG6N;^|cd)|^sTeVde}hl7foOSe zh;ohxX7hI7?hXY~^B{;LkC&!H?A#>1glLBbz8{2Wa$le@zkrx?5Ae%ZNItOJG4Qes zXu(B5Z#GDGpjB5Os~|r8JLs0~gX{cM<9#fhBxa+zf?QaKY$E0$QTza2gb~4sD1dxs zLN@`kwTJiy*v>nU(c*UTGO`){Bpnb3LHqYXmqY6t!Df17GPL7QEi5BMhgX4DJd||E zUhsl((jg$q|J3@t0D6WLwD)TGBpmqcXW(?d2Fm#cWExoa612=N_~{K;o&zHDV4y|6 z0h=bkGo`>)-v&hO(a>5pu;fo2+I!$XESx8BNFKCOCdAD5z)uQf4R~rh@H+<7qa(mu z8l@@7Jn0WOC0GQl_LB@B2C;T8@ROm?z65+%g6(|3BH8eICum}dpmUQ9YS@p!M}F$I zKLA>IEimA9l3eP9tOZuN9%AZFU?5xJ6>H!#pMX0xLu~yoxJpy7?{4tJchEzH!*JZd15+rU0% z;K>h#-mnB1=C42xI~x8k`Q|VhBmiB#0iI#us}b&237`H2^otG< z503_)l!3l$B)l#SxbTwH|Ns2!6)AdpXCNvFB1shc|bou3@20>z$iZl*N}r(*ntOr4_f&e{B|W==Qe1)E(Ja6 z-@pRvK;^d^JWd2kccirb|MdJx@cSr96ceQ|IPb6`ADHpvDidk8$&63%-%N-dyOcEK1h8~klOaQLGT z58%1S0BQZQv`FeAjQ|4nII$1-LI{ikFCpIlQJe~1?FW4F6;dcr%S+)bW;}S&2FVX> zK!K%2Xt&kSFGj-HCPQX|y-I-?KNx&|1$b6B&>>KZ&42{`H}o7Yu!9Z!Wuf>Vo-Y`s zy&{dyLY@g-K;K{t`1@6HB2poqf|h+H9)**jzhG8@%pmZ_UFfecKi-p`fDcWBJ6J3| zC;#B8s6%KwLF=3&Y=gY0Xem`{@ABg(3or2_-4YXt zTyz7>3jd06ST`gCTY)t4ze$+jh71?S;V<2T1DDAj+WydGIux=Qz`62qeXwY?f4U-= zE@E3b zh#=i{^ z<%>E{H9pGq#&+2?7BcQgbS+iEGqzB3s#L+JqrO-fa^I@4B#7sb1oAVILiz}!gznB& z&H}L=R!lY}7ob)cVaM8yVq4No55pEdy_5&5w7%j37~dPkj-r>41!qZfT!_;HIPs?(OPt5;=j%Ju(5^F zrcXCq(=OKrHayeBt6l0){W|jiS1{5Ge?%2iB~(v(5B-F3u*>9L%72yT+{4@>l_M3+ z*+sHDco=#`=;<2c(mKLi1Du}DOy}?RakdWD8y3BFh|TQi;+iYmff2zIFT`nbEA@#~ zQLU*nbP=4rwUvLDKXkjN_$c?|4#5eV7d4U^3NdelN+%`}8Q2jZe$IrI@^W#8*hzvh z5OE;a!KZUj8I}sRxPY#Na}*h|3@S`B@C=9$_~IvkXZaB=hIL>SvKo{UB=jcQ0c}7U zV15_^Yd}!plJ<+u#f?Hsp$*URlJ;ae@>tZdj9yTd01*U`MG1dTEGvTu624|{QNRe(Nvt^%TsoXh5 z3%9QxtcRaxgvWfhZHhl7U_1WiIQ*tM=lA7K~p3B(iP3(n(5K(*jB z1@qw*SpBVq*+CRCgan|X|HF58dGNWsOvqP~2goqY-{ia+aKTNB! zF!UfON?a4wVj%iLEEZ{AXCG`!GCwnCnLk;lI}W(6h(*#Y^dl%=Xd!1fohXB7&Uf~t z{J1i}JNY&B>$f0M3WSa4#qZNHq-`6 z3yVd&umZ+}?Qq?6`hd;C zosmwfbB8m}dDofmc;?t8Uuli-&h?9HDz6oS(9c*D;lS5pc4W5rOgJLw zU>{YDKc>3L)Ur2pYl_Fyu&L-R)DKT1FHvvk`}CibmMF$1!#vy!)&}3uhgc1M7q`F} z!e(iMsDq3f1@bpGiKrx!$UgXY^d4v&1c{x*EErWHAX8>5T<;HHJ-dnPgrSfxlnV3s z3&ajNBaN^QSPFuQ6y#i}#W}Fnd<1Kq17a7ksqlh-#82Zd@(1`quy5ZX#0qbCl%L^B zbw#+2xh$>&@Z1qTmalY;1KZtj&T=kvE_c$dzOK=(5w2`!f9E#GGRJ192+=#UU6;Z3 zzYFuA??wx)g<@fq_@{IMxd3CrSga@J26_}$bQ)HNeZdbCt;sFqBr=;AK%B+5;4ATG zFj6GKY5TwUFZg||4ORvt!wb}e#$m&;p;$M_*C;}B(9aNgt3%zfBiKW%0yAPC;ADC) zb_Uywy}&ZCER4jTV4tx)*bXcfv%`qI5c*pybRP7q_Gl$22n|J^!fhccGAMS3qco)=QVxSFEK^AZu?Bs8QHzvTI@+Rc$_(-2Y zb7hkl4lDJsFh)RK6zo~&Lxcc_b>>i*X%4|IFaY|^&wczD*wLC`9rzxuFd6(=k0{Vc zRDt5C7SX_%R0H`_IY=fVBB`LH(i5G5_D6Y8y;%o$G8XoFXCX7G5=Qk3l#cP^|te!Wxd)<(l^$c}%rn2%M zthhV4&;P?-c+l*Q$~p%6oD3si6hQr24eN6o-j${3UX~zJ=pcHbqDYPEN2K80n1uIX zJUWWJKq#`KN3D(Ps*PUbE;{N7s6i*1$%)8}7&(j##K}&qjhSqF9 z?$}OZ307bm>@6elCkK&ibd1;tO;=;QzY=zP0MVLy3E!^-Fd_@yvEQitA0wOW1$vz; zxK2NMuRM5fGIIel)5tY+P%Q3^i0kIjN6YBMg7|rvIpS6EbIKt*FE_qlW`4929f1SA zRR;ROkLa*YqLV&@?|Kxk;1PZTiGLG^@9+m{Vh>R9Ji}@{g5GO0>Yo|NK3aLAZ?~)bou@5JJNMxoiW@a0wBh1R%_ z6L1#?03+*-KW&2ivCP~_o_K^il9}%A!qsSam2dD`Z{ok1S+uo~VbmS#X((3sT&(dK z(9->jES_WdIE~f56D#)<@f6Sh3uyFSLOJLo6ygi?iPwoE`1)I@Q=UR&_zD`w)5Hnn zSk1xrXopo<9q*X(|L@anT-zC3<3*qe7jXw7KmzjO9_9fC@CwhsQ>?46SXWa* zKjEF*ht6#Vy3hGQd`_W@J&$L2FS^(%$ln@_r{Z6Hl?~UHj_+_BEB>kR$7p~qai*C@ z736l%zrocGV6PFC$v#F+IXyZs+D5(>X{ueu`?8Ok%fzv3$@Y|PF4woHt<>pi5k=L0 z8AG77tj_36P5M545jaZ=y{SG?E3P)z=I9o@rYL!bIzSDej+1kV@y2)kyEapssNK>x z7$2ZT>_Ls98dC+SE@T7Zh|$aF4E5YPy%Q9FA+r;?f_zS1CQCza*%0c%>_&NGIMku{ zjk4GwCK9`Vnl2(&V-43fj~XY8S^7hySe3xjFG2N~4$Y$tPf#wisF}}<=yT1n#uW6? z*Rg*!BJUABkw?`!^Vz`@*#&oV4^W}f$cw8*nr2z*JJE@p0L5ihtf!{HwptT4up06p zm-{mkxbBk{q7M1ntV~9LR+a%ukr(^tLz6d0p&H+-71!^mf>Bs&ZM-!m5RFZp_?Ik9 zyfd7}V08DIvWj?NsY2Gb{f-$ft;AgK0Pb*bsNxIfQuYUUr2=!;_|E*`nu{K)AC<)X zqou1KqgypTdIY@vQvQ#$(t18_qvfb&A$^!0Cl5g)(SyLw@cL*EV-Rr(y<S zd&NGs%i!%-@(awY(ec4q84*W!*JEFs&=^@{hFMR=dJ~pseUez*K2{h{w^bemANof5 zUwP+;%WEx&i?m0G=T@?tkur5Im>hZTxfZzPHNvIj|C!V1*IZsy$#+;6UD#-cNe)Z& zMr3oevi2P6;J+D-39u`uf^-36igpXhDK#Ti!^fgW|86b zYR||crbSNx#xX;0Pp%_NF-6GyTskwD97Fsfa~emYKCLCD+PSDjucj5yLZlmtzwXpY zx)HfaA8Ov#24nK1q`uu~t?wdFQftXMJ;>-gS6d7Fa3x1&}?h|0s=Ukx{XAgroS_Zwq~c{pSGo6k^b+l^xSDE%j%&empa;sm(~>daR7*iPOy1@nrwQ?IYxK(5&c zJRzHq`b8L`k)$8gPwCHbWg+#r)VXc}Q9qb8xf&SP4C9M2+W1c|YAn!S7%PlAzzXZ5lB-Wz$ToOpccRi5hZ^iIsvZWF zz~3ZJ25LW9nCgVivKMghlW3RC6f#+I&ZpLKSsd&x0M&iaj!s z7>XTmInKrmqBB_=SxbkB7_$QM1pfkRo&hv;D4CtCk0igW)OKnu;*uYjouO1uM)$pz zs!8Xhvol|*BlKkJ0FA+y~ae$MTJlA@`a42l|+wT!fSOpPYp|!Tn%QLFIUzTgR!~ z9WD=dl~dR@+z{5z{bXA}6WN6eaQV0q+!pR3w}`vTuH;6sm$_-|9kx5ep_|!6s^k~q zJ(&mUjw3+nWxcV{N?%~KHaOE}T*a<^U1^|YRgWox+ElrXK1$Hi)Ne{FwHbQPh00fS zRn?S-@-n%PG9$WJq2+SQS)df{qmRRSTyv+91MhcJ;u)qr5`q1U zM0ME>d&V;LncOX!J5oIIJu)hKDS9aS1jwEcO^hUmbBE3a4Bu7nkBn#V*vXOc$aBFn zI-@hZXL<)Fgp#A(l{Z=gLZY5BM>&sh7uxArmY`I^@>?1&T@ypX-$FUQ1x~1=SnoZk zUgSFSn7&B+qig|s@;$mV`YL)`{-2VlbL2EYKrbOTS7wxO8!~&9v8#8yM%c54JCY zcI%q;sAU3@76(YjM4fNMcW0)PlZ|;wugI-HQ~z^sYDNuTdtcVTvEU^5TUAgO>lB!u z!sJ4#5tG1`6UR%PZTsy9PaM$d5rePJb> zGZipf5T_22$ICCmxuVITkKqr&b;0hzeExBPJ{e2A&D~X9Bb}pD|4B>z)hBgJT1RJ> zz`&5MM~J6d%SZub-R&(X|H1RDs};vMOY*Z z5hTIS#d52eXHd}JK#zAG?9WAXy-m=`U(rA5vRYAVsa!xp<+X6W=)Q2>=$G)K=m5;@ zUkaxM_5~{Vo_m(N*Ewo9V*f<{g#OfU9`zOtb%oRRMQ*0`W~`ZTKFf+MbCO=gPmJqn zmu<7Hsp4&66u$zh_UcqW;+l3KnlCiX`^FP-taCg`=N%v1-VAT>Sk$AnCbu(KG^xG%`qUH`iY06i42)>uIag zmbf&JHI!ejN33MuNqJ+-Bpl2#EK6e6vkAlFlVjs-7o^Voezq}Pk%}d*88@}j%8SU_ z-~ePKR6e%ChFHUv42 zwQVgd9i_L}L;s=|5e1A7>PLl@B{@xAs5aIcnO({K^g}ie-v@i^G`DZ+KViKZaS~- z)W)h~Wxw(!oEH5Z`~z?L8@>hp)1HKkuTGn*B-GRG(z>N}cbsy^`Tq?kz;P#8{B2E1 zn4hI~wi!7(WLuRjH(bVR#EcLxaYLye!~wmh`bFLqnI0|($Cg*VcNwPVse6=rddA0$ z6M-Dzr%DXbhp|Y>w()Tl6BDy+&+;OveA0k~^KlpLKHLA0O=pt z2PfJj%t+Nz&*0QRcn(pA9>)$qO6Ym9pY%-pDHO!~&~WA#HG%Aer@R?B=Z#PUjK`cE zOI8KS(#qUrc=T=h0=>8H)5~Hl+yW*r1-QyObjjm^=f=pWFM*{l)#w zIoVai8RHo5bfkB3IMYX^T}^-F5Z(Q}I0K@a;b106+3aT$4`h9kJtOCe?0IrjN!l4V z$9kDJ>8ZvArA*{&Fh`)eU-u^Xo@A8rRrZeYw)HRdEe`Yv4v92Y7V0GYPx?#6VtT|+ zNy?e^V>W-bw5%nwu1h=^H`0E`GFKYL{{i;{4HNx^+>BIYi;ivw z^AfHfw}b{bFON&(pg3P8Ji&B919lNyKC0lXTMRwH5#lM70yls>j3H+Or+a{U@tau# zwf!IT7qP}3eW5-~E38#hUnn=_+>xr0v_O@BF+t}9Y;>x$W>JkmC?m`Gr< zpa4{&IWRRMV`{XS@i*T2*Lql2^g8&zNyc)6F%B9fQTI17p6eZvSeVn;ua(v2s=V?@ zP79BLZ`i=VHh=$&4<4(hf%~%SvZJjd;>hl>IQ~ohoHirP?|A8|;h7b97XGB=pjPq; zwyJS-($_5AvbWAwJL}Knb{`_V6wBt@Ad=6x%X? z;h+|}8?B{oGe^^9gxeN5rgZ$2#J*XoWXY4HBut6F6Sq93rbUpNU;^YTbA%p8^#$Mc z6{~ldQPM1Db^}g-*&IVYByYguMdS8jR^qOhEDe^Lia~w{rf0s=ji?Xk2_$n6xU;Tc zSqNYlZ=rOUjFsU5&z|4R1w}zk=nZqAlHQ|F*SqSEwFO!pb*fS~Ix5mH^xJ>MJ2T_H zC*FP2+0&K7<#JVXopn`pCAo*XbXPae2JeL6y6`Wh9nqV+V5woBoUkeBLpDdYW?BDD z92++=MvxpVNwzThs#WFVk)NTBNc7wpEFBmf5KtRT^*8ief@4EWbb)eEFGD49bEWpS zMzO*8=ZVXcawNS=?3b`9zGhs#7{->@qKnCbh1xGgVc_OMMA)y;aPB#*!nc zH2N01imQVuks9KAOqp!wJ97Kkr}QwIpcaz^{_Xz&$r)AJ#}6H+_rv-nUb1sComOdi@wt;bw$C!l*W6lB}THR$g9um&68esMPbDL0h! zvA>vp%rcs#-BeeqH+YLH)F0{vwTB9jyTH^xM%sM=;vaOoshG2UXmkPM-5kiE(|C;3 zby07uO@M#;PWgPaU}S7K3@@^(p~j((!Cpa2@ZZ4pz!P6Se*-V;9h;Hhnd2$uuJ3K_ ze;+EYu$IY|du=Ewra{bu}?06=O_hvb8=_6p;N*G7IR?wL0*%Y|$PSnkLj4?)IBVD7lLrOg6X8WtLYAL0H z5+}cnW{K{KT#MX^43BIM)y8RBCD;;)-`xZ6eZ~Fxe2;u?Ul3V$J)#x0uKH^*Y#W%a zNE-Sjbm7Ps!wOom+4*fR7?&HVY71oa1nl`FjM3)qn6#W7IQ{>MFv{@JM0O}h4e$u-34k6 zXit${|A;EgEF!PqS3i(fAGMGvOLeE$(<*A`lk|IeJ?|v9qq9B; z6#fLTk^00#ApH%ASTaBa$wK5$at^hF(#R3uJ2EHP`-5v4KwKvNBePRm!T%MYyV48j zvh*16mg^`N^@ch^<)uf`0eTWMooRq6)J9BA#z!xtW9YloEli*m$2?Uox&fv$C(|A1 z=Rj9k=m?sCKXwA4FH5$-N&Eu!=S|Gc4*|QA1AIV5;Qw!ds9yjV^a#Dib)ZJofC_V{ zF^)o2*B(Csh9d3FQ1*dIeSe3*St17DZ{#Ot;3 z+<0ZU@%tY}E}Q}tF#TQ{uOcqO#L8C-xg&|7}R$9-J=Gk5^3C$@o`enXhxFC5_SegX}5g2gz3@A(Z}-h0$x z--z?zJcdC@m>--=23V$Z(4sXr8=D07Z;$cLPz(x;Vi9z+OTjE{0FyTx@9_rcDEmNT z))UII-q_iCn!})!>jqA*103<{m|gIFAMbCKr>*?U>`s3U3JaBz>KVc3T`QQq(kO&yo)|C z*b!j=1lS`ISXd7Ht_V1*O!aLdkUtY&;RGvoAI#7HaAhv@8d$MRmSqo+{-bz3Lvf|u z@#@-vo5(z~>*Lxhj04-xp_75BypNSCb3p zbZ*=s8?GTg{zZ9Q|3F;p5-?*Az<_zdmi__b=EJMajrEuX3YW(C#Qp)~t~%JNa$xg% zf_h=n#KoXvEs6IjAN0&Q@VD*IZ_NOkI~^$Pzj&{A zW7S`QQg<~Nt@Yq|cYwRziF0%**x9AvVh@4KdWDbEIA=HEq@4u@tRsF}Td<&AfI+sy zo{@roy9^&|z(y{??Vqqrj929){(*bqc&>D8d$ASsi0+Dcpc|3TY*f2SwKK6lXqEL7S_AW)de%rG zH$zPo)E8h@=r1tX^{H>hZ740Do6oU3uBT}7A(02J7x(dd>ti1aNoH(ln+mfS zn>VRY)=NT~w#+rixg**j>e3#Cs`EuGp`_|;JjHrDsZ|0s_@C4|c8$UFHtQ7r6dlw8 zawy=^N>hC;4Ww!D!}&wxFZsVn2`xwTE2b@e8ijQa-I!d+PS$G?1$3vvkS)x$>^S2Q zcU!;Dh~{&8d$_DwAY8e8{`-@jtcvmz?qpf(ZL?&EtgRjAq2 zWOgrUWxu1-&91y4rUQLy0p0RQq$!Wk_9!zHmvV(HWY(iz87z?l^-V79hBMUn<`1$q zW2XzEHmsvhAqs1&^<<4951VzUXV_2PLfyDQxuL#=Ze^=^797nDtVLQ+0(Ti_bb|J9 z99e^&09JCL{#!e#ebC4J|E`c>>TqS5L?oH{jEL4pt)&k&>H<4INMRS#t81I}6wPJy z)oU9QvCG{sNaK_CLwlzGG6or|;C~`x_YUgaq4=nVs^cTH8BNiz--YsH7-ou!>NCK6 zzK3S|22_h@^h-!6|AD>#4xX+bn0)wZyh6n=Qy*aTfjS`tb<9w6hdvzrj{zT*6mu?i zzG~3bK1ap;#@vCQML<0fFtVwxDNhAYF z`fcnoPC-Xm0O(R(qZO|B1)lya=+y=EQ)hJ>e6EX=sdyfX5N(Y4&=#LIYmhsUwc3{~ zh7)qLL1Gv4gRdTh9_}U1)T%hQM`K@e7?Yv+9S8h;7iL?Iq34^4v!WYO0hMZgdL3$_ z(qw+}AEFoOCkm1M@dOsc^N^obDLd3{L&+^^#sb95iPTT6+L#3=GFst`Dt z0J)9q4HR@Vc^K!}ZhUPf_dWwVKnv_wlQEew(I|sG`!lf*PwocjnAYnB^i;hKJZKJr z*DY^)p$lE9JxBdD2#CTt-LBV!>SnwCL(ifQ)n=(>)b@BPTj@vj+Ngf6s^67=mD$QS z<%qf!Zh>on(9-4%sCW$H4|vpFKMpD+ z>-{mIP4a#H3b2YhKW3)z$a2xT+|o(x4Og~%(8m5|NAlD837o`CAP6Ez^7bcPczw=WzV2C|3L?V;p+I+v+$qUfmA8T6HZ^Yo=6_ z2Sl=kyM(%gnuc43vxVmbdk32Mih7%6bjtXZF(_l2cXP11Tv{s)Y(6_G!eROiV<6d= zJte)j)wO?&Ic%{BnV3y8dNJ$ZlLRl{gE0u&xQC?cvC&m>AFQawT1_pB7Eo3x4yBE{ zN{6p#>bYF^Ev=fcx>3g_`Cq(huH2ALh7XNSGJ>U9_ z9_|OOHol7CRAm5}m1@Su)12802+kO$gY+_HU0jFw<@UN32cOJ%s2>aorpwLsWlod7 zj2&{*=@&Va|noPZ?a>5^cubv<8Vw_q?Kc*)GOBsiHb_p>bm2iL3 zPR*p+GDYBMFqf6s1lGpngY=K9O_b0kr|siM^PjLeO?7 z-bmb>lo7u_rl=(s-<-_`mh>`Pg#DLlZN_NXBhKI-WH$Hl2mHu42-S;pin`^>%1~vj z_Dic`E}(4O0?TB(EADKMs)rx6^=2X^einiLA1*PIYR7H9?`r!w34oo#f*moQ+=)x~d7H{DbxjifpHZXUI z9LZ(l>NBP2-DDN`bzj9^A)~Vx04J+iT7r60-YO?XCjlKO6IR3LgY`qfz$<^&z<2LY z@BIuTqkjhN-I$@f%X?}$yEv+)|1WKjYe=YGv<_jS(wRY*$4R@1YRvv^Z4~QIsFv_N zu2M`ZNfhU^8EhH;1vFRrfEDyLa;jOP`NH#p5ByyM3xk!zm*r*3aBVR1GwP{%pzRq= zI@qWfw5l;V?Lj*g+s8gCwsG9OxDxTJ<1BH9?E_=_SocV&!XNek^M!gq4g`WS1v5QK zP*hEWXH>FY1BhciG8JU(=t}|J^^v zXZo*X%=VS_GG57718!HhGGtd#Xz=$q8w88UtMzfTYM!Qk;rk^U*Qo}=R@>;<)p1sP zcUwhqhtP;S0bcAaH=S`X>u~xlH{WRsqIDuKf(fDBVSiN8Y8V6b$(Tw%DnC}{>AA_C zY-hOMbb)WvS*eJvwQWtz_}Id+S>xBmb&ty(_t^d^CdNuxZh~jp3*>4lCtz0Ih5Yrt z&^A{#Mw+RF1&rfWD2B6RvO9AUWGvf|y~h0o&#Te=JZ=K-=iWeBRg%---LRDX!(5^_ zP&J`4DT0&c6}q^HaTz=73PqD|Ls$1TdMd(49))8=^MWG+=X_^;DL&FC`x?Q@p?sit zu#&fcf32%UM!}5NUfw?we;UlV;I#mqec-I`trC4F*C(18+f5cIe<^7C-Y|`XjS^g2 zEiWzg#Cp+ZvL!DS(@1X*goO&ob7nbkeI^ofgcfj&0fgnv%r-e*SVSvupOEE^Z_{76(
Qn z+%tjjdZ8cpY2CG2np53}X@o~$@T}2}(Vo%iktRq8iNm^I8!CXAjl01zerK?R|1A(z z61oM+|2@#bmjc(Y(s1N&;k)j0XH@Y#bKmlIk3unr`H!GB7YJK_sFBam9oSOh6JeyV zLCO%uiD4my&BrL<8D7BCWEfik?vSrg*%bzFUIl9Fi-rsHa$!R>{{qUs4H(Ez@*T65 zJ0u*Da#<@_KU?Nl^gNMx{VyK8>D0f24~sy>Mp^vn$xDYg#(@dg z2pz;fz^`{<9^sp|6G-iL`A0ZcX@=MQ+jyXFG#2Xrm`hM+S0s<31Iz<_qBj{s zTiJ44tU$pFX{KeCl^1GZbl{#t$YO1A5PdQdD6|EcnJ32sqifX2JbZTUCM1o@?7u^uv815Th z8af<$8G0HXjLDZNNbo)vDCu9|Kj-ykB>L>3j&hU8WA&^uNF4!&y&n10T#xQWWu9|o z`0^HAxGePNSF$Xhm9HyUgt0;&ZW#L>3Z6t}5Rl8?=r@*=hsYds5OZrU$!<8Ij}lwS z(wH_F&fH_VaBcZH{6yh0KJN%C#iml4L|9i@vRNlctYm`|RaK!pat4mG4tSf7z;{ic zQ^_|#-<}htptV1V37QsA;zzKiRs-dBQ5B&?n}hj z6*KR^N!DtVG8JQ__E;~W71Lj6gY*$yGF}L=7tuBn*z^o9efk28+sM`7+Dluj@iREP&RB( z{{>RH#JFMZCH9jmsO2=$yO;{hHF_7inO;m6I3r$TB1}>C1Y_Vy`AU7EXiUx_e}y_j zImw+=PE>6JP@AoyHqj3H8S|9g#HvVLc*s4$I@`ck;6qrijf5-waCSd5O8eREKz{qd zov1S7U{}&-sj*B)#zB^(o-k9XA%sfSV`kA?sE+0sdJIDwqp7N9X=V(OOb(~3nGGq1 zZbQCNo#>)c%o${7vp87|^E(U3R5cI1#(blBkauDQ_qtEr4e!;Qs8^u^qMIwKT%Pzx zTc@>EZ)hRyz4}J|p(L2& zLTa?Ea@I^GM>9k;*|??Vq~99bh!@IcbDp-$SWDkmL-J$e7}ZYeXMT;e;;Jj#)nsCk z`Zv5DR%w0orN(XTwNV;~LxW%<(SvRcO=&J7qNkb7RZ}lW<|2-3kMw5h7_$m9gwR94 zSlB7%_ed_YGCXz?^{x7KdZf`*`^fdwR;f-S7d%`}=pATi7dg16iJ)%DTY7?I!auwTUrQGx8Goo2{WDX$En%sVZuXp*R0ex=87^{-H0{*R*N-H=@p1} zd}XA=45V+GN#afPSJ{b`%uO?yVo0~zoc9v9bZADzLgWYS9^DEqI6JAqMh`PTv(H=@ zElNIQJ{Wfm+*2ryF9PExx!W?UG^X0q|N2J*Q(YA38*M~Cjx37Y(o2W3>3QTX%BO&? zFH&2C*FtA~Pj9a_j`UKi5S^3)+H5AP`a>;BKa#5uAu>ZPD9>d}o70tZP?HYO<`K=v zV#IX%ikhm|6#q8QDmB@`WEN9_el-^n2JbVL8KjM7Me3PRoMVh;#02^UXA34yBr?wMc^uipegz3TBN}>}}GHRix(m9RAYBDrK zrnZ{wNi<{M%6=`%&m*UUgNik#G5f{KAx7UKEYmsCNpz4G^Dn8RQJy%q;v!O(-0>COEu9K8B>`!y#Z2$3P(WpbNB6koc@ca2xQRAMm1z6^4a)fb- z?qWE}c0?s}J@_`2{Dey98r2Az^MC1dQb(s9!%l(9YcRN=r_3gD2b~@HMh0~l8M}vw z!kIY~R0Y)hUw|!6WJ2KhBH#@6)Ayi)Ed#7?FXk1O=ue>$fD;@Xg{}ba&Gg5ZX13G) zz=gl-A+r~7?Hs6JLh4N7w35}_VB8@8!yL{DqM=%ssHS};cBx~{(fUuIg|FegS=Lx< zrs&&=J^Ci|sUAn1hXcShbibwaE=Dc0tez7LHmTn-a+*c;>-aY(p}@51UhsMSp}@YV z#hXvGd-`Hi(s%1&ZHoB^XI2~NbsWY>@RnQDZP4sZ!1cy}t1GR)GkyY59u3@~pHY+8 zj|rzmW(wXrf@lxc@IJ0FFPQ*FlED4G4UW1Tk(VqD4EP~A1*-W}C_i?SpUDm2G8KT24U@)1~6;q-+sXv&SctBPod!s&I zOO7IYV4AuvIUXG7M6wL#sJ9U-F^5l(G3FVtYE{5y)y5oK3QmdHV64}Gc`8jV0rpsw znhw16g3+0L4MhGSuyxt^f$sRJv59DI6gN%m<|8pth)GWHmNp>zh=IbX&=+^-7(98p zL)EqeJm(td^Y6g>eukMBC>{x=XI}G>kp=8<7+i5BBcIV;Ka6T%5b)b{{Sn*_=IAHU zwLI5XfemPH?AOS#k0mEKiv3U(ub{xUU>EXWzkcUKfA{TUWWM0G;$rU++I{#7lQbddFk0!~-*~B+M-pm!7~oJi$7~`WlYk6Kwg=6K{(# zV+PoZ+Ox$PcGKQ3wrcFC*x|8Z`#yVTdky==nAR~nZ3S&Lt<5cx<&`L*&+%}TxPojQ zrV~=l5^;B;cvh+y61bwEdROVGG?$M?{ivqjMK(psL@Gu)M=peKhX+RXhQ~#2g)fBv z3oi&)33mwp8@?1860R8<7%CmQ72F)09UK!V6iD~?@}Kcr0=s-Q{JH!V-zxtx?^o|S z?-TDt-(&xjz{lX~(D(4e$aXndnX0~4ecE`#hJ1oC#7VLal?}MxTBZlHkD14oWwUU9 zvp3o5>_7MnG7H&cwjDc>Z4TE{1CtBKxCh);z7&-3Vd1m5Lt1NDVr^rKkNFz2)=tFs zL_+Jn*kZAs*rl-zV_U~M?A`6%>^EYz$9%T!x4pC;wO)t!_86(Qv{0-p&V#0T6Q7%J z#GPVyuy>h5%qhAv{K6~3$$%tV!H;R5!5R;>o!TjNy<$}g%8jFaBKN{2!sSExL&pNc z0!YF1-}UL<2|(D^diJ}Uxjin)b;WtZk)A%oQ7Qd=`ugjX|bg*-)Fr3GB?D()6$3qD<%l5=}9ct(! zK%Bl=-@?tc6MVP3Sl}rwwUz2f-^E$tNb$1Z7nTV>g!2L~9v5iI?bJ4@i(p5k3|LMOCN;cl_hxG_kex zSge368Ya=wDygZ|6~6E%VD7X)a*JMgC+lKE_^m&eSY+7NwfvDL!gv1z+}BG<6Qo;W zY3P`y;JX|Vwh9sc1z#J!`FZ)X90O;hL=cmc7x7k=x-5q36Lxfy#kezP>&xL-)*d`&~I* z4V+IL$I@@6KS|q`Ha^Xpx;u4w>eD|j{~Y-<{Et2LQ0lR?ZRsnWD_mPVZM@xl#{=bp zC&IlWkAb@O(VjxBY{D&JF-@~a*|Gd2K2F>%BuG2Oqta5zEp@PLvgnpF*7nx9)@g7J z-+;7$zpOpspPyhYW_@dUV_9lBXQ=>({dbZ_qL3O;&Qc9iCb!^;IvU>Ng7gJCxS`@i zu{7LYlSE$Jf)!UuC?jM6Vq)^egginAp+DY*cY+U|!8yg&Vkfbt*jy|J-&w18QP?Ep zKu25}{M$o4ak0nKAkPdoQiS2tIZ zvx&2w!=1h$ok`D=UIV_;6zVt(Qg(z#nIdz3PI zBAlVRkzMJA%u@CoH;w2U(fY1J3tb%fu=TbvUJoZMya#A8JOQoIC zVX40~Lh2xe#IJC_9wBBEn~0^5EYLx8!1?2_uu~WyG{BR%l0U=e=WFmWd=h_$`@!`^ zK4)I81XqYF&Sl{Y)(P1;omm|p z9h~EE`cL@9sA=x>OKGzlg&oITR#*9qA{o7XU;W1d=R#vKg;7-gSJ|kR&<_}mphU6L z>2xw%nDcQ%`Qn0;KZiY}zgPxz&q{F;oYz-k?X1V7(N?LwG+e5Ib3(r_GUxerJoNeTq!=pEq9nU_2AJ1+&J##XX%nEH z5j0d9$z_o)Iv%duje@8B9_YP88HFavvDJ~{ z=-_zZsOxk%|z`>dAeWPLq&S z5)03w6jU@_xEh%3d%(`dS#X(&VUE!rIFDrJwc3DzIQ8k;RFzR0%O4^y!Y@KAg3ALT z{{&w>->HmAK%a|vtZu8jj4J^y^={`UM+;Yf$3#~#$1djzhsP0hTyuVOs_w7ugWh$% zeu2p$F5EEM3$w_6Wr22Ie_&jIe^wcKDbs|_#Zypx#^9vcDeMrg2-}1i0xP~0K8bV1 zB4CCaN(F%u-h+4hBJq_tMhu8+#RgIWG~(^y6irJH@h9W)H7QaSsid?{%pz?SJ@9#- z2fuY)2nqk;o~8)dgok)a=3;-H&#&R@^JVyA_zYsDhA^=C}?G{_j> zHr+>E`(00*gGn*IXXJdq~A?H?pW>I>ALC}oDuZ34qORk4G)O+h>ld6 zt829b(18bu(_~NjArugW`842ttNEhB2;RpJ<-79Vxsq7XPM*bzsVyu=<<fDj`V#2uipS9TG0BV&6Z{_d5qPgo#q#pyDgU&zPv zgOOcvhMmTCVym#t*zWLaKfqqabmkhiHkXS_<|O>L65QG5qb6F$?c$nnRl%iJWOo6T zX-BuFBJf*!WbD&>Xtm)G)j=*EwMIOl8NoV%?fx&`)?Rx?VNW@a({;-=#N~GWbe?xM zb!Kt;F^ka9K{?l_?Q!%^%i%ojDDIBB`}nr|=7k!CzD3GLAIi0q-db~Qp3&3HN503b z<^;@*{^GCW8ves>*-Th3?BVkXh4_l#Yjm6uzmRNj61!_p=?3=K{Ne$8j>oyOQamZL zaQts4S)n*j6Pw}JPLdNS^aydMSRIwkB5@`vk(%OAJbR+B1AFOR?f}<;%Ymom3A-FA zNFCX+>=-r*%)2EwfLoP`%c3t}_(#b3*ny;v7F>Qz!tP)VW-a!XXXwke!|SlSkzGHf zE>tGSqod;@1;dVD{XjPVUGKS!SDp;_Q`ZMqLuWnb_p~CZT-raWO;W$6_WZLgwZxxA zsk45QsRdKpISx1j9x3oUG(_nHb#fl)e0Cv4YbI2j`=K{%3oY|dYBQ}sRq}$_05!Om zOe50Jx9>pb^N>vtV8CTLZyjbEZ0l~lX_wmzTx;PDrae1b z9!ege5tx1#yBKHDJK>1f6?=L$aW=o4eM{HJyz_gU3R|GdC__`^JY%>zT3(Gk<#$+& zIKopx-oPX822XuwL7X><>6aYi9lIO>2jSf6sFFS+ZEsrW&)l?C>75-nUH2TJ^wN$+ z>4P$|`8xEt(Q zwj4W-yU9`F9Lqv$*O+u$vTc#&p152Tg`#|UtZfM{{ms~Hz=4-TL%becjyUUI-WXWwr+ve{ z*?g0{vohAZ3#4aBZ=4pN*3Ho`jdqMptB_hNefRGbj!o&e96udOY9M`FdI@I=PGuMN zlT%s?vzi9ySKSWfY!1pp&84j5U9+=3L2au3R&QgPHyOynFrqV+i^@*thTmTbw}iKf zYmtkx9@!=z;5NS*`2?M~1sum-V#d&OF|+@Ibc27HL*Jy=LBHjqI?)yRH&Q8k(^$%$ zVE<_8ZaIfk6$R^Iuk-;m*(tU?CY^I38EYLkiti?Vm3~SKEj^{qxI5p3r($V~V)05- z#jU7ws-cF-E_~rvB5xv|Ys8LY?m+dv9lFHD#uNRiT2uWg_loX`%n8ZC*8aTylfEKg zk|#lR@YFrs9q(F^zR)qwaoO?C-O~BZRnM888smJM_Q^TlJ=V3!o5g1dWyt?250Ia> z9?mJf^j7p7@+7-}`2?q&>1r`;t2R%c3#B(aa-j+4z{zZ5o?$k;tMF8qjZ@}|*h+lF z&*2>0YmUU}O!7^ER6nLnvJG9BIm?vdE-)9ENz6x@VDB*X1ki)_e(@vhtz$Myb%YM$ zWq!J3m(Gi;_)Soue4uKg4;;!q=fnInahbFnwV{a`u)nYrryMCA5O455*u_W+z^To( z;d1l;!|vRN9mDvTf0(OuZ)z9O*Ia8bPvpm ztP=l4=Qxe4O!p^C(TVgw)Iho*`O1i{+NORcazZ<` zS*XhH#FXf5?OF7PoS=3Wh1txpRQn60lsJx~y9=a>@nyX~}o^x_i;e^Stvx)EG3wjn)%YH&k zvK|3ex*$NN@pG*C_@+#8`Y||!^2|OiNo*{p z@b{TwbXy{U-j0cq17r^9^$V!(=<={T+&1ziV3qrZcSesZ2ZL9_tNq#i>E5lVV8gCY z?i5#nv^?q6(l@6qN_VD~aMp0Va82>FO8?`n6i5ja(l4s>^qEw9XowfVgL@WHfUZvv zS^?v>JR`bL9SWo&*5Ki3BARQ+TjV49EK`{M4=3aR;S&Fef6MnlVpbei0;g#a?h?++ zznRZekZR3TW6HC&xFYO+W*FOvY0LhRe%VgMF$q%4Hfy{%oxdxr5e|#rg?(HhIx~ZE z2RVzm!Vcm)qH}B|6+-Gt)G|WiEG;CL&`zAekL9i-pQ{;ao5>tpZMpjV4t4}T3MYD7 zR;QcOHK;yREZl}q8mo{Z*jlMCe^X4kjm$*S(Hk`nT?mPoep=&S=o#aZJzHG^y@fKe zyT-Y$q#ts&bk@#TRsdU`&)$l%13R4aRy1L+sF~LL^Po*BX_X}R>wvCna1l` z)Pw3oZI!m(+zIc~ao95oGVQp3xs$^G(DODHxA3?5H{3X+XPpH$b%B`+N7XX4gw&KK zV2qovJ=xssZ*~)Gxok1#V~52Rw>Pk@5>xqULM*>ZILPN^XVH%^*>Df<#xbTJrwN^e zW72+{{xd9HfX&quDhTb+Pp5E2nRW1P$xqc`b#@J(M`*&{Vhl_PooAZUE9u+RR7{Vg zK^>mo=&pZ7&RP{UHmoSIk#hbW;cJ0wzDdC*e$&-6W0AW?+IDw!x68F7V}z%s=O1Sk z?;+<|X9?GP=V|xoz>dhxum|2iRpirZtoBpyN*nJL(0soO(s+4G)EV&?65r zzvy4|>_!LT92ul*V%l#Idf2Y)azmiqj08 zm6^hp7hl=_7c)MNh}&X zj+7cp)1|vYO-xxP!jHtxhUfvPff{pfS&3i9*_jVe#`&;^6(u`BbDRdoVTjROKcuu* zUq-uzv&h%Oh2`qOJ>k-UzdhZ;*U}mPf?y5Th2X#pmwSf$eR>!7VdqC*!QeIb?Tiu5 zaz6_=b$1`Xj!e&7##qCqd*m1TSL!fTiGRR$C#S>JeUo}aiPvJ)@^CXeWe}L)8;x19 z9K=1cFx!EvfgakAlhVUHK+;%%oy(oXYaP$*LFU0HxcKa$PB2xmpPpw|G7DIe9}-tB z=j+(zxvJZnaVO&{lOKe)vCD+Td^f^M){W#gtN$NI=NuTv+D7s5EVCQdMrzx(()QN2 z-PB0iTid!tYNk$Yo2i{fYw!3w{rFGw5Dr*^ zKY{}Sox-ohnc>-?kG|vnLhe((mtM+!)17Tgwmr#yo%7uF+TPY(D|jkc;n}Y^B08Ul+$f0x`c|v_jES4LnEvWIz59A_Z-3ms02%L)0s^60eI2}`Dp$S*fi-h51r=%%F4WEPcwO#O4x;VqY5ys^lPmz ztvS(kt@}+I&B^+AhHd6+rt!vpTph*-;`aj7O}wV6u-{mM@2@XtoMt?1{An0sIDs>0 zr16MujbT4uoVmkxqGQ=1tezXjJ8^o>2dA0&Q0mm67Eu;z6PbaGb6fSCyd0dX@j`te zO&k&K9#|ay>2p$Xe zA^HO+|I)*HxZ4qZ*;^( z_#SR!@0)VqBl&~R%iUrANVmZv7@~6QGAnr1gxdX zCp8kaNIg*VnNH4v`^J7yzB1)LN_WJe#wz6%9@MZ`&<^cKg-vbRiT!PJ?8o=>Te1H} z0*7DCedgM84Equ_lQX~(>w$XGHqZqp(BJ5hY$bhuYqJP0=47I+sJ(B1w?H7( zanf1D@jF{{%(eCLaIOxn47W3Thb!HF%C_Ajx#tImi>E}7rpmRnrBH<}B70D;xtZK5 z<~7|$9S;@VWyuA;>I*GLF@h*{TY7<9JwYC&e$ef?UZ|^1#94kl*4#OWLT^Mwq$ax# z_J13Q9%LcXN6n>E=$EK2Xb8uy6mtXPC(D^=uc>5Y*x1hS%bcjIY~0Vz#RzBwMMGoM z=8w@XDQw=~7^f zS>fR#pw7>sy(I>K^Ib?RD5Qa2wNpBaOjh6EV}FAH>u%_O?JnuIcxi{(z0!Wzb<5V+ z`3GmmL7pcz%^|oV?Pt6#{0U*{|2beg)bX)oF=8Me#~p=s?JKH5FRM+IV``K#2(+}2 zqJXmEQa8$tpFUxIiDf-F zQ$OX_BF)Mu)DT&ro@t7@;m)YWYlQXZk0L2E zltg8Vd_^=0UBfGb#lp5=LNIqI#b4K7#`nd&!~N6!*g3?r!}-(I&h_5D(sjl*!oJ=4 zA?LdPv~P#7T8)waCDy4iYA&h*F%0ah0eo*=4Q3N#q>h8}SQKaGG&NbNiP(i-`5n{^UCb|* z5Y(^~iV5EfMTeS(DhC?{`vtc9J-!a`8U5$Leo``ME4e zHSyAYu|AX*DNw>TCudPFNE5w}>Q8s13(+2GI@AYA)CKAqb%0zgDUm|bTy2V`luPJk6e9E@)B8#s*1WYH?Cq1s%-*)=O3(Y z?Ge3s0F8eo`UCY4)npB*;bcQ_LDy=#P?5b{xejLXX{oj}M~p`OQz;=BwuEiSekBDD z2loYr2G;ve_!|3ad8d2(d0u+5-DBLV-2b`exqEvqdsq992ejb9P=2Aj_&{2rd{DEH zE9*=Wl$oAR_hK$Fa}d?;#dKf_F&pURP&_ZBmP4s|6N-WT*fqtHhsY(MoNR%vkzul! z-RvWrXmhxlJcHbg1H0dHpw63d0zPYaY*=lGH~dC!rviNUZu8aos)*O8qsJY=y`2Zq zyeMiNR${*Yf_m)6P^nZx?d?vYHsVX!WGc0XDi6)|KDr1K%|zjp4b>Xm5>@d3LZ^Kb zq~;~4%$=q+L0rfwua{%xDqx3Q6T2f*-wj@Zi^DU*lfuoyOTr_doV^p!0!jYE;2(7J zP4Pu~vpiMdSH0N%!&S$9!qwR`(wpBuDA+KxJUme7BwmoNLZ4oZFhi|77<=h>hG*Y1 zDQpzGg?)`te*(U+GCdj3XCXBm)#!d=9@!mp@Ehs~ogaJLnyknsa;5nHko&8Q9O*jU zXvDaO83w_NwiwU1MD>a5!eN3!xLPMjL4_>4erHiF83JyPSN8 zD&bkEXfBJ|hlL>dRmW`m_es@Zp5Romo{43Q@SXk3i}-~WD?_z_zHJP(_~OG0%*V=#_7h9(5p1S+E%=A^efI57FWKRtOp)j-Rb>73w* zb3}rtGRE28^UnV+R76Y>?~3cCOHyCN-d z(4o`<8|O7#1{Nr-)O)l)e=Vp8$?ev6GPE~MoEKs*bfSK1BOheSfp{3s8Yml;I_uQNv$P$SE zY@{b)&G(^h@e-X*)u#HBdr{@x1ohbMh(^%fj7P+uQhlg~ETtU8X=c5YU-F0x#g5`A zAqn0P=>aCt!r#uX1Aoo#Yvyn7E9-0F8SUQX%;%iuobD{{`sk=@cVtu9_jBsG+XpKv zX*mB^Bdp3HwU{=LX~zfjI%8*DZvGA3i|PkQkh#!aJ%y$vM%kb&RdSS*>Q3d9GDfbX zRMak$R*uAuD-#a7J>jewglFj#(>wDx^Da}WVVX`u)^9tXt?OcSro9Gyg=gSJ41+7sHPpI735VIYC@RU)5jiRbTE-F7k^Kq>LvN57`Q~fp zTj?F|z2N!ep6YH3xBGGKIv~#6bB=c0br>BZ99`@Qwskq#InSLoUy8U08r=l;0o|N@ z4gZFMH{OL?vg*J>g{^M)q4E4rr! zk7u^3}TA&>4Ma0)VP;s=5c-WHI2uOElqihk%r#7SSU0Y z-4ZB%He*k69?Fb(&{l?1`7vXSWjcepo1dM5UCsiqUUo7QnEGHhwt)v~2r7LqJ(Mm- z#Zj3=esGo>l6}ZrWB|F=cj`2SLr$bKm`tr?R_-l*7B36uz@51eu=qE6&w5&WQaw6v zZ|@ZENl!)3AJ-fg?Yiu`1ow$?-owaM?{R&x)3!-Amvgv3MQTb7n1 zM1I=G5va8+%nHOtnMan+96l-VidU@%^6L?DKlEHG*@#M`5Fk1lnsI%}tFMwK*=3IzbjrrGsm1zUc=HD?#HKoUaHTPWds#yadc=T4jfzG4}sX`TzlA21C^hRhNt{!X!+E%KkkH_VH?vD5Ta%X`*w%oJD z6Yr_u>F=KHPH^RNZ$i{Jm#d6(m~)+Ln|Efgg&)`Ko!{dS?RFyFb^x;=<}?>6X-Ck%58Uj0_M za9G)W(3N)~4-q`7t*bEs<`kEtJD{Hh@AxEh3e-S?Io-sV;`R6V8EhLyPftPIv>$a5 zD|TsQfELmmV!X$YH%j6pb}=&e4-oGHv4ES;CbJ_z*L=auN44rQoCkVhYz?PskP4jd zW~sZBu^?N10rhS&_yJz2u((9{6AM)++p7mPc|4YKz_<*MlDl2PZ zjWQy_*}${Yod{}OhGUMak#|_&0MyFQ*|&T$DmVg+L{>&b!A%(G7edd0%;+dq#O)x+@{t8*qPyKFs9KfxclV;#3n| zoT~_!sAlI+chIvb7$q$v4l+^5hIZrTG37}G{!Oo;eJp`>>R;^m+GyM4BEn?sBel>{ zF)Sx*R%#91jnQG`S44bjE7wK0$gsk=#N;tv)i={qrX8k1rc3&K+(T#|?8fwcEVd}I5;2mX_o3u4o6(X22%by?ILuLtkPKC40ZQpP>Jpd z*M&Jy3FZ~nh4u#{0#QDrcZmCm`kTUqKdjI0CTD|2{hb&`t!7TMMc5nE3MlsD!?%Oog0sRF zsTzE!^fX2an8#JYuWHKKdD5`c*x8iZ+}lK&1{v2H?&|;3J>o7f^T3|w!7b}gd?DYn zf^M-XDWaa`f~9K2471bN4>5%a$RgdzjznyUREP|Hl8X8OSkMaRxy zlHn}X1@)0{nN@6W_A*xW+-w?hN9XA3pe1c5?}Ga>3Iv;ZaOFx;*2=Rb4@gaNXm_Ya zh=Xs`gkT`B&rkW!d1`n{xmZ`ond~~~E{O`eKJG~OXGey8hi$ygVo$Yyca(Qjbgp-1 zx>CJof`yc~WF&u9uNZddDXsyX2?cTyZ5JqZR}qN{(TPNq{3p~XFw@Tj4~ARHr?g2_ zh$gTD>wvvsIj)LM)bBFt%!SNl%|}g@P5&B2WUXuK>TpjmTi(>hgUNG(I>7y7P|UH` zyVl>9Wf2=pieWDFbeV>d#*pC_{MN6sZRjFUCCr9P;we-jHei||bH5vo<8|3E&UE#W zyQ_{w`{I)G<9noVR@!7->T_DOZ%1K0$lo=o`N^j9uRiDGM^WVi;Hss!BU zy9P4+$9*MyVNbMIbvysA-mb~6QO@tq0nP-M$<@*|&gFAFblkOVu&s3b<2D7>%N(;r z-#vo0m?D}Rve+}AiZ&u1lKYt(+!DSRpNsCI%@!MiFp%3%2k(Vr zhF`5)r{7>0f_!Nu(>#;O{L%;e+GJEv=T?TD~Qe3ttMB339>m!NI}%f#3exzD}N3 zu4>M+j(;7I&PUE}&hf4r&JoTJj-HNa2WvlP^XE*+vE_))o1WRiNusW9nt76?b7WW3 z1pQg$9bH;?;yXQtEyG{sK+e#1iemx~y)8Ujyn_RdP*CawuZi){3msrAY$NC$*60r# zo`EQxY0NR{k+Y7BxM_N1Ai+m?g#F1cwxw>TVZNye`q3i9C*Q+)yAoJxJGjES-@2uk zsRwcSm^Y~EJw+@7P-hAG2etmFvXu4QNrV;6UPBj9Be6aqYfi}h)8yGJdOPbD_KSh+V^ta52 z2%2^n|24Sv)%DBtC-u-C=z`pLwjeVbyVu5a1IEXU$C&8K{=vRY$HXzqnSYo{>=^a{ zb}MDTAwNYQ#YuJ{vxJF+r&11>BmLmlu@a~0cccS^m6C9ne4$)XDD;l!>ScAFx=<;m zES9HA$?$d>CRP!uhUW)3UtRErW5EZg;=Jl8;K=2WpcT8C^)`ESHUSQ4LE9?Z_ne?Z zbDs&=Bt5gjaKo&#_OO&N*U|$oSiE4*JPgdZQuB*vA}>*lk&0S!wBFCd?t@Zo?&gCH)%RLpB$? z2bG4y5nHh`=Kt3rrZDB;-8Y9hf}MUDJnwj}FS`-5f`htGK~P0|LC`DD3_+!0Dy;*D zrWtGz5&1%1$bKNn`TQoQt?&J^oDj*^DHIC7j*WHJ5?K1l> z+mmCtU+~TAK^PTAY8HMT%nB$0GSoBFQRpqDq28h=a*J(1%Fc}$uRQdAMRce2gz>DY zzqwh&qll*wt0ESd3z|D(T`s5_$hBlELOW}q=8!YM1o9v|{Rqxl%`r2qhSI4Nw+-W> zG4$Y_*r9ABG+%RYS}hAr&35_%v@1R6%hVHc5?q!ys{KK&I|1UzLNLpwfrIuMPGLN< zW&324Oi2^Oo5Ig1)GaXVF(0w?jm#T$D>7j9MW`m9VTnG!t{1-wYMzhGOK8v6 zfLRm?{#a?$m93xy%n+_B|C>LC2c0$y(^Sj`-W%JmPEj*85m z%10?E91gb+cMD$>DoJj+9qP{NQz}Z)Z-8|3l59bx(=Cy;8KUoRY-j!$vB(l>Edd(D zWz#R@$;Uw{tiwEbiA|%Q((@66FA1K~u)oihX#{V<6wniYU_UjUE=Jd*Gob7xaSGnZ zUT1r<4=^`~P@}&@)n0F`m1bztpCV`62;5d5T;PU+vHBiP3p0cQpm$sonhDjzWkOs1 zH@yo{tzXmE!^Z^7fiu3`zI69zm&NI}_px8d8Q_R&iS`vvo!pDqC3xMPUqj$^ZtD=dy0;)7sI za#TG~Vw%&p>1)`DZ^QjR`RiZ`vO;B?*vg?UDFia+NZk7k^{~^%5_TFE_d_jTcXhKv*9QGRijD2B@j4WtBln0ic`M^2>H(pL5$a>Cca z16n}^$aVBP_OLF|@Y(p@w9r)Fl*dFGc|(BD#IE@)ZA2}740{_cq#vPI^dUR?i@DCm zvCY{XXoL35IIN#*k+qtPo#G;#G6rI9h=L1KFSuEi0PkxiMW80T4kGX^i1q4j^{qTe zsVq;CSK?&0S++}pY>@k)w&k;U2V;J7cs*jy6N5!UeS_XW`GCuB_W$(0_U-n3#5t{- z`>Xqu=cAtwmsbL6F*x!3W()8lH;!|lW@{BO6x6ghkjBQr-N21{rhZBfxhkA%_Qh+?EUX)tGKXZ$%U%gZzRK z;=qY(K3yDVs*}tP<`s6#6`9G<5Ur%{lO*+)w8K~Q6q!j5!*^qGo)ZvVUHbnB;w$J2 zdxI-A8gpN5FqE#M%4;7~M*(=1R8=!^zG zU^Jp&HQ`)6ANreBAT)Lcf$=ar)JK6S5JZ(*H+mPn6Jxd%gK>|UdIjPzQ=wm&jtIpu z=nU(l#_%d7p*n6Bq7|($M|k0JTblfi$aQ_H5=G%FycN)c_K~g0<%swlM$O$mXey^e z8MlDYgT(w3Irg{oRWQhRpz7M}ku= zr5zz2!+S@-o!z5Xg8n)m^mhySLdyWxt{M?ipJ_!=Q}vV()km69{iIe@`@#9ljmq*} znv6*JL^x&msB`ohDgyMIOq^g}(NSQ_ndoz53S!f5$O`0V&?JxHyf|K62|pmz=)nh% zfLmHi@&+*!v(5*?0SCTEh|`2ot?MTgQ0P%(k1=3?Cr(l90NQ~I+8)jfb>JJg8ufac zu@bfh!*UeVQne9xHRG?o!^0ylT*nF`>i!f}<|!bhR)J&jXPm+&sa=(y%2TYD<*};0 zRY>qDSIEbaUGRuIl(kZAWr5sOKBAP7YN-E!8E{NaQ_iY4mDNNwXbv80uaVbHgjVFf zb{8|oUAQV1B#M!rwC-dpB86nZ%J9Hlrwn2fS5N^w6cJq?V&W^njN=ds_QI>B4M>Iu zP)*5G9Z7|(j*&hdw6qsEe-9uBpu%zsabMd_G=Zm4dGJNAfi$?1{GlB|ZOIpZ_LUP=HsckH`^pJ8?5A zQ%^CaFm`$oNg%cLk>9GnweOUy>`{h+s(FgMD~I9m)&Mck81*_FqLPTxm>0gIe>4QY z`HI#Wy{xJ>4#XB6PPMC*7FubHMnZ9d>emWPMoLbW_GvHGF4|9V0jQ&Q=>u>YUq<|b za;7?-|8MdW^&DA<0nmZ8riPLip${ueHX|jB0Scplfs^-Itaj_cM6F4-K@9y3#vKbi ztPxee+u>HX29eI%@Wsr9{xOA!1d+^+`Jw=M1+D#qXbSJplH@^nWnMw9t`F+(3P879 z2pO9~=%a(tBUa(RP3U81;DoUjwVHeX&(`yY)+&Jfi1Hvr!`AJh*7nTo3J zB+M=Sz)3Pe6OkVtpr1fZ-HX>w##Qx$+v8Gja^K+{%^(c70nKp|`t)tYC`E8*Y9abn zl57XJtSp>N4}+fD3-!Z!ko&x#j!>t7-_r;snoHH(T778FQnaz4$X*4*vMRW-y+E05 z3-3-HoLGP0-frP;?&B^4XvJLU-IZ}Y-Lcb%LPdDiUk`gwXaB?NuVatb06u~Ph-)C- zU4UQZaqRgYp~80q_Wcbp>#(@mUn)s_R~r!jVD>W-eYM3zhI$CvYcnwmdcz*@6gfdW zQzsHrFt4{k9rZ=xgL;6-Rv&|7^A7Y_y?P2x?+die*jwC0|NetdxMCWHdQ2^dO!8&WZQEkrPr*~nK?aQ_ zd!gqqBjWHpW6|raf1`5Pndz|;yarZi8}zxWARf2ZI^bSrgXR1MeIXxuXMOaf>ZqZ~ z232_j-nR$(PIuHLZ^CnaiWxTt4wb*~93FwOd`(LP5w<6u{Uo@$?S`}BBv6&>fM==K zA_!Ftqhfj!_Swbpm$~r_D!?6dIQC+%un%XUafv26p^py6EW8eNq=Wuu!_^_^5tr0< zYA)?0p4Ali$yWd!@FysxpaWu;lNbHT2DHK;PUKX0WNyaO--n)g z5I@bqC#x7md65lkPdvse8k70pECI)k&O}p;wNY4K-=N)!BW78TJdU=GhV$1vawXV7TR^69VlMcB7{hjO^xMLX z{y+3%AMWWFX21RFFw~TX)OMHwJ0L1C2R=w`K>ECndGbG0>2E`&uMO1js+v*TiBJ6v z8kq$roXMDL?h&6rH7J6eT`P>tX_&jllYbrGCXskTBy*$GkT0M*cv;kqR@cEW2a$c#vs>!m^q7leJN%I{S&8((d1~%UJ__6 z7eNyo55K^PN-q#^%7brS9#o*EAZM|yOejq|LSF3`Ny$>z*P*t@kwH&-Z^=dWv29;IA za$k^7JqmaXN+0cjG8q)#-*7^%4#w^=B^|Y0+o3OBxLvMYeG{TW2rLu4m1FXbg;kvA5}<~SAQMGf5w5N%&!=Ia0oAQ5>OM%BPo+a-@yOyJ{|hwkh-c3u6jpZTl}L%-dF*PNv&x*c_i zj;Ef`EvS0*IH(OoY6ardO^D4PHxvdxzX=Gn$K|4`S?(#PNX%cvTUG*l!YH;I;anGQ1u9#7wQ-In5V&7 zo{QCH8AjY5c$sY^Rjmv52X?G7eF;IWf$>&P_R2@4^3rYbqPP!4gL~3`!GQcr5uu>e zKm1(S93B`RguVSUAz!#AT+esFHPl9|fFh(f->d8@6H9{pQcP{FY(Q3hury9g6}zDaERc@MA;qJ*u^-!x?DQn|GHN#~AQF-SM&(Up zsm(aSd7-K}$)qxyk-272yCeU;AN~I;C>Cpx@gF8fDXO>;HEo^6hr$41 zu+T5uT4)yD5gr-N2$utyVGfi=E#RkkTv??gsi!eB8_AO3c}>R|wG>NmR=y**0!qVw z5YH+BR@!6g5G7#W>y?Fg(+!3RH3IYtMmqodXX0ckLa>_`|W6T7#eLLvJgi)x{dO3cQE@XpLlb zJjfKEq@c7+w1qE=GlK0vjW{1F9PTSD3cV73hmD9?+z|c3zhWx1`;k(rEP+_^K-sJQ zL)^w}C(#|4%OEu0K^@Oij2kHAaJI67h*=b;1q$<`0lu5qAE^sehx$o5r{q%ysA0Gi zUBukz=2ESC_f%(E1ps^B(Z{WeK!ucv54i|p4t#V#EAfJ zN?rx}&tv>Qy7#D3Z>4L7$ZH3E4^$9d0f+P%c>t^T3w5OY0yAJ~B1=ssj*@Q?l|D!} zVMYEZf%;^^!-R>8@GDuz+W@~XjBVuxUjIa> zR6be@>e`q`>HmrRqU)t!VLGQZ(2pSdal;scmZ__$BoXtZ-3SRIJv}BMV(kDdM4Pa+ldm;vR_fkG8fc0*h%fv+R$@|6WSp- zcyFP4YoEwOQdRHJc6BY=UMoRvkV(lV8(mrA2=5M2_3ZGM6=UtE1O3ELIU!-2Gt)LU zyw|bcyWTU)UBo3ij(RC)O@1U-g*YCUU_OwmyJem6OVktAp+CgchXP;;xs=(0J?1gt zvg?~SIkT${S0~ z#O$z5d8+nCCCmc#s~9HENEg)YAn43eW8uaB`3|@ILDej*E>6-ePGA;LY~^0Pcl}9HknQQaBHiWq@=4!^)n_hMaW5-Q|hi9 z)}TbD-$S=qUpE`NN-L8>df4Y|4Z{iDP*YjMMtwA&OwOl3QB*s@X{9tZ241yIHAPBS zHi=2%2sstwox~na4kOb1r#`XRmfH&pv4DWrxB^{SG&@ozZcT4xq%)8YI>_~7+A=l3db&!_XD+cr3~{;{`u~j4+()i6 zQv%t<70e;-71NSm#>O*cnRNA@)?7XyUqMEw5H$+^6<5_)Awg&ti1FEdQGrpuapHP! z5#gb)pv~t#m%S(Jqy2tX*ulA)=ltg@k#*A^m)$ETA%nBu^PUYpWoj|Q%)g8sBDzOd zj3HeWQ-4D_^Jwc&OC9TU{R{RmF;eUjP`%~7A#a;N$Dlp zkcPsmpttf^=VRQ)GpGiiWG-gxuiwv?(--9P>Yt;wZkB#B?PK~9H?cb|3U`sAaG>ri zG=Wd>AJk?iyJ{foRm<@?{MtS?FvDNn^*Gqdw$whzF~F9d6L9BsymgLro^n)oCfgol z0PgGR=Zn^slU9~D*mMUCGj%Q4M!F1czrK>`p}CS-)0g3FR7kN2JAzEGPk?|Ha8LNH z_%3)P7!jHiG=$%Tr^zH-S`v}#8GwveEkno10}*Fq21QMeE1c&?+{9dtsDYOIh9ZVp zx{J_TFQRMar>w=TWLbcNw+10B06>M2%5ZxKh ztW&ildAe9aJ}+d7okHGF4ZksX*7G`4+r85>+;`3{gRUEL)dYj5xg)}F_Dr;Y_KwfF z?+mzRxJx*i+lRWMJ^z8ubqj2rz2s?r8R)JJbq!e^D#g8~a^_hkAHSE+Kn%a87O8d> z|APx^X;9BffdUp2N`!W$k)(#5h#IyfJF}nR&i&a?))cnxiy9V_J9op_N4ajtRf;Vb zwbCM+FPe^62w|JY9-)2bQrXtOtlJJn7fHZL0T&WFX?5We8B{NAn-Mq+aK_|ypw&kgLm8= zeLHGWIZp!H+zGz+AzJ#QmKC2c=ZP77ZYnR=&$yXSGWO(0nLFyX zLw8qz`-wZPNL?nUQ9slR#A3Cu_F5W)6Us@ss>~@Y*-RZpcB55(3^U7g*ibuSZR9fZ zgP5vOKcW&W4Wdt4UYPHi%NmXtr=d1-1AB>mikh~L>;MKe;rN%q>9`EE4(sUSaEKa8 z_8{BCpQbDv=T?oU;*zl)O*$sWO_y4Ie#7ZMo&HOzF?XsI^11Y16qy=k-W?74s9;cf?3LT$2saR zt3HOmM{b}GQaupk>JIkpT;hk?87l1ah)H)uR`{+mKq0kae>2dutjpvM>64(>IcI3D zPXjB zM4e@dk@v~`>|)A+n95#wPaZ-=%Wkp`TMHzLG(^aQT0^1^+kjEYIF_Pjb1i5K+nkxm z^<&r5EBNy4V}@XQ&@Jd6sOd{2HqgrqxL<|fijr5UDCIGEke^LDw6^Rt9ZXfJ!DJsf zA6r={gPrCP;)5s!t%n!DrPloO)Ps4TfT|>8|>eNIvKG0UB=xQNGo2-0+x^1lb zRsKwkp&|n>nJUU^zM%ge^@&@gb`hh0{)pu$pj=#^A5dm-?X`>FSZu4{+6H55Y) z3gt0(#(H@+FvZd%7;8!uKIoj<0s1ImWDhBGnY(0t*`gn=o>!Oi9svq=_lww`whu9T zeG`nk!Z*|c#$UnjbX4$cRL$%jv1vK23NHK7G~a4Ru1HVj_*}~!=Zu&ACyj~Tow2{O zT_6MZv>f`g%pi~oy*7~7@Fq~hFx@|1uFtFLYwC^dP@a?ac_zWJD!Qw)SJd!dZ}XRO z1@uV{za=35Cp^K-aw<&ghKs5k&&Dw!ryP6Cw$r#Y@Hnb|_Tk84Qj%Jpi5C;pKAf4B z-0p}jWa01^@|IN0x-+w5u9-hem-(2MP-MXOnfb5zU#e}yMy8q2!cUeBQMc?Rt%Jj* zm6g#|epVY}*uFRIuwYWlm_L`njbSNUCh+pKkNzLek-jTV-=FH}ddXM+8A(=Xp zIi^hflS!5IpEu<<+2Zpq`&Kz`Q)!HCUXgFV=NVT?<)bpw>qa-zKKRb^?U;Ms6w`K5 zC33*^t)sqoS12_6TXa5mx~p)*Z}E9;-V4MHT`Qk2di1wIg=Y^g)QW!=QT%s~-?#(J z_Y1lW?r(8R(m&@O?@q$0vbuXz-Z$xsEx8p_ut9WsR=)z(KKH56>Sc1Vt(ku!+N18u z9DDWG(E?VRW`5?|u75@>u=7Qn-%Cf2BffhFM$K`hSl4INEunsV9JQ3|;VRAEG&jl_ zWWK11@;Jd6Ge1MK{viuF^s$@8O~HlW#r^yx6xn89tW>nV^IxWei2HVf7WSNbg%JdO z?hImQ`_BGEc`0Oa6U7ixeIG`A(h*{*Bxk|6QxgpGlUXE6Q2a zJb^Rk^PNE5?0Vq{eNOE~Jq!LaH3|(@XR)PeUx4Sbi0dGGoJVB%vdEgZJ6zV);pMuG zu5D2@1MoyQiv9!8?--aN?$n4e#8S^jgM<0#7;l}WKK6Gp6rkt$Ul<(16#1PYJxqy@ zxz+|b=QU@I%AZ+YSHn2n(^hPs=cm)^uU#zW``73kmn&*;PPeF&-crnZmKOJtA2ka7 zF;AWXwuP3(?u_VoyEE#Wy_Bgov)NODeWCv)+zp>HmI{bQt8Xp0fc`FQ<4JE9(`9Od zyP$3-)s~naaGTbP52=7}WS*iwb7J~aGo1AdZhdY_ zsC<@Q4@b6F$ieCuOJKigYakp^CbLvrnQ&HUSj;ZxCeBLdh>t{5#A0`Coer_a!pH?4 zWw(d(vg3k2Q^g?1%HdAA)_mDq^l|!}*y?{`ir4(|D$laOHS(Z5!f+lM<xD0!7oG;*i=Am^e!%6+x{ zLIgbPL}jb5y9CwvGKM7YS5} z${YSASLN=hv*hnoKRVid*6I{a38Rde^4h==>lMF^87kZ{B>JYYxusSy*M9xUHPH3j z)Wq}0R8BlCB}HuZ@x*IAt<`qrig==}QNFtB6>Mk=ahIV$$>P?8$}`hE_43VeO_WB0 za9KfIA}!}nL)#Y!kw(9^M|l=rts6<#mP>}6`WZqoeHGtTwlVuObdFCGs?c7oWN2hm zX-8kcQdMW?%l zTMqbMu;s}cLOfS7G}{#6+Z}PkHZ#{H_Y>X8z#hYX@qLhs-s8^Hr}=tDO>xe)>RiY4 zKIt|W;j?iS4ST&8iM!@qj!MSm;um5pnN7zE-RVR_if0odaDVv5-fD(6#IkT*(}6%< zwzS4%u5dpe)u|#co6bty2Z+tT#EThc{D~qx7H(NptudZ@e|2_4fX)PNr$e?sPa7&5xiOa{>+P{FWDZb2l< zD#g)`&{N$f;TwBOx+UBB;$jocz_yoDpu&1h)s>>TuR=4pyRFk=ZI69cVCzdHBh{?g=r=aCnw5^>|~$H1V2=Yg>s`ceI=Ne?L(9z`-ufuD>Gl3 zM>UtT)D-%PWP`@2D{)3mr=md>SS=QX0cl}!zxppV8JWV@>MZR7nM*0Gkx)MVLN3Fg zW|2|I#lMpeGK4Zu9YSqX+cPfdEbcj-c^ypV&Z!rfOp)RS%O&Ag=vN;|bIAJCEg?ex zRt}N_r8MTKFdm+*3#g7tD@L!bhQnMPWih!y#Ewb4plc>BVjC!*s1=fd9w?7ev&nl3 zLa`u_o&$N?h@3vKBIFXSEp=PnN`2KjqmI1@a=LcaO}$Xlpv7?`>%LaGjHh8EK4Qk6 zMNii3fKz&kA6BeohTqkO){mE)d3F;xT zyS=EQr~vLmY*8iRG)UzowcXgI_Ql^6q|)Vwpn;bpH$ksd1sc&iM4Z-}IE4(=UT7`O zYjcQ^_{-MXcCwRpMqNa;RvSTC{g*0#7V3;u8gwJjO_{VaWTJW)*_vtE5LCS%A|kW_ z9DzFHH7Y$;YYR|$Tv>AvPoPd9$U4Yi zHYUD7K~fjGi3-SVepHXb#b>EDnn;B=?_?+>>OdnQgHB!;+ZE5BN-aaQ97+7j3M4MAgI|=&kBN(NGCh0{5Ytpr9l<3{TRt1OYwJTxd)dYPX1{ z$gaBKm@7l|(Fd8)H^?f-!=pG|%a1%qalC5)%9gh9<6BI;(gqT<@dQSomZKByu_Qc2 z3litF+sLYKhKp!RXwFn*iE82XJe-HeXs2-pouJrh57I^mnxg~IER6)M<2vddqM_NW zPteet{8iB8g*#VOC|=4!JvN=#3+GNdyii+1hf)ep#soFPXlQ9DXe`q4j(%u9g2-g(zwChl`Ow<4CQ=snBT7f{nLQ6)=okUx0hlb)Eu^d|D>A1o{P=1U@zIYt!3g!bKnRv)Adf*R-#`d<`s*96+*B=Y`#{Qqg> z;LmHNK)3XvpS{z{;96-oGe;81P##q!>LJ_q3EHBIcw)bybUCb@!WHk)lHr~_87kJ? z#69RWVsYgcpxsNt(_N?iL?3947SlEnie&Tx?Pu_;Bt4}52O&h#dC?oKNpT!5)Qoc!CUhDth8*v}Fw1;0eB3(98UI zTHkS{xAE@(p|^XWtx~l9#0F%sD&Q&`5P6W@=0Pc1foE5Ucnim#4#?B>$2hg)Sv^1> z4?#U5;C(cZ5*TEbE8^W4Wl+*U&Au@Q6VF3gzAprrW$(%n0Je;jRp3s2z`zIzWZbqD6B zzq+tpXqmt7kd0`=Dcs?3D7bDx2lN51>CsT26+k=MF(VM@>8H^12A~I};&UG0ov+}N zj8GLapyBL+F6c6xOTXf~!}#4{%v`n6#`mCoy@e~jgwHsH*4&Q1pQN=z&E{jY++*Aw zhw<|j|6YVTrvQ3E99pe8#!gZAf>U@Bc@b~pp~i{CzfDo8-41Sb8!_H@qS~u6JR8bm zyywA)7tkhuebOT^50r+!wh`97y6{IHfY-Kyk~1&-;(Ouu9f;g$;|kCU88HL*!MGoU zKGhQ=e=Z0QM~J`Lk(u~+DV)|fy3_VFfYyYh*d5qR@LyHuIZler7y#U5Z11Pgf;clv-j}^k`DFodb ziC2xoEcY8asjp}a;&0V8p1~8eT_4Qw>FC*Jw1k3Q6N~xq6W)6o)K!z9D;lpghbu)D z%>139x-!GpK0nm$<+Zw65PsH?aQ0}4-*tqFr8A;EyYaVE@qFq-t2G#&(LK?JqVSvA z*u6Bv`(bWI4_S%#orzvB4gF{(o?SeysUNglz2NUr9~zAv_}is$yC074uRu3;7w@wb z{V-7*0)^KM)F`gPdv?NW2jMd}!7pjwhR1!eCyD2^5LmOVkMf1nWr6|>J0Rn z)x=mRIhUaq9KXd|q+%*Xro0UX1m?|3|%s`JyRa1uhnP1y+3Y-``ku zHHXcG!OI}^`oFM8clJneCKKAq9# zgAu^w!e1ro6^zRtaDw-sk4ac*Zea8u$E&~K?@nWdKY@O5 z3$HzjS8T=Z=P33!N705ZL}WF5^D+=8xoGZ97Zr(LEmhKWoUs4IIgM@=y% zdxVph$U(|oElA1w05bG76rI#A{8%a?Ho-13pR!hvLjMNBp+QnRoZnL5|5jZXfmms6 z5P3?V#w3DltlUucsX^)udyDx-e%A7UmvjS;EYZ{<=;GlttGp0{${(c&RTov`&*ZY| z6`4{KsoC^UVmW+QbVPe{v6@M*C#$e!h~>1VRiU0Ja|l9>r>{Uga$V!#RCYz~38na6 zrVn~-EeUQ|@;205_{b{SI`tZ~&=e7m{CR%tJr2neh^ca4@X!p{slL|gzzOEI{8(zL zq{`j3O{%0OiOZFBVqJXC(~VUl8J3O8RmBXl$5*C1b%dw}%3E_H6^=Ox^d4#tW}vg0 zRb7f6l8fp>T_m?DBh`I!7v$cDgXKD&PKDFFs+FQI!vEC5ucAj%Z;3XD+16x>^8c}Z zsw@=`J>+4oG&71_jL`-UMPi)tiLS3IM6K{Q2U4DCB0Z!9)T*lCUPz@Cm_g0rEFD`;-pV-GNHxMYLZ-9-3kwOy&R{E zg~oR!C{{ZZGnD6xF@92^egA?z{ak3f@-f$_yYOQ#!yf|)p5S}&BtH@K%nZXNE)6wM zH+9E#UG*;hi7r2PKzEPb4brNITg7`3zkkX#0159bOVQDEe~qG?$|LPx;eb*T_2dmg zb;R1iU^o<<7jEZw1bl(5;Z>mxq0K^DVZHxd@I7KEW;qltD<+10LQA2BtO#9{&%%eW z7WzjdoF?#V`MF^GgFMaMK^DZpuHt^O2YC}G@p-uCdXYElK5@@=`?xuJBgmQ_ z?uLFTzZp5q6Z9tP3wwcS#eHE)Qi~(q3J?;FrRyDe|Ka5RCmxn)MNL*@Xhci2*;%9 zas~AtIflHUc3{k)%U9rexDE8S)MRCrc0%2yj(~sC zD9VVO#sy|B+W^dy+r}2=W2U;6#u15DtF@ce7Fo}d*IE))T?`O@CiQPx@0zRghzHO=qf4w~e@<2xLv6@DmIlfFut zvYXniU8Y;(Q+#9t#JRRHnRI*h5;d0oK)9()>N{!@&q2F=3^f*=P>EzkB_qQ!Iu~ct z-A6U^e_)V5Hf^&QqWW6h=HC&uBIT&H(Yd0JMfbEEHH|UVF!he8XKHPnpgRsyA2|GA z-#g$X@ts@3wF0FjudW)G47TnQehz11>w$dU2kr$b^@fBal-M!!EKn$v7&zf8>>1&h zZm(qT?Ktf`?VRUWYl{MJ(P!&oTVkJT_t|%3FUa1THOI!~2=;VGK9|oKuorQa@y-ag z2XXGU*jA>Lw$L&u%qVzi?B^1gZ6Lc*y4~Ct*1`VeolRj1LJ?JpnMjv~Yw05PBUfBE z8vM9#x@*Qs#+!zv#s(&n(GsbP`VkqhBwPHEp2*#ixnk}`)v#VRpD^``@R??tY8l^x zIC7FNt-lBlh^PD;RI)p{Aa@mH<^KFV)(KixG`o;}LcgHbQUO$orAu{$%b^p&y}>U2 z1754!ZZ|ou+kZR1INRID<_yW&p7k@kY>t}qJUcrhIdfLVmdsz7BeObXU&?NoaX9-x zc1Qar`+v@VeUJP@LTq@XSPtIl6#biVgBqF}JeRzfIkxB}K2!e;^`K+m5737D3AdpC zadZw~cBESuF5{?`jylQ2wmP8`J; zKJ2~LyO!}Ldz0;nH?TMVlKUnI;vn&Z_*#sx6tMiTyb8M>@iy{Ec!jWo;Y%VfM170? z5&b!8X~dea-;QZuJNyhA;2>u4FPDN_j z9ISt5jqQe8E1<+lc>*>4Q!kP-2(Zcu@ma!3VFq$0R$B^y*-#^5LX;kPG-7{vLge)* zDf&b12hm@nE=Bmm27=QdgdcT`fgY!d*h{3v!$PER8qe=Yp5PB~lkpz>Pj}o9_Bwow6;tSW3a<&MCp<+9`Wd_NINza=He1D*0apQl(Go4`Vhp1hwxY;1~`^rc8A) z33TMTmQ%=3d51)!JHjOX5y$h3xGzYNae{920i2hPNXfV_juJ;(=2^elw}k(UI36)8 z;&sHRs1Ld3=k`S(%k?sHT7((SM2ru2IeOTG*1DD{;uNtL_`^TB+IVW5+!TDpM_`^+ zfGWNh7%uNf2dbWaJ<2Ez5AQ+cfb^ zGIK&kuk?hp>#2OIBXxdCYI3s_U$Q40opAu-pgjHp`k}e-$=2 zyjVnJM5lZ<0IKecaW^73Zu4 zviA0@K3SEsj%AL_xSBpUt$kYE)L_c6G8KH<%;E zsipN>s7giB7ud^O9kG+8mNnh>+1}l;-to;*-BHHz(0<8Y)n3!y-}ct}8R-P;t?j`6 zx?=HLa@)$-?T*Keh_J}8nD9s8w#ccG%Of8|?2f1xF(#s8#Ix|f!`e6^?Q1}+Ic1$} z(S*(Xey$)~(l?pm^lLa_GU2&!!kbrJ{UjHVM@v6LDZvK8B7r0R9)8Zh-Di69dV6}- zxwoKU*3fks1dM?>AF>m&<*b8Q=|~HDnl7i8N~6*nrmN}sGSjnq=3H~;xGVaO`)7v= z%dJ!j-nEv*XRt`Wv-`Oe;hea_qFKXjTWsfTcWk|F(`{F*o2{~CthIo(i*>Danze~_ zp>?LUEK~- zO(@8Z!Mt6J6_{02L9!DR3=fSu=|SoSr#<<*+&Rv)^TX%=7HA-S zl(W?;y4xsD45YRp4VS@grKS)se#E*r#hS;OfFyu|mdD6c>n|pOY5t$kLmYs-nh#bUE86ShCBUU*V?wa77%?notKW5n0+5K`wN!v6{zWdCf90Usw@^a!*0ozP#* z1J9&5bm3jhDn>D+J6JUdY%Q0x7oNAXz z*elPc-!sQqgL}f4fYK_*GR@M)@={zXUV-vCh?Jr)NTa&XuZ1=;6PZSl;z{w7wVb^o zs0;;=mbWbYWW=Y4RgncFTSf>Gbt4kP`-XQ8yJ}wm@^B;TT+0OUI)8#a2aZ^OB)zsH z=fXMKRKEcw)Lr$p@&j7@a&oHlPMRX^2^GX$OTRZe@?OZR6_Uf%=%hu|>jq!O*Y)x%LQb-_>G5j^SN z+)jQHn7*IHO2}nM#*CRId_hK217RPu&m6KUHe>gp@q>{K2jzivpuH(*6GE62wkTW+ zpBiy5e0cb)uzO)M!oE0s_L5c25w4OQW5RFnx?hZ=Ydx&&>!J^ak52(JYo@RqU8c| zCOUu*I7>JoyaprSD|jeL>?tJQeCHku8gehX+FID=*}B^2AlJTX*f2+oV=Csu4vwmh zZ2JMwz1rHr+{au`;x0E7+r&%Z*tbw!*Scz+1ADS7Rgfa z*bmI(yKx2CTg*0EgnxQCs;C+B9+!g-xzfAXlgAs5`LewymuH*jsJFa#wlCoCiTullAx)~Su0kr7 zTmOm9#UipF9PEF=L$Z5zXCaF!?~Ye{Xg1gt-Bd1rgh|navbtYE6IB`{Dmh(!8r(8xk%WC|)WRmn1N zB<&$jl7p$~@M_+qwt`^mgZkFaG-d`eOTlW|%j995$j{(d2m2CfEFGDV@D-J1#?m|B z@2BWX)Oq;bcO#dhCdDI#`9Elk&yo!Fg}h5XCC7t>IuB02vG6HxfjfCLDpg6Swr;}d z*A3a>t(hD;bM*13Kh1h=pW^P)!*FcX)HEA8M9EAsevlg zUG!eYpbPO4Rh*Zo{H%ieeGlqCPf{ zhbRw5O7JIH59@!d#{Axn{Y;Y<1rN1#clL<)Z1L^Y->DrzI}^+KpGT||9289j^R zs3mqp{qTP^<11zre7+MkJ`Qz15}g|ldLnlygI?t z*8n3Hga_`1F%kaT3&wNfF>;LmG3FTy;Z!^i#~29*o`_CW2l(>Of*vp&1lAU~1C!y8 zz6T$4161k(KtkQR4*skp{$>r-h+n}yUkM)B=BPogMosgk`48&j`B25%h|23WRO9;M z`&o)AEswn2rl9WFF>@vurKxJR9R4&h0dqr9ITA#X($oaGaIm&ok}l8fCKL3MYF^_s zy_79YozN$#!$2qB&FrB$a)oiyxD5KuPbva^M=Nre+tOFy5FP;kz8BrJtK@g|M3<2B z@Vk@n`@O*TyAeKSA5j9&?Jra%YvYc$z@I)aKSD9^0qe&Ca*_E+{}0jHEL zdJl6tcpV~F()gvlGS)LQgw|X|;)MR2s09ks3Z_2p%PjK=y^Tu7yKzvRt-sY5;(j=g zt-ac4hkn`~va_jDZ}9}oG5jFZREF-Gl=o>3%`~<(w}dZ;m2nhtg~>7_K!#%AqgW#z zLS4SRsS10TotBDFFLNSA)=_H8Z9p2X!VI7rC_kij(4~*oa)ax!9XW`VwIAwQwWB_f z?8Z!DT;vb)omL4M*~M_L<`Y-wo_vTYZS2*$Y3GgC#z3&dO8mF}rP>Pl$8lkHA zR)3;@f%NN|IZ>;jFEgB|BG)wo#yh2j+8zD&!(GHPQe--kpled2X@jUiHc^*pCy6S^IIjRz>~bV(*iBxIH?EV9%}>-R z)ZWt2dAI>8hpLRyyTVJ@#VBix(avZaQRmM>RkA7h3VpR#W>I4vNbHM0~MNgZ81a&%tX-QD4AH-VFDz z535A7IgU)i2wfyp!=`6yx?arO2tDd~stkJ1kF^0tS&%0lfaH*X+W&l_BWC#~P%nIh zPw}C4+Zcm4?Dt(~r ze2ZBx$Gl{Ckfq)iEPWTgr-?*6IM0XRy$?gRe>U+3KH?`}a#QHI%5Y)dhU&QpC@gbO z?Y(QXFuv;NjbflOz9N#zk>DW4?BZaweEk0vp$lg2lenr$)Jy14 zD`K~F8l9Mj_+CRGa8w~rk#EUIsETLd&o;xSw+daFX!NN?eEr4fM`12TpXD%S-YCq& z%Q2HN=5Ql7W~mjJ-=5-G`(Z>IlaUrN7)b!R;g(!rFvxd2g#Ug6_jIgr2@c{NP(+ji zFCR|;Jk&kVmns2Ac^pU>bHRfTf;~PGWQ;*@L-qzCuPr*m7BDI9V;=5{Jojp#4U7XL z;4%7dpOM$M9FEEWSUcDpkv+hn+Cmkl2h%g5j~@vZ?EhHH1L>;x*_qA*BF-o32Cls> zc&X3nf8jX)MJ=VQ;Mr`3+BlNh1P1RL>N=&vF?t7mwSnm3-NWkSAXD()#H3fM*L)7)^H}#FWQZ1$aK<$fEZE6%! zdeW5>N`1vEpOH_=r@^;cujue8A5gn#lePcRkWT0|QFY&ijzto&48GLn)CTGnctaIH zLElL~K&r+oaO|&vEc`Ehp4LE>E6?;pW!-}{d@+58K7k~pIOaEqQ{%zz?23H8?+gKA zWh?dwn}F=n`>YBoUlu(Jl%~O;dVZq|F*oUlP5=%Ah0t6mvjT>~IRwW9f->UnCmD&;@8WGO;RiH;|Q>0G@e- zI7Vp9ujZO@4z39|lxxVjkf2fm2?)ce+3=@kg2M6uxgG_;*DuXX1$BA{GZ`6--g=x)@&j|6fl_}s(M10M_i!oS8jFb5_X5ylC|icw;H)TYld<0+qc76~q!uuJ|Z4Up@Fk(g*P#?ShBpHq{r7>Ve=nGN5;7qRO9-c}B0pIJ|`(?IS1}6S1rM0gXUi z^iu0!H}?^`hh}DqQ41sqLf;5Z<8`%(`Vho{1M$Jwt%D?!OzyTYP%J1;K{mAlXM0z^zfc$zi``;bA;5;wf6Ol$t-KBO z3C0CXzvhn({1-SHoE_>c&6E!+AJiCqq_M!ffjPe#y`72RCh|wYdd+1yft=r&{5JL! z{Rs1WbMu=KjWvS;U*rgq3r5qAk$!=U0O&0GVjsMWkjdJ}H#^UELngL`>&Z4@iqUh? z-`I}xOiRp`QS>OV8X0;twGeu}Vnk=`Z=YgUIt!|h(Es*1lzyq+G>?i>nridp^6Qa4%Q1S_mA*R2e+}OyQZtKGd!nYHl3ZC?RVDqmI@Wp zQpo$<2n%h0iaOH~>sgUPhvPfz)zVTaBca=zPF%;zPvZQOh`%!ctKVVl z@>*%l)TMB$){(zr-ee?y=tJmas0y;@KL^er`JjrYo2yjL=&S}=8?q_acz+#bI{An9 z**=F~jkp&6-A-AqA*m}!6-OVgw7AuJ(e}~yQ!K&`GCk5cUw7Bt?0>S|Id|Np{F6hw zlu6nF{V4cGD{!VNM*68!Oi?~lkSsfFM^M#l;+SZkVq0T*&A(*ok%f%b>MVJ;bX=Ms zw^0kAV)_<3nMCt9cm#FziTVa(CsCQ6h5Bg^Z~~+3^=vCG{e|h=d*&Do!#vdp)Wjwav*84|-%Xe;%rYiF>TNIJ)L#J3epS$6O^{J0 zLwE8UC+Vl!E7ep!DulW~9jI;BO+^A^B>TXOQi=tZ zc)q*(yX$$^1p3M^wOeL$GK9TTDcs*O;4e>yvOv}II8fG~;0=1hJhePEy=lHR zp#kb`;xHR#sbjzHc;qN=kFv550@tQ~#-VFqe@tn7#ZVOPGDU!xmQ0(a>Jg`dlc@b)(-94UH%*uaYFi zfQE81G*mh*6_uAG{cWZ^Qm&{hQR8%uct)LN1N>f86|=>!LO;GD*Moh+{A4b(RxX~q z0#?LT>azJ*AENGrYdJ&fAXU`!GVj6IolGawIYt{|CtNXBI-Jw6cJ@Uw_k8voyaJ~v zJM|5EGnSmLQ)&gB)ketEl?<(f7NZRX+j|s}#n*$xGt%r%rjSBO_w$)cI=6->Ra_A>PTr_prP-TXRLdJ>mS!!&tC7)U}|Cd(JN*N=q}`ky+a^vNAV0ANpQNjCM-D zqBHP$J_MWPFfo-e_yo&uTN}qR$9;QyyW6^3oXY=8*CL{{gVOTgqQI5F$zYW9UV0=Y zf_?FKKnOk!y-+5ZkC-N6dD|`fXAnR83O$hGR5?(f#&X$Ilr^dJMEtRzCnRtkd+c;hj!mMOQtdwe!sB7Qo>fu!G%|d z{1TZGkvD9%C5gFXE>}(T4f4tX*$2M&wBS+SK9A!5&-*R-uUeiAECZ0G zejHg=jP;mhBM!+r`V+9yMx%0e&WJUZYj@O^YK*!?y{InN7AXsqI?}|T8vN;748BN; z|E@p3d!w^q=C!PM&NYE|1SkxV6QlD+C506gN+NxCKOJKc9YRF^hy-gSyF<$zXzZ>I zI(!XxgzqxCd2d4Q&{BDo{90M8l{W5>lh|UGzK&Iq&!gH!tqY%PozG`MUDO9>hPUKC zQZ*_4x^%~X+|$;5&C|y3hLi9z`HQ*B^%RPr6LFdU%@0B<(n8Bj>g78Yl*ZX&O? zO!5Tw2Bskge6q4rO+#8yJ!Pp}PN|}GGF#H$xE%4C^{j2H?W%3B{kwfw*v7C8j-b7- zWiCIP`9)MRZTe5FU@=M-@*}P(dDUhbiaW6-kxW$v&l#_EQV!^KlCEsQa~8?YLeaE|qr3ZP?0;B?DSzWlbZ+1g8Q_S%UF|sJgtwjjZ~O&*-Pavm!Hpb zINT;|pe!I-?h?mDjL)5%D=Dl8vtG8jyJj>^zmXa4{^74E6;%IHImIfq4<&~_$c@Zb z%pTkKsPsIA^R3NYF`}%+MSlRvq!D2Uy9U%_;+XQu=Wr%wzR&oQ{n9fgv_U^b#j{_y zHT)ZH7`F*3qwSVc_9o%4Bbr6bwO;ca6=ky0NLer6QWxWNzDhp|j#ydsZ~2}Uud-n8 z%1U*mvC>WXrghLhfCY90r}#R-yMbc9c>gi~!ocd#VBZO!|S%aJFdB;SHZQaOJz+!Ylc9h>h`WR|5h`&L;Ny6!LQX@b7{m%#X752;CDav;|K z)e{@6DQo6(VY9tuR8I7}2qi4so-QOL8--4^rZJ&&zIN``?hl?RzP3S5 zs*6P6;^4i8K=uEd4i{J2K7}ufdJ?%kyuAITpdbb26Z?d(Y#V89U|GTSH>33(!DZ6F zieS_M5$if~NvpF@sYYTfUXRbHLOZaYiyLz{He|ekB2b4=>3VKYYy`gol zWrL_7A2GMkP;d!fg*jYDI?BwWHVQQlaQ-d+zk{PgL3yRJQkksIQ0}SSj27f2_L@lB zFFW=-n%S>g&w+&h4-(cdaX;BSOf0p}EQs?}O{rZ_3GM?kc#l#N-lyVF@v2D6?ZQ;x zD~hV6y^Xfzz(ZdTzWT;uk{GanRAQ@W%_9!xsxiZ<`Q$&;2>KZEzP>Sen7q_gjNwiC z6}+>ZtwpTA#fM;p)1bt6;sV@8khzC&qmk;G3wb2DjjF~yqcf_d1jvM4^oD9ld2T4T zJPdcGBvQWr$8PGV-IXjMMXIR=_0Q10=jfc&EfDA3>LWvK<4~+{ZR+1$V}juI68vyl3-VbqI3q4 zWz3OYNbS{a(hF7BQq)4SKUgyu@GlO9DYB8k9N_Qc6g5Mr!q?}^@Cv_5EC?ca2XQF> z5?A8ac7we9pL9;{s4(g$Ww>$_J*|)OGc~u-hAhn>9ah*Nr115H4VHbDOxrj|n&W6# zhJBcIiO`PU#GD3QdM)Kee&s;(h+Ydc>?fd*A*3AXS9ST9P)D6-m$5DQO@dE+WZ7mJ z0>xz|b|O6-|K=GbkzTR_$cN`NQ`w>9!To=ta9S4?`P$l3b%^#_D}WkB5w)R`BcGK@ z$wk0pyJn2g612xsnDSk2qZCx9XzAuSCE0k7UK^$T?<6@O)GTxdJ>jpuonTd*^|o~1 zcAjuvaWxLKlR6k2c4U2!G3Y@SZ6Qkqv5PH_ov_(#_4zrRNv4C__FNgJ?naf*AvchA zhC;z>p`)SON)^37p^ysF0s64%TnRP?+OGxR$*mrklS~HDohU4 z&*2T)t-M2q*lTsPJW1*%Een+cr*DT+UiB&Cks3!!+d|`Ur;DPtKgy`4jth+sh{5Ml z?$F6#M-YOB$$nJ`tn-v{NBVcjThu=0SfhoW6{r)Cy_KbX3wC;ddpGn6nFn zk>c7_oPlqlIxfPal^dK$S>J3vBI3yc*l9#icd#;-h8xTd-TN@y(`aM3CL{eZpf=E2 z>cx#p=<~cnzCmuSoqo`ar6Ns68xZR0PY*1R8leudg4$%1R=)+VqBGo8jWWL)^WbJ^ zr@qG-dahhgz8eTin~j-dW47!(46dl|WuuAwH1D6f%5$cxZxnm}GB)}rS<3XJ1RMsB^J zF#=WoPE;WyH>x%YDts;VCd7Gq5`C0h4y4buZ`xrt zOC_nwN+}|!{Y5-rhmi>Lf(zh?$y51^U&wYS!0#gFlJ}V^Y6*I^ahT~&3g~dnr1F`? zncE6W9)s`iHQf`><9$3&UFntdUgNcvNj3rTY7?vbvWW6ZXMP^D1HJJzsAEmlvh)>H zIkOD;gnFvRP@41$@A7!6hvL%xn!q-tUXthK&uV}4TyQEgW$KDN2${Vj=pN(|qPaYo z>=liXm-P(o zgVu=r04m88YCqoO6m1dRQMtn?${f_w8fb^L!^SGAk8CH(vN!dw+J55)+lyLalv6)3 zlhq!~ZK9v@ij8xH)0L!3LRq7e)E}DlAI2iR29ad6C+*OZ+#_eQiR2qCTc(Ax@?7nx zmc*7fzf+oeK%5cub0vZq*7?$#@TttOR;+>J%VUHy{1w+ULO>v~v{^j&{w>F)hU-0i(;zo2~)X45;fvbmsY z@;jwTwk$H<_gSbHtYF!t#R<2($@Xbl4fRKO$v}NFpQ121rmQ^?iUbz$xdOY1JZdl7 zDYz}(@sG?a>H)gHvD8=v<}dUn+(rK!)T31jSsij2>t=bK{)8IF)+T4EedKqhjs9u; z;G(r`HcH&tc8m-KU{9JD-`#5fzyL7~vbTUk(=IY0(8cHMnzEnVHt;Y%1 z{cD+*YFn~T_;`1SC8!QY7j+K5(dYpO$aQr+v6-`z6^%C_p!WzAvOxJlCXofSN%WFX z$FL{rVe>qn*Lj%BXXzjr^ksc1*;Z|dtY2M=v5pDsQU-->R^fWpf0=8n8R_AU6_$bi z6lTA_rTwgV+Zd$I6i%oI+5h~BmgV%SU?(;V9B`A^IqD1jHdKMlOT5=^*#_kFa;yvXa{SDRcT}UoyoJGwcoNv0+m^j~z6%E_-R-NV%hbTx4waUlGp&*U6cdh>WMvb0x33 z!vfZt(OkUkUY{%0?X)gX_68?J9&~SK2Mf*Ie#>?;-hDFP-PE6vUjI(*Q`l?IM(U$= zXnLu<$=*3iW-d>9Z(CEnn)XXv;g`8Rx*No?S?Y6~X3F`i>a|?QQbau`A9YmrS&6l_MV|5E3v-{mg1aQgfKmKVYQ}BVPwJhRNAfb^ zu-d@nm`chFMkCVXvs{WQfW$r7>;uKfY(t>B8+WLxMhp6g@`37!w8g=o6TBqyL2I*) zIEr5BV)BGh$D9L7Wm%@3mWuBEJED#H3UAC$^yGKJi*3MdcZ}Ss_ah?Am((wvqc)n2 z_3zYdaEl92IY`U=4<6O+Bn>7PYxvEn8PCWvsQu2R1^qu_23f*bO>CmRXn(*#DN3fn zCvyl@o;<{5_lj6KGDjFb>W$$uY*Z~H1D?TsPzXFT zhr!?W9qhX+`eM>={tFLV6uM1&@VQCGJ#v__86IX8?B~Yl4ef;+vKcYV$V)bcC$I$h z&WJPu;lKG8@HXgy5x(>Sz245*Jam#pv-coaTa{EEW(!hLX ztRjj+Ie3|H8h@Y>dSKKe)}bdL5Ru>;Oa%$LA0ZhD&^UF+SDgvB>3ejOPl6`B2^xh( z&=HI?hr>I*6aMwO@CW{cULpi7*jDrvH^3d3Wkw^PW-&BtNAQVDki_s6_iQYjl0)DD zZ2+Bfad_OWL8*2GEQEpZ#x}-h7efY2EVMUuuqQl)YGf+vViVv8?FHw33_d3kj>}T! zLTD&1L6H>=AN4zQ-?Feftcxe77xs#O!9!jSccBydi*ZnKJpte02t4P=@-kQ9nRtdI zluZ2XVZ=l{ZHMu>-!W?a;bydfGw==LyAsN;Z1^H)K|wMVgl7Q^{{Bc(*$J=r5#$=A z<5Ty5%5Vdo@L}+qc7wm$gx~)<^eNw<(E5#$3gP}fGao`3Qw^TtSNPL!@U^eR-x~;b zsuj-DLD*v)0Q2AhzT0!~zaE4`)C;Yf51vjRT)3U^uZG~Oj)If>8Ziz2_I~gt@541T z0;M1udY+45P)s4lLBY`o<2(jV&Hvyz?h6OkI{5D^;VEnbm#Pij#-{i#FXCw%i!1UY zgE(k7Fq$o(4ZBHPF_Lf{3yjKeym!NBM`OHhVb4~G$cyJJ!g!5)I?O1Kudu_YiR%eB z29m`OM$Hb|-$!l=E7Qi(X zr*FvBIbLl-%wi5nwb`F~TQfgZP|wf42l4qH?b11892KKoW9iUB=8ZPecntPIihdOD z&tniFej&x-65gNcN-yz9a5WNB7l1`44-};a+EbEt#Z3HFx3Dki zd|1KIapi7AskC#^^;4~pgK|=B4GrD>Jv=V0SNLM3DUnxsWDH~Cyi&w5H_z`3Ci6Mc zVK#w$6ihU~k$?0P!~?3Qe+-+f7KJACD~QPD*?h_guDI+B&9VP*cXbR4In=JqZaj0z zD#?%Xeh71UnupExPt$+WMdjU=E>cbMfm~fV&G+>$w0_DtVxQ^$n~T$Xs14b6ou{r=qOV8jW~?gi#io%D`WCYz`=xTU#Br z&~u%<80gE^f+nJ}{yKO_=d~2u$xwpO%Da^*=PzZkst*Kcv#A7QA#yuPa5tb&O*Y$T ztGFoY7Vna8klT$l?){5=VKz6Gs9x<5UC|g$Z!^nkqm(Y%Px__N zfqg3HfH(iv%#G(gp1dagBBv5_*#%}gl`TIu7AWWq+^68+rxt)3>wn=0$ptJB5xY+{g6m>5l6CxXOf=%t@QNwxzV zg@@XFBxdBsNp3dsawH{9?`$-JZ+<%7{6FSVEkE(XSWo`Y2NApVVyMmiF!$-#&EIM` z=4zd`)3c2rQYc=U4XF$0iDY3u8fJ8+1o+bX5-arnWOcmnXK=D1^+)L1jxr5sXJY6z zMs3(jKVtD&7C>du z5;WwGNO$kfUSPh^I<*h$&PFJup)ApmJE1&Knken%uV7FwmX3y|hhBuL28V@If0f{G z-#dSj|4|^^zu#Bex6%8_7wy~cyWu_S{p$YhHr;nUXT5gs7xzH#HE&1%kif@Kyqqq- zQe*WL@R6z_?XDI(iK)ZQ=T>sFc^hAu^E0Vz7&KamOc^#FDYj9xNrf>l>F&%F`fr@t zSJ3O34CY_9D)dF=*@f&Hwg_j%S@a9rh+TmWrjvCcQE&ivjGMyufy(Qe;1j;XxpNr) zolS5CwiTK|wRDpIfwY58!c(}iZQ@n2A<}kViI2qzmJuQ^HWy2ZmxUA|uW*IW;r>CQ zZF|(}2ZO%5h`b28!&2nUsQL))wmMx&m1`+hMV9?idQb@E30@4m4wMcKg12ml_m5|| zH^W`s`?vR@=Y{);>y49k^>jyjE_myD+j{GHNBgGwZ-j;_gVi-!Wv!0h(!g$w>d5@W zZ190PN2Sttm=fG9?jW3qHNog^!?xspApbg=i(+~9Z}t-WJ*;qvpTI}*C%Gy@dBF?! z#2@|{f06sfb-~r&;8*gM;XTg9cSINFkYE#3z6*biJH_pW3ps~t!zHsz*cM>Hy}^$m zs86bFDqEY|#mU@AsJLeG8n+UkUr%_?|KT_CJ^3j9G}oWI!?tBlG8yn}H=_2Fx4@}r zXD&l#$zkoGYFCRY+hm_q5BZ^+gF^%J0vr7&{NH_Ne4D&uy@TOrn&wV$#kn3k`#W=? z3(?D2*)`Pp%@qb+QCoP{qy6gx^Mg4dTK=rGQrBtc^e)&}R)Z@2895Z4c?-U3CveU0 zLy2jk!uo^$NSB8eZ!NQjInCr}quBuKMGw6hcNbTKP-Ns+$aqrb!^OUjeaN=r>T`X# zM%)_iDi_6n;9hbo&?`tlR%{1$791Lrq4cIKOQS_pb9TU|QXh_$DeOe{ zBikM$cL8@ImMg@4VP7Jt{0TY{b?MdA0n{RkLzh&{sH<<%2CBu>WO=yUOX?ib;HXIp z_I|%TlY3s0rwDBbypjVOCwi;E9gpaM|nu^74Hatoxr_d zexy*gQY`9LZ69XR?&bz4o1c^8prSZTe+C)41apMB%t%OsHknb#mYIzD{cG0EMsjVT z#XQ74;!?PJd@^U_vp5%LK>L=PZ;t!96^WY*gsaHdZY<6amx&X_v0@|~UbTe+aEZy> z7S2Rc=K(klOX8DnA#e2nl&vb$maT!2IKbXOBKzNP-@fKl?ghs3H1r&$xgl%_>ezUC z3>-RtuoJ#-?ABLl>B!$nl}-3BJA~qqK|U(b-+$2;>pS5s?ET~sJyYHHU9(+CJMM!#XVJ`fpD#gNi6 z3_YB7)F8SJT^%IpzD#{Q6aOLYB^LeKIOuZYpp0Zq_b*3(`4@QuXNj-SZWqymY7)lcr_@Q> zjwd4z6dpVLXMKBpoG;#c+*8Ig%iY~w)&0~p2_CXFPM7novxuv`%K=YWKUc`L+&$8h z+xx(K+;_!K1Q&%?NRi4iWvF&e`)t%RKY$5XlbT0epp)p*xL0XNUaHA0#eK@pUxvP8 zE@$Brxd!Mg4d=`9XW{CO6=-3xFkV%TLQ6xNsj@#=}om z&=L~=6Jx~`p}$Z{xXWMR>+xRhJl>}+xbDul-e|lLD%%x*a+W&^+Q2{D8Fap?DR z2HwFuP?}^z8$N*QN6rM>YM3!ludCTqP1YqNR6gXwQ_&4?NHyOcujXmw`33JI;cn?l za&~o^Ip=e-bGqbY<=Apka?0lvc7Dv6=Ztg4VNBB9(>-^*-+b@={ex3O6~X3ss&>~z zWG3~2ro1MSzLr6I`X6J()BK4o00-JkZUmmzg$ne?8O!oD z;dk%Jui?-0rLbmf6;5JgbfEy;uGPdIqEq-S+{So~5sC>J{2YENp9}u~93&JE0M|nT zX&@SWffz0?*NPj2x9=#&@+q8%*|-v40rT}$d_@+CN+z?5ISKWiiIhb6!D{Hv z-vMD|JyLP{U>!EVCvjn|X^ktI!t2~pzBT`dzr=sx4L%oi&K|6q=fu?_!n7?DEL|*% zkw@@byey7@TQ5iGCKMNL^IQ1-m|$z1Jtut&Hb_0`xY_o<`ELoo@jt*lhbmPXW!<0qxQ;C)OhmHm1Kr#4#3;W;uX8Ba_ zDsNHmOV1L|AP?c$?C#-y?Yiq)?`q+ygXiUyGs?xg-n%UBkb8pXymyqZd7xY{7?cZM92N@8uo^#w>(PhL?8+5~ zH+4FGWYyrrH>#edvvC1`VV$y%E`LJ5kBm2o2Fg zEfKDU9Jzu#U-}g45?T-}f((f!7>}*~6zG&Hg1PX{J01SaOCA}kM`!mYD6QPC{H{@0 zSvtGcy7swhxj(sCZ&6>I|4M)ljgjKOian&yGw|$TSKOa&$3$VaV7P}MckIS>o#q$u zL-+!GdHya|5spj8Tz7+=hvbxA>^6+Ud5l9ftPx|lDj2z%*n1QZ)(N|XFfkswYTn`$ zZ(}^Bh!w;#;z`UIoBm%7(g{!4H>gqCVvI*}?XZq=*uw=dp8YXz&VY(=9iEJPNTn)@ z-A_2XfN|4Sx(Y>;E+jLSHtOkpw2o?=azfUm;nJf}Mog5ca?oA8i?BnQLi|b>%O3=5JtCNfkO{?l_Dbv1%X|EyR4( zg{y^8EP!inkG;iAB*k6fc5;p}97n{h`VMfs7s3qiH(7z-jFZLNzwf$-W z^fH5Revm>PLtl}e@+Ght3Fc9O@&0-KaQ{%>W#4gc5g+X>==HlBd*-<}xU*a%+{>I% z?&HqsuDveBHO`&wS>dhcUmYkO8Yzub&Z!G@65PKP@0?HW3X>! z;U1Qd0soF3Pxr-+>jd&6Pa-eSuhr5zs2i1!axOVuniD#Yb+s4z-=F;T{AK(Pe5-s{ zd~1Cr!Fw3sjq_T)rYGK$-?JO@D2+YXELSI2l%3N7=3L5-q{I`5#i#@302);IFU;*`n-urUNsSzKRUywxI4`M5<^7qndtBD~Pce zhwS+;QfIlJ^d&Sk^jAm?%s{65!oX=bgdX`X1{?v+S0nh)m)kQlRLHr>zZ1^&M0Z)= z@0{1(HSU>i%By&nco+HOLOTNCN-m{}I$!^cN^^U7wboJ&`UvxcdCB;=RK|@N<1kkn zijx{#w9tTi!;NK|b0%Gb9fR?@2LF^qS7ie9er71r0Pn*2u$bGxd!WLwh&H5L{k9yl zw6gSdE?8V15htjBCkVWJ_}QkeS#E&sbyltWq%(eZv~Ij{VM> zY*n@sQZ+UDJ6#yN&hzAMFkjyrul168J*}MTP#wxOd8K?GIgkH_`ob%6JlHO{J}@xI zfJv}9(8gah@ZOi^yXL#@b$QEpvpgR7+i!ZJ-Fv_U+3V`!Ug;7%c2BhTz3)H&kzlyA zN$Lz2+C3d^DWU^<*ekJyj9?prI^w}jXc+dE`LKp=!+E0+yPkQ&#L=bbJ5*D$4OJbv zvnxpOmnk<~ARMyI-!f;IW1wL9xQ~1T;hK;jq=^H?6wzW?1z%N3af6s4*s&wO03}rr zZ-E8Qs{GhF#IjSE6U-nyX;1IyoLj3Q ztI~%I6^mXD4yukwxc@AVRg&R1s4q_l{R6^A3ot0o1oc4K;EdqeK<$7f80Y`uw+15o zFMS97s`sXE9(Jg;ebs$=eM9_9eA9xT0)s+zWDioBa-+vQ87`F;gbO{c({LjAai_cE z`MroaJ{_m6GkBki^AbChTZ&U!Z=7ha!2c9S7pE6dNz@K1gZh9J(qyU}Jc8-;cw}P+ zn4>tkJwZpf0M4S-_|g2opdrKxnfSE}UmfS%!`RDvFfR_pF8mGinc2uxWGXYC;L-by z#GK#s0O(oQ;XK_7=0Q*oq7odd~tI+1DoHg$b4bR-Wo6;48A%V-YpeR|3h z)X$;)>ax%lJTo<<8uG)?)6iOIw2p+{NlT#qn*#!b6kLL|`ef${eAT&mm!3mJaGt>d<1htW_45tgrJcW~P8B|ekn4e5C z&ZisMnwTdtv0`puSFz~~W(cMzn~W^T0n9L_1!ki8bW0@HT&AbfH|alASLPKt5VPoH zaxJ|E?i(AmkSI-eH@hKQXsfxDT!=ccgc?{|IB}CS7n~mlv~i%r<<&dso7L6YP%Tz{ zq;}G#Lr*?Ut*Z!HwDL~*sdQ3$DHWAZ3J3LZUUjLQqf`LznqlW9ZwQMa4Db`Kllu0)3z$*_*IeKUB;WM2EU%&?{7CyoG^me4AKBY4-n>@fc zErGKt7Q6awxW8}cpSaV0+Dj+WxtVjg%E$Bs#*35jI-I=>CZ4Xv_^{6|z!s!?Gso#0 z@N|B}mFrXrY8V%oo6v07k=%KczC{k9y&%GUCz7d1G8Vh#_9jjFP^&LRRRM+I2KchM z!R|h690RHP1kPI%2)l6?9G(I==RQYnwZmMlM}TL(M*m})3J4#@H#rX(2HmU09I9IN z{P1|^Xmizc5Q)m_=M*2*e7E5`bSY)QtoSUyRRuMl5>jg`wbh$ z-Jw@7C$e+YBIH1(ztWgJVZK4nJS4BCR+B!pF#W?AX1pVQpodh({6Xbc{-yT9_1c-L zuRca6b*%Q$ykOi`phD9$b)dOZSz?4~PW^*g&8#I`4MOjxrs)mzc1lNcJVtO9w6c4R zy>b%i3d-bA^_td=SgOpWDk?4TRKHc@s0um{;^+@^2B-jqsGmCWf3U~CP1dKg^$Khu zxX9bFL(Moi;j=&iT!d7DEk+T}3+MAaw!AL$V~pbraxNLw2xnX*^4^8cN;{6G^6MmL zBEOsD8*5I?yql4cnGMGF0;USOY(J>=NZ}c0lp+5ipXhs-M)G+f7>wYHs`sEgTSN^s zZtEA>Wx*%>X6Xr)j~HmwrOp|>$sn!wb9yb^X>3zh5RH^fa*Ohb3MpTRu3Ax@Ca$WL&1^kFt4;M)YZx?H zTq;QI(~{IO)MeR9=EBH&$OGEH1}GNVK6o%ILlL!4+d?1LUV=^UQxDQ})elr{vjKUD z_?KJ^=gnDasrj54s`q7Png7u5ajqXp9;13u-H7e9<iKBBnrb`h-lQK%^#bFsqF( z>{0U`+^?s}4b%(dl%}!gh{1FivzWNfbfpBNv}LJwm5(qkFlWhFx*xFy`pH_v5hkFY z5WPlE>NQ0Y>$$z;3p&-jO-)5*FrCUxZ&GiwHPx2T3%=I>Fk_5a^mF|*+fPS zLiPY#G=}^r7v(s4GCNv(XhJbco>TqW7>bs!P@5%D@$yq$Gx=AZ5lrb|VaDcv#HPwy z^wRK#wNNh+4)rbGe5ZDuSgA~+mIR*>TSIG&LdHAI0{X*VW0+Fih?Z3Ifl`}jq}?@J z8NfQ3i`tU}DGxMpI>=$M>m5*CWSXq;(%<{xWz9SJUPBl0_h8)IyBMY;g z%>!Jj(U56J=Al1AWtv3SGuv~oi~;l$V=CF57z3V6A(LdB+F)*~+=()kY;t(0EAu79 z^SLxhD`=F^wrZl1DZdZ((uN~(xPr<_-2ZWO7T{4GZx_FAYf&J$yB3!M#ogVdxVyW% z7I(Meh2riMcbAZ4vsu^c-}!$}o`ewJ*_pX>-}ju~iFK0)L>8&zBY(%j+NJ0O-4WTT z%fUAK{!mwKZtPtwBi1W+A$B%a0q4yX5q~%vY7bRFNje{`u3m~P(;La2L4Q{*l((>#eT^>Ogh3!|UC5z2;^ z&;Z#K2kp_Mt6rAMz#i42WkYr`0M!R1^W4Pf7yV`S1Xi+N=*%a zSO0B)(X6Gx$?g`O*FLBFQgDs$_we)BAM#`+KZ;%j$(80^)SE}J%egY*L!qr`6Y~oW zEMlJf7M^Orq)q@D6K%l%B8PC@?QekjT}JDXF4$pxMIR}v984> z<+B#Hrdvu{K1kavO)LxT7W;3uSGK0sj}{hQgzjP|xNg=7Wu>vgFg}7)nV0Fp<>P%| zyte1Y3T638!gkc_9B|7A*lXM)b`1NBqqrFiPV`9seglVo8#+ujfpYg1SQX9D>B^&d zmA$bU(c{q(ky^oHzCC`sCn>8+aD?l;lk|Vhyy+FZ{;Wgr5V>4yJ+451x7Yj4-NTd5 z+spmdGa@iHFjMKHQ~;@9H~ovs;(D{PSVX8{?J2sf3$2%=2EtUXF-RScwXWte=*K2& zU)A2~OnIf^lMmwASZ_=u=h08OFz9AGq~9&yEy>mswyE}r-34-BVOv8f!mnTrilw%b zQ^;1_L~%Y>3tZFt;s)uW<&o{GE#zS0%=nv$_2L@X5^X)imKgtJq<7}zKH#q7U!1bu zfQ~VY%mt_VpWF@9!tBC0;SC&&6VR`k2rqaRYsFpacIrLJn;z}E5|;m!m1xOGr%?9r z2LG5$yEh^0d%DYgE`6tGe%6(&FCNk5bRPu`GL~5>YpeUDx0db8*uZr5bc;(a!G@E}$QCmn*|M$S+1grAVw}q)03=wiMm<*77l>rP^65 z1eM(-q7A!%yCTi8KCs2Cr>&2`Q#=Tk=P9uz`~>sZBlHk5hdE2ztu|6_YQGsCQnU=R zwQ{^o_%o@0Qjz4`$>WpXB`r%nmB=O(O{`*TD$He%ke}d9>jQpKE%Fq7lN?Q#V|(E= zR*l;t<`?%_-S!{W2iDP2BWWU^gT2X5rVrvyxda^}+fn~$!?cD98-vGJUK!ORtwXbY zU0w5BRkLO|cY6l=j%9TYzX(=vo%Akrk4caB1bn3HNMKE1%g=ePMy~p)6VfATbb4}N zR;UaWLggg`mBJEqJ#+yHTe9HbTiRE`qtJxz5GxV(2L|}FMN7n5$BM`&BF9<0|xtE(=fFVt%^R z+13n%xf%8q<#_Mg^!mJ7%orWhA^c@?N5!85LH9?c#7srrt&uYH$-shO?6=k^RsKb`s}eUkIYx4W*R z;hMhuS~sz{v5USheKRM5^cp3caJf|_T5@^G0=gSykRpK}-WlQKQ1$QuwTv=fD@z&1 zLh3ORu0IO}Y(*Te>_y|O5(DF zTRE@i-+j7!wOcICkaU94V;PL53QulJveoC&{^M;ga;id;|5 zp{JtvnI)!J4C@hFZ~J`PKT<#81xr)g$R*@Py`E3Fq)MiPj5G40Mlm^_BFc`)avH_)@=r^^8eB?cN?Zn_Az`IKO{6>)e)JFMW-- zk#l{dMzEggWm3ql@UiYQKT1u7-4LtPU3LT4%jlDO@sH=^= z>{KOrpEvKi8N>qSdPajvkHnF_wx9y8teZfr3k+G*{s_DL@R-qRh5Vl|Ew*IVw! zEzD6YnOE>^hnuo5r3cH zictRGkccllB+@8!H1xo0dfWP0XFvbL;5K(#ugjxnwNIt|MN^Ar+U}+|Ln2l0VrWtXJPEgClTFbQG7WyYLAiP$-sJ2ie$dKw!5!_0C zkJv@xkyd=r`o`LiXZgNNDk=@>L@Ds>ih|bO-^eiH!NNO6-D9e80b!ZAS5gye6&_~0 zl=nrh%}i>}3({z2cw#Nw_uc~6J;m6D)XCB+NtEIKr5D?Z+O}I#?Y*Q4*5>wO;wH-r zs1y=WYrn_Op-Ju#wTAS9=R2P7f=cN`VyM zbhbZJk*zGe7G@(!y%jeT>X=@7pXid{er>h} ziMOm199N{-!Z9j=zY8DceByv1se7n?^Z{;-WxDjm!ddD?u zrOb8sX{OO$CXGA8tmKDqGs$G;0FpH7k#_L@Ca4wUCgC*f7`xnay%~Y~fsEkT(06}+ zZ*5<3?-*BI?=Js%cjMsxkeU(a&Y#u%M^$$-kh+(83I#%;(cwh%p1MHG$(GYEF>NgO z=yLo_UL|*e!&k@j!E;{*p1vh^|Qj%4)f$#r< zYr`5i!Kq{^q_DjKX7wN-$m=(=YR{1X|oBHM1PCbhp;Ama_&dzgb4|si+e*gB)Tqe;aCr z4CXC%mWN0jDGk2q53{q zo5}>u@d5~vwFsI%BZ~I3wwcLVoDknJ@pRmegds_6{9Pmq9B`z9!qXXb3YvUF9HI7s zRyYZL&HlXaf3;-GaO-nhU67`?*w@+?S|&<^giBaCHY9nk=UA4c0>}kwVMu1ca2wB+ zd$9ts6tBYs>h@c{EeZXy~f! zyy$M~ofPO9tP&X!d#3bP5~1$yPo~q~nIw7#s7IfW7BZG7LiNBMOlzVs(obqY`%LPL2f`=RNGH;rW(T z%u~!=#x*7LrL%Nq&^aR`7dSbuGXHS@%=BlaWV$n2I~!+e8Tqq{xJh3&|L5Se@FlRs z&&fs6MLB7FHiuL7s1$M+y&UJDo769&60MQ-!K7)y6ou9%3zfdt)K&ThlHE^&Z1IMz z!e0ivzA?I}Yxw=bMDz>N`Az(0^u~6uL)mP|Qs~Ef`DWr7aRRg}t#DoGD(TW_%QQ81Ek%p--w2T1H- z#K}T6zAx9DUC7AP6w+nxGospbH7HZDb&*i0e8?KS?62rM>>cG9;cnt?k(J-o&Xp^3 zzOzVXdPdWX;Tf;fo26IJSeLQI`7pD7Ru1=0&nVx7z_8H!h&7f=xuovbHyK&PJ~B6A z5DTMQwu}3OZtGK_8hE!V`lXe`g~Dj`(#r_t#MbBxZAZT_QH%-A#oS_cXw*W&E-_iW zEy%(yp%CWD0%?nMSh^|^QmWM7vJag$0nh6!@dCP!31T_Hg^p5w=$)oPm#~bT3|ifG z>OSxqSYdC z;gnz%cy#Z2{&p93_sMGHx|;dIx!Bpp`7^`je3T((d`mx={v>@&Mxyi2%%ZO4S%W=& zeUk%naCA66`X+V}F6!+@PU0i+J9-5yp2vG!GJgmvvl+ra!Zx7lCaky*YJluAKx zk>mLnyaPEZUBKlYBk*vB4+q1{DZCTfh;}I|)wgT{$N8OfL$X@uTNYX}krQxT^om*H zS@Dc;p1%(L(<*4%e4LYI*)W3ybG9FgS__kx%?Gu^WUz@VLh*E)FDj%7t%PmpdT&6l zaW_KEnOIouk=iNc;GYxykIriW`S|x(Ef^ z94Ii4LwS(K^?^6KI_{ORD9IwUceFb%G_gHOsow`nxT=br$Fz>!h*R-71Uc_~YCl zz8Ul=@%*m;XHL7Bc$^&LnNsKtokWkR1G;Jp+0*PE_}?@63VbIJi3&rjG>f@NCsUo^ zbSP--&=+fO@%V>Edq&QNmxqQ0Px@0)*VySvbB}X>chy1y%kP=9obNI^XFN+AmbN~v zOnUqD%Nb>y6*77E7|&CGnc(>FkI1RmPE`M=Aa%4f2%%T$uIwZZ8cn`7zX&}(2NX2j zg`=4Df1(TCpI^j3SOL6t{*!Df6b0Rmwy~0HJYtR>{7|7vY4J|j zX91@(qoC83@jl~HMo#B?*Ch8WZ+ZVF|Fb~Jx;uylc*JJWLAO4~a-j@Y}| zS^MwyQVuC@r{ksln{AFQhpmk*!Io~>AuR&^upp8E7~wp(iT%MCOk41XIxtPpt=I~O z|7&V4b(<_imWQAH76`4^LGnQzAI#|C&=zoJPtc6o8_)GbeI;nC7r>8qp|2i}eTr7Y zycivQ5J`v}Ms7&QaN|&sP|@I9e`4TI{~dppFP|?ju9oM#<$cGzr@d>vv!Hcd?o0G7 z^Zo0y_<#9_2W|z;pf$`#ZbYWX=E-N2=ju1D1IVnG;Xx`$YGh?ob*08+Sm&^ZrdL_>N&31PooyIz(&~`+PU1# z6DjL7>%4?_R+rUdEoYBNyV$?kCrm#iBo88Afd||Xj;#)6MfF3huKHLjuAhO*zr3*r z-MI7Wee{W&Xv6gXkWt(Zyyp$@x>MOOAwo&__)z?bsR&?FskW$wPx`HhPkNS3Y)HhCn zd9jN<&-6eQ_X5<6z4)8L6u|`b-ZBr>x<$xiiAscJtMpYIAx#v^3zR6~ z$}$Y8KiP#Oz9O`Fe?pg(#m2LLp-!J2o!2{bbEuG}Q;&)A=;6%;r}q(&i~Mf(g7)NZ zW4=+q*st#~ZfUjj4{CF`T1Kib6jQ0FEQw`DD%IXt$Cx$tR30fWiRO=b!RzoxC&j8O zBjjPR@6ouZD^fZ-Cw3&-CAuoIATl1zv|+I)v88egd8}Mpz9ILP&&8(7yJDZ^o7j*3 z2Iq@by2&p20mj*)4#Y^0S6eE&oF#jdncClKP>m^!_5q0tGc?9ntgkb6>phUX`_lY@ zEYHT|duj)H8iZ0m)eJSdpUf1H?C!FCxh}jYp5)76{~RQRp>O;oHpGtl8H(89mVA;2 z|1Q8?)-Fp7xltKXe!Rz6HbX@nC*>DU!Do;n3_zC9Zm5C(U`Y?TlURtwh9hKbbE%gObcuFg*R$jdcUhx{mr|iL6tFL}gM`13k zlc%biH7{mCA!V0RRE<-&t4Z2rElw-0R@M4w)ztRdY;@|g>kYJ!_6dFZG^D!SHm-uq zvP;i_muKoE?VQ#TSG3ijYh72)#=gk+)y2vacu)4?HzgqLVkcOL3V1QE(*AS(_xwrU*sUQWHV{!MeL7u9nha{1N8S~sMCcHPl}AWu@~fmIcjUI7u+0Iv~T)o?Tz{luAV)hs9sc` zA@$+Ds;G8kwscn0z+O_60?7Y+C;L!quc~f9b@?~tymD1(s;UYp_mJD7Yc@yeqQ+~b z<&jE9b+|eR8izr~Ed83+QzMZB8A8{l5b@Q>X)Z$cClbo3lBf(0pjy$LnQ7=Q#F$bn zas;`3Y(Mm@WUf6Q;~zs8)DsHr%{cQXAO+|sPQEg(l6j=pVrTIO_{(Q;68I^!7w;nl zWv;MScmlV}LP3TDWGkv6{~>$r4YYKBvmz@0QPf2GJ?pk_T(N6-KYB?2n$1>nevpx76if%IerbBn*v{cY|8g1Ik4g z)^ri<2}|Y8$_aS_7>Ve7S0+Cj3G4&f}v_W)0-mw zJU4lVSWaf5zH=3ira9tdq|;fVvNCBo@4aaZ-CYP479nD`Y`Ra)?K>}ib_??hka^Z_ya3J_di5# z08;-kL zUIBlvbF6&Ki2M~B6de$iqrO;2^i-Zje~-?ImXxdFYB5I6h}M=5$)dVhxuKQ>Z>@&9 zRpa$KP%%Q+120>7R3KgwJHU@DMkg@SaSu{~8vw;>E9lQU3WdcP(q&5-i`UZBQr#M} zb_Q#9ux-8Vg>AU?sr9zCnynxDawTm0tV1mIEn}=DtrqJ_OIRY23&mMVO67!n!f>#a zr||PZW2ww+rzvJA(+R4R=kUrO0%d&${LHmL= zS5$Jy9ib_np&r#P8V`u#W&z?GT!bI#!R$AVd^j2s!)owgC>ClBH&zGlU%~$VroqF$t==WRJpRFOv9${S6*-~I#9eBCBLru8 zIVgk%^F7(a*hyNVB0t@dV%uUd#KP8(j-mGA_BoF1@lEZw9ntt~Nh$Hi9M$7?C+10f zozN}eOZV2oZ98pAu;jDENv&~p*ox7p zLT{%2hTb6!=#}n5lPLFSCUt$|$S&N;ghV_fZCLQCua(lpXeo4!Wjr~ZCCO?B}@SiaY z=@$b)63jZ1W9w_CX?w#a0;+f_t>^`0~ zAZudg_VkSO1m_ZGs|;&e*UXX`InxLHD)rNoc02u5xSp&8YI<4(cSM(KCCE2?8F37A zk8QvwTC0hNMa??S%Ggd?Z;J|Zml}jD+bBFU@93*$4{d`vff@zkU_G-PF^61dW+FxX z4O_rk#vYUeOEs(0?v2};JTC{CGB)YAq&wMQt4)q1%}Pv8loLM3AF9juIYT#OTk(E&8f}4^@j26i{tWHoRpe51rqifuAT1q0R(*3KJ2AbG zbsj-wzhCfqAQ(ETbjdfPr*yFe?k zJ+7pEg5#`hwxzOk0*S2(BQYiEA!aqCz%)lL&I9?A;>4A!nN~{cu6NTv8b^sT)JEJv zP7`XyBV54x&QUw=R$`CDzmo4HPfSiu${wGRxHK-=vC!7SR>U?yS|*<2mog;WQ6spL zOc^M+^D<>o{rkvvU?upZ4xs{?lP|=*!l*crMXllWe(1E0mjA!Tpft#=B5!DCMBz-zM&g-;!80X?fDM#OrbI zg~UwleTaN`BHFJ+~5M-IA{_I zaHmn3{XktNs}uRK`xgS$@L2S1FdV239{6oH<;m_E?)GI?&ODe^C*y+a+b>_nn)Esu zE1VD0PC2usxzbBGmt{rWuie%C&mwu0_8 zmVFkd*irnKSx)~ikNu4~7LR58-GB;1TF32+MSvntl=oC0}E3{_v%}BjS&G0*aAOAGZ zbN3MUK6hMZ-^_*ByFX^6I_GAzNMGbkOaJUFno-hOB7LdTmpR(A*mJ;t!apX~P${C{ zRvVhTs0J*}JEWyjytOj;`M+EJmbI2YQ1v&YC#bRZW;)U%;C#P?yQV_gN3}i9^v`1b z zPaVbVdF|b-Q!KrtBH}Fm5u5=Vah0z_*Mt*dD@`y<=;gSJeoNJdM*R#lfoqW!H-PC$ z@2526;H(3^W*nHqwbVuOrRdbqgV3qKSUBCn-kA5Pr;U4hR>XZev#0BmvvgMD%u%ko zu5KB7vYcu5tT|~ZS@S*1-8SE7fBnd&7%eYUt7!Y7q%MVC$#v9@{6y~<@ zR@^7mvTe8f9jn1kFP@m17*0Bs+$woSa*squ;=+WKxCf3gwz?RjR+3H1fIG4{+m<=b zY=(~8M(3n#G)>JSpTWo60O=^haheX2|Dlqw0QpNJK{T#xJkpyR+q8}HpxD6Zjwl_@ z10T=);8kBKe>ZgG=D8YpzGN-TWZaiBGCiNu_hd@0Y8ib!qtmW>Qaw%&KaR1~{I%yG=Hm1Vy@{9gP7ZlbD-^Gd6 zwbtQ|{~X8TYbR7lAQRfg=Sp}Hcfc{$KG=TP+T8jQ)Y@mlDz+fkjjl;`A{Us8i1LJs zNJc`yc9S-e@jv9ckeNn4sw+By?TM*+W3mZ#5`3nTSxCpO!!M5l#n`?`fIhpPJAp&S0??t=cIz6F8l?vJ68UPok9@SHp((lOc{{J=*> zIki092qd4?d@))VC1$C(9Niv6B>C0SR<15GVa9VhdzHRP-5{WjBlAIRnj2Xc0rjia z7d4Ud%2%WkG}EgSjfl16NNOtTYjNxq^h?@^?WEF{NtQ#F|13=;8gGY7Z^i$_HR4I3 zIC@Ga+3L({HaAm_Y{K+G+Ri*vp?j-)xmkLJ=oRZ0rMJ2Qtf2wwbhSOJ65aJtY6f9>F3SNc-BVU@_DYK1Tjhjz*71{?@8ShXj#OXjEFlWizqUf^aw%B*G1 z@EOcmW)QD%HIUCxpLxNC!Mb0AeRv+~zne%4;WWpabBzY%Wz%b{p?bp+aGwfO+xSZ4 zGv+l{jeaMLXa7O3W(uDw^kgT9lVFR!4IQIgLkVHL4f678Qv)>SSSev^uv(Z$=N6|79}unr17c zNt{&*(=GG^T39`X1h`qzCwd$8s#-f#M&F`_lyzaZvOZGD$m_cpvnY?F?Sp4f?Nj{JIK zkCc;JAF7;iSM5X|=Dvu9sZ?&PQqF!nc#zR8`}LRRP0P91Q1-LYm2X76H#*vK``kit zx@@c{`Z8;9|J4Px=HtdPatCujYa|qkZl?mF=5Y%`$LY_;F=9mMEz^|nhgS%t$VU+u zU)nd2azzS&YndCA%ZiE`t*9QP&PK23FSPwiD(Wv6V$;le%7kEUri+}1XccNNS0WC_ z5~8!1okaJbC{7Gz5)J5szSC5U8UfQ*1cq<_$DyYNK#g&^@j%XErntVR6i~y5d>k-nk{i*cM zYMioIS+6~ods((R)x ztW^3CnH=sR^l&d1{vuxc=L(|QI9SsKkV%r@QoR3dnAJX0R7G6ilwa z!sp_ydCE#1bvRYH3F>X zAnyxqwNotJ(Zye%oNjFrSVraG(!$@^DwyS+!Urs;qyL$QjSf_Zzw;2 zIpExpjC*PQwNf}CTk7CA!{2~+5sTc_;%=H#GM`$e3r#c6SU#IQ{A>AM^nyq>X13Cm zYZ)3%)ux9St0E($gK{tJyLj8Tj#zNgyi1?z_2Vq$QQM(L1Rh}ZWW6MVmgQ`*1 zFjbI?qT}M*`%c(@MP>*O`DKy8^ir`Tc;k7qRd60j$WE^EttX?g}=5gJdk9*E_>;Q>jLs(vFbrr0QAQ5~_Gx+w(+DGRF-+(UBP@H#VEt2c+J} zHN{0+`^Jh2^Fkk}VwQTL-R5!P7&IF@&9lZ+wWM{Xza)8%KN_K!W0BL=1#&;4i1v!x zY6g7A5(@fegP=P@`K~?ST-1f&ZgP^Nq&u0~0A|SFnZKv}9o%Z%ApPJ>UX`nI=ZO?! zh0>i#;ETG7%_oMavQNiIvU~3i|kr+kM@dK#T~>~Cv>*K! zzd~NEoBTGpG(8v&wLjF4%yi<5GLoB5{gU_TGp$N^3MzbA$OA|L1?w2KDzG4>{f}b> zioM@k^7PlQ2}k_Xg$~F;4aSlfmC@uUW>GFly`<8PSh^awg}&$gDBM;b+jeJtr0n)i zL4`>R>`N*0vq17^-!ZO&IZ~fuOg7!(Hg|pdI5-6BNAJXU36(O6(4vuz4adT81O4Jo z1V0j0tVjLFnF9Q$05AQbas_5F_Y>D;>}C&%Z}hLqdBbnntyNI6>kIAC@C#_9s}Vbm zVuqEC)6&RhmUphz4iA}EzG76j42zrOVq8^Bd z-tZ~0mA;vLWNa562PauRDmTHqZ5;1$C)l1VGg-+P5RG@_iHPFtKzT>DSXpwr_LP2% zWYtH?=J+e#PwY*-A#<5}98yph-5FfQPSiIrer+k?i0n=HN7?Rqk-R`n^tVi!?Clsg z)4SHPLEfaUv$c1V2?H{hXKU+K$@O{{!LLd9nh)OoagwuCTvn`%6qjBiw~*GrK8G~D zpKLj_P3T|z%k~a8PRL5Fk-SXash+k>c6ze^7TQRziw(Ee_0B@gt%=)fe-W)tO~om( zh2mm+E5{iX^yh{q!6jt0$T!kGm>4R}R+`Zz!tZk)|5UfoZ$B2svqUqWk^R`uJt-M! zwepXCmr?kR;i?@b@eY!=oLd)!~+pZjOpO3VL5XLD!8lcC=-lS;Np-%GrsPH-i|2~;`J6L=aY zx=Yb%wsx8AZ74MUa%L+L>dXEWPDxHmRq}5C$=YhMox`V@r^GL1r?NfyXJ#JUmSttu zinoL(v%ATB`f8?`V)DA@nBzY!5F2i(?qlM<$!W$$Wq*8TAVs-s-xWD0JqwNG{^A;D z9m{qhYhuE0S^KlGPA;LTI#(~mZ8KWMhKp4rs~rD$x+UMv+8e*c|B5(dpXqF!psIe4 zmhXGIRqANm6!W_V+thF)>0V^HXhcfbhq}JlU+5L}WL0JJ8*Pm`LjB-uN3CEs_Lo|M zl$kR6QMIo=%5qX|5?d!_kw!E_-eTQKObacQoO(Uc@9)oiXP+AmuBb|EDapKk_}GDCz(kR7z`2m>X=bute=c641d&3vQ*J zMS7K!#tw6WP&Ay$){n+nK15m(OUUk#AC@oP=Tb4<0l#Uz*jj4~-%n;1d5Hcu6weTP z4dE}jGEp=#Gj2{amt049fxLz)>UKEKr%@+k4A+8q<+!!=2+tgT}-P=_EG+|7L61)Ww74X8U`r72Z#h?qVbqqLVZ`_%o=QO zW4~IR?qyPRML3E4=2rHAvYuW=)zlqyDJ_?|mTQDO<(foEI+xj3X-E9W=2!n=f5y&< zb5$9Pzyri=iqcM#BZ&XNnJz;LU5Xr~}3BZgVDCSzSv_G;@-#;f|z0qkRY#Z(-U{`qGPu zK4eQpVs{d~$+_BEq771sD+RDw4eP`S(&f|7Ki<3dIFNa28y`rBNOEjYtN=I}dV80?26JOcWt6fqVOf zI0W@|)Jz1s_N$rK3>h8BB}N^PPL~j^h#)@G1HC$tC`%S68=I4;Y~awAhc9#`NWeYF zDQ14~t*4u-$fZPAq7K0t1nzkLAg!dXUqhPb8)VA~7_UA=S84z;9eUvin7BKT-JF2l z=`LdwSkoKf5nh87_7`R|^DOjzH&HQbW_pkspPhJsEZb+`>%JilgAaboxMdbYg7*vD z-&VjGeg_gPs}n33%cF_;})z~yZM;=^!!H%Al)xp}?$-rP;h#&69+CAlcl5o9FO ze}@CQgt-nsX*zPJ%M-uwJDR}LR|;v{c09lT;wOIr3A-HTrjEJ4)4Tu&e3DrNKW(C! zhq!5K#(waJ_aH&O559LX*vNKb5QuLPb1ab;X=&FnBS9|4{bDCjlB<9x{R&?*gP;kw z(GQ$F7o2K;;3xRZ6jYYVgRq>Pc#XeyGXCn3=4r5^cf)1h$z(wVy#p>lZDOj~33Gi0 z`XPtGS}p{_=t}&~B=~Ve{I>h(srxZLUhsc&5)bfmli{AshG#PezNZJ$z^f2_z}{|( zug*v8z?{kf1xQnT&Hv_RH*~-!noYr2UV*gq=a_u~Fr;;OP6r@U<06>U-_3GRqMrh9 zv^0LlO#EGWF$;R*negH9 zfB=05xn1u;Cr`%bN`gdv7Ci0@jB#<|6rRhngl0a)YR<%8(g|;0q9eH-pZDPRAH~OZ zV)n(r;w}OcoPu=XLa1JfnD^h{m3?aN$4~tY&te>YXGMI?eTvSX_yO} z%}k8Q|5jm5Vg!EI75uLUe#af~z3*YYeZd$$!?+0GZR7zvo5QU7`2U|>Y~IDcPCSly zTsev%<+m=z*n<%$ioesqoT`K8<%jtQ^DK#ILbN4rn`Q77x~bvu{VyTik4K*b61NAV z*%I`Db{MY(cpSq)ux?AVhH`B_9?cY@2GI=fpCGmoIf>t~Dz+0{u!l^>n(U0Zx`^;$ zKiQ4%$^--J0LJ4X66YUdj?6|1{CRw47gE^&`TsQbjToO*AV4g~W88_yn-eQ`3-;C4 zn7O}W&1}Nwv7=&)I^i<~@rZ52L;Q^qthE@P=hNm5)Q>-71OnjRda#S{LoVQJ>=!rj zoNh*KeFHx75>=LQ`2S+ep5>s>{$1TkmdD$Y*mq`_HO($a)R}|FThnZZE7@o;?T6q| zY{mNPiMQ+VNIT;FDR}K`&O+w$9_$?hFed+EWKLoJUc>&d1BvI~Fhd{W-?MncOYoC7 zGC=)R#%*FX9vYAs69%xrG_K9zW?XVl8oi7>uzUj`3^@?~aHSmj`P{!ESXCO#K;n zgu_7j$zl4S_Bsuk${-LBR~UupW!x zF^wcPV4YsYRqipNfcjpRY=g1cLS9GOZU$+9nN*jmNj0N-QGZfvsAbd^Y7evm=Rk)+f?ue78xt?T<*kIzcC^Phm}zqD}NJ4 zG6i0uaiAnThWDufIT52(6pwo?^`1(CPIEndAEeyhn4aMO>;S{{GTdJtxb5<><=Dn- zHMS*qT5(XjGwf&PBRVvDn8i@*bbvmLW3JP4>49{1+DBcWMpBiiALL(TQ_@0S!E;p< zyYm*Tw^GQAUWO6K1}gg%eIsbt_29Su1wGH8NbhLZ@!mV_f%Z_#)ShTRKxy)5Hc+*z z;2y23K3AWwAJQK|lbE3=7`h&h=cE8?1|5u!nD?VFU$+>Sj334`LqlSsY!o%iVV_uz z^=`p6ej=^_=|oL(DwyC7Y6SHU2vps{7)qzB!wvoz9)aedi(NKIuU?1VMMAGlUe!TIYGW=wrU(=X^Vk@IcS??aV1No$U=iq}%rkLq*vxOzg} zhD6uJ>OQ>gM>p;~K6+jK1Ql>9^au>TGN)EX%cs@VDrk+hnp!ih3erj|Yi04asrEZQ zQVuG->Ubu)Yn||%tb^X_n08USiP8P0`JpJWfoxGtZwnpj1bv-;3iCatm&M*N1Ec3O z3SuW+iK_>V^VSSxyE0@$%!G#+x5ij?*O6FXk)A;x1h*>}(;PI!V_@Cq#mpMTo`G7M z;NrM4Tpdu{`harSkL$uU!k9J$FSaM)- zS|BgE5$pCluG0l@UK)k&;3ce%IODcH4l^OD{jH6{BV*wf-=U6%C%337LXG%T`3G9E z4a#z51vq$9@j6!-qAXBGE93BXF5a$#25&pm!Z(!`R`egZMqk0zy&bNLFLCB?K+Z)jObBam zG#vRk(=WL*J$Sw3VsEG+{iRZD4K-0X&eg)Bl0S`3~ z#-=b=2;`FD7`;S}1n2A*xSRigvN(jT#1iaNE)D6FXhpr>saKzkOU@y&*WKnv9Rb~ceM2(E4}HaBv`IhMr8M415d3qSF1W&x<_r5J*F zPM@RaK!s2WG{QGn%@eVvlc_NIh`fw-w+ZWfDyS(#L3*l376w5niZjhK;wV<*Fr3=+ z;S6^ir?wRH8Ft)c_@qkd53&DbsZ*eQy$)|f3FVbM6}r^>v4ycBvDeYX(UxH5y@6hF zRisa(TV#CX52*b6MaH5>IVCbMvM#bR@+FcBYF0kh2{fUZP$4`~I;aLT2zLFgUehp) zO}M*>fZs#o6cD8L;R;t1{MC(kWa&&Hwmx2F!r?R%qcnis!*+(gs1jU%1p5){{#VQu z<_U9&xd3|Q4!k|WEW&E+kE>%}W)xgcotYj?J#bb_Gs%n}Jh@9)UlY-}?*M{eHRxq? z(=1dhA8_v2gDkxPR9g@h8$-X)i|T`J<~V$0Ak~{{3ARCgio+c9pmJ~&rBh}2kP^- zAXbvV}{7z1w*3wz@YpC!}f;Bn= zY)K2)u0!~R{2t!N-{eaQ9$o}-Ap=C)E8wQDfObrM)+$&0||kNfjfZ_ z!SlgAp-N#o@)_yOn_^?+rb{}0^1DgPLrwvKr>*?-b#Iy7bTtE9#U9pl2~An#&dx(#LMFOo;?yf`%& zU$c}th}F2Bnu1kW2mN{#mEcWeebPnTz&YQJT9RorHrD9*^h(+q)uY^zOUX@Q)1rGL z@8MP}9O@B#61eC8%YVaH))(+r@aFZt@f`BJ^o;i01_R-?r@ps=w}$VWucg0hzzk#s zONSqX`(uALXRx)GYU91ULmWWyP2?;6tLBpI&FQli_Q|R$HbeI?mdFl z10VfYd~uKSZa>JGUlcpiCLcuV?*_!j$D21*3)hw6qUcolxe zX2W^27s z{oz6`%q0cf*X>wsZNWzNgEYJqUY3$L5Bb@xn2Ae}0@D|a_6BTq_ILQzX_hvhXOtV)5ew}=9%>eqPCdZ= zV}Gg~b(dT~+Q>_&`Smh)!fj%K^0+7TP^-X&Jr_-iIwIx6dqc~D61=dp{ZTOJ6Mb{M zhR5{G_5AQ$^PKZM^@Na6aM_#U>+k>Qe;F`?!$XZDHzVC+ZRGlJU#-!u>u-!5$kMfd zVb}`kHubp6TwDGp9M4~clR}128>7(xe!S{pSB#KFe2+P`UC1wx!brX>Z{yQI3cto( z0n2|n(gxbW<=qWL^?KZ&@JLR9_v|J3*^fZ;|BBW33$th}&fheb3Fq%*c-!ycwAF^G zjeEUUI0ZDNdHNFO#bByARR)T@$8gn%*qccz4oS=u?p`RoCs1FoSDzrek_E{$RLBh6 z(-$ze81J#Z#vlXyE!1M;iZ8dXg|Dlx znXiU#iBIu4{rdv8;DFHZ@S;c-MyD9E%{Qu*wbJ?+aC7~rlBdBZe3otx0?j$L1XqAR z!luC?H5$1SBk)+qb4BrZCnDo!Gglhw<(GJbufY`f3{U$G+;_}{lYEx22Gq8F0tbdb zPd*QPW^2I_9>({9ziNcgNVv_n;UYL?e`mY17r5VTQNLl&SxL%dAza&LQ{SjmT*Kbr&Ul2$g8tJ~C-nvDW95hZ zG}bVthaZMd1@8x!1vdHf`Tp{ha?@F7{*S2hfODei{{BpQHrZa-U3v#a5fo5Fz=~3= zC|1C(JcO=qVg?tGmy0q%rCs@%IhyExnD z1k^_N~Dug3jz$f@cQE zXQTq)nXJoJ5w|+uKeWK_s2R{xofsp zzg@pNG&6QzBJ4Qb{LngTG&-*ITw;D6Y-V3Ji9i;yRZKq6>|fk+ks1Io6n4 zU54*T&l2BM%z`1|1=4JTGSQ`fcLb)?-<<5(WTc!ggnt)fJyxdO}gBR(U1o7VgSh=dEL%RGI3ZC{5gJeZ+ojwfjTI63-)9 zb>8>0m*>vTYsvdKKi+Cg@iX3w%-PfetjiwhyT%%lFcMA1qIjFs@6kEYxtwxyr7<>Up2eo-6k{oQd8NnU~~4-L5DwS9(7<=U5}ardq1x0Rq^EE9d$#G{MeA6C{CS1 zM&4n#b*tIQx7~Mlpi`dTbJYE!(Jj5s-57a+U|5jsV z!`1cs8lSE9R6o#g<&h6+9rcOIs@kgRN2-23HnzbEuZiq&bn@Tp*_3@@@I}uj!Rg+F zKPS6|$gWRZmqj;5s#^Z8-OyCi^iR{Q(6y0C(WB81lV@1n9L2tmvs&eJ$e){YZtlUX zV_v84YIf?~M&D#nVnd{oociNhTDRyKO%2J;r*=-7G6}6-Rgsap5<-ITqTDK z%L3!Qw{q5UWpZb{C-vKYNWEdKFo$7}O>!TXRUX)r{Zd}n?2of&2RBm{&FOx~G0d}{ z(?#bRsd&HS@$q;pkhm-PQ*s7XN_R&4CA!uRkKEsKe|7ujJ8Gh}k01L_&5Jd!RR$Yo zSJhQsSlg@O?W#dd-j+7W-;Gq(g5Wj5e#NGHesD$pLRYVXaaj+U=X*DVS2i_99;-gF zes0tJx|WuUBkx2u#_mWN#@FWKzT2|vv%3|Zk^fE3yu6QmFM3A0PiM7Wm6$*cf|*TE zHV$a|t|5@T)$h!DsW`tpnf*!0rh;v`CzW4VbbnE#pxisqdx4`<;<+T@w#j43&1O%> zrH+a2`IS=KBsY2K;P)VhV;YPdUlT;EM95v&CaOdcH*KzKWz&XV`vm&{77Iq8HC_26{ z5&S3Tdv8xpsR)}F60_OS_|?48`GWg*?;U|T-Xs34S0Q>>srNa}^k=JIsvA4aIo9Kjo2QxwY}Z#zD2OR=rR^ zta3rs@}}czqxFT&8;;e5uTO1_tuMJDcvjYl<#9)MPlv)Cj;C|l7o<4x`*6x?E{}|D z+0^)6%ZBKk;alPl8dDO3(uZAdq|WfY6?o47Rl&^MD>xaXi!(r8Z)M`U)CH+3a)_Tx zz2{!x8Wm{C{>sxjcxCqR;0p!!%d;}rLo*l+7y zm5yz!;y-q`%}RK8`5yFN=()yw$m2E3%=^-hq6n{EAFBXc$IQ#rGqbfN& zHZV~V+uJzM{cC7_1iRn8>9S^}Z7-F0a2mtMS-%*88E2mD9u14m}*HO8!;zditH(t-A`U zCRFWeKG$5{DxKZVu`Bvi!}GAM*QWT}6z-ALIf(PEDk1 zh(x*33Z=d?espA$>350$iom+u$AeFi?Xb*u8SB~af_Da&;71A*Cw`;Vi7Gic$t$gm zv4>()Q#)E}taSLT^q|nBu-kohRW17(ryjg7 z{P2-`k35>Xui=AOo6=%u!u?17>Qsy8?1Eu=MyrEGLvk*6PD-7YoDmvBeVE6{`yAig zp?PM@xTbidJpNA1@BYi%#=j;1uAIq%tg^Kw#l`K>?&Ue7y~n$~z7gr49rrrEiQFH* zFy0tlmd@t9t4loxJ>Ph{2Ulj_k-aEqN_LN6x8NzkoWMK&Wp2OYO3uLjBvBtf#F?xQ zwEPeb#=mdObKDehMZL-An;vc$75O>5HvCQK^z@_QhZ=rO?~NR)elu}vlXdj2=E=3q z2YOb2-Ta@X{^99`-+PB8*X3X0Iwdx*=<$L@*^A4c$(!oxayorIggQo>JecChJH7eO7?Xyk@wM0&7xvRN% ztUNj)**CVo`Gck>Lz5a7g|2sf+AyQJQ+jpv=XD>X&TF~2Wkqa#)1=0}#}?LnTYb3V zri3@~X6onMz0OyR4}4RI&0d*1vEYOJnMH5pUf_Ahs!gnlR7F-NPK`Yi+Y#*?`nY*y z!`$ZQS;zU!bEuYmK~Af(k;QMc>fI)g|8nkAMav4>7eA4o2()&88=W6HH@+yGpR7qQ zPtHx{v8(;1E5|$6M}&8FEH_v#_<`=VcAF4w#SS?aQgjTUhvNV3K#DvpxI%S;>W*u{Jfdtz~}mFO5qYe{NjR^i1gF z_-EE2=NwNwcX~;@a9nw$d{ll_;l)K`3bz;L<~^02?=I#XkB*6)$Rzd#ZjDc-7RH0b zJ-+3LXFV2pAoxn|v$_8%_#pQ^>L<+ee##kMz3{dLB3p@X4?p;4|mzBiKJHg$0> z$*XW*9N%2hRG#a8*EzMIqvPk;-Ja)tH^jO}BFRqCbDF;hwQjZ=T8A!)Zj06?4w-K| zUh^KwjTYQj^ioNxsAs{c1$ULy6@QVxC%-r_)p3KFO7Gzu!uhGjZ5XK|Y83aJp0K{P?ZWlbgFoS~aD@t(u={JRJI_eqqDy zp%yMrNT@{XVjLzu+sU?X}tZn?K=o`^bIIVO}OkvW)cy16+1m$iS4-Eh!A0PfVhXg#$vOh^9*v6cXMviNXE8-^Q=a3 z4sC=gb)S+mIsqSW!r^qD;XIwHJ3X8|IkE3Cb}H5Ldo&7J<{kJRZeA2bY z+3Xlb#i2FqZ4Y)-ARi{0haKmbFC||i>dtg7w>qSHrv_kMTtmIu3#>ZpRjMs~mMUc* zr%!qnCl+nC9wu7kMYN+j@LE^=(Kj8jhFjw z=UvWW&KypyisG5e<}83J_E|h!naaM(jhsBN$sFc5f=}*8;%5#TCz~G{pAyM&De+)8 z8coEGts!=2bh=HtJ+;aPq+P@(77-=1HWOFzFg3icu(GJRKEm3c>cD)qq`Kj=noV$H znRNn@FN=s>@v$S{HrF#_!n-)iF&+RaQ0el{>8W7{MAMMNTKGVhVi^AdP6h{>6w_QI~dbMVs73fR%$(wUW$htf-iGFbNiBgbZY1D znnn!CBIY!SXoSf`mt9X(*wsWKE+by;OX72aJmDp7FF-U~kQl=hmbs#%x)>kQ>Q+-l@o#Yw&6GH*cDl6{7!t3@i9dX;FW&!C+odvBK#clIL@(A8jP4nC*Zv@|3AuJJ>9CRq6m==LWD zG-fHrE|zu>+xLU@C{co|fI!&sCzUkD_-c~N*NiRN}hOv544O!>o1=K#Cu<9)=r^nO6Ec77w4>|OR+M*!btW=p!8 z*^Bzwr#OaM2ZnYy^d_0w%{?xjLT&qQ?2dyiuO`Mpv zpSy3f1}6t479}>Mj#%@EzkZdN#vp$9ACm8+ez5uxySSg2vu10aRhl|T49la*_SV`| zEY+TKL+2sO&$lk7#p|qF(@&F;a3k?#XTk9+4bDC@4>`W(6tyX?~^R127*3YTEwx*=|-s&N>D;wSopO{+Vn1-d(%eRSI#p8UA>|Y8_Egv9 zo00uTUJ-RGM&)K_$Fs&#ar0xp<=gFj+C9}loE~TXtcqR{YaK60-jsUT`Y(ICl z@IoW#c-&dW>GEFpG}kq*b(}f7lnAu$uAu8VDkK~rPW}vXj`nkQ-bwKZ@ylb?(a$0m zhkLe+WY6glYV%JD@8)d6ZIO}T`&te(d7BzJaj0+ObM+%?f2jJjYIp5UtgN|?ss5*^ zJ2T7O!yKL7YuxEMJNt{GS4yvG^+Db}-UqDikr!Kbg?l6&=6lY?j%~?TS{Bv3S-HA$ zOKsgKy!hz$Ff>Wm{ig`5-X^v-ZEaIQ2B z<0q>+^-`)RbvnYKYmhTFtJTPQ#Es4tSR}4;(s_mV;n@Ecu z(OmZL*0(&|>}zaW_h(h2;+Z4yBLk||Hyn$6Z(inamwj*WgGxBNdfyAcNUu1V5i4Ta53*Q#HIs9aN zN$OU|Q1A2pNx^Bs8Js9lO(nKa*3iKItTO|@W{u1$rDnl1uA9y4tuGRb;&b9F5aREHcA9~EDgXiJ3M zOR-&%>7n7x-#30y|7N|b?uXjfs@qp3Dz7`%y|!E9qS#Z8U0E;X9?5?(|E%okoEOp4 zQ|fy$`-!3rWzUueN=N0t?a51C-uzG9m$lW6A4JwCuf|%wIn^WbTjP1uH%T|ljecw% z@K4Qqq-agy)Vu}3?*7j_-+TV+e>K=C_tU%qc|QbadUGAOCc}{*L+6Ih3f~xgi24*; zlV<7@Dt(SiE=Wx>9&w%JZ=Z8P!N)~2i?1j;n!hinN1&(wZB8=}aH9Jm-yS4cFV4By zlq^Yn9v_|DW{r1rao2ekdhhX&bm)AFh=Q+O-MoKLE2Dj|ci=8xXZK2TTDrix$Qqh{ z$>?q7li9R@Y{8Au7b2IjtMF3%*<^ptu$*c2NS&PcGj>OGcyv{CUThCFAqOT`Cl)47 zOt|AakR%UBdq%RtZ??>D8rQh6K2&#O?b@1F)xl$%D)K6Gt5?<~!ao}S^-sxLT%6nL zvBDQ}4*3Uq^WDRIhqKQve5}=nrEivu$QkLH6RWNN?AR|AbE|Hw|2s53u_|>&>e~2w zp&^a0)aN!{73!Vr=I)X`rQn_7RPiH){@h7{8NTN|7g5=?#%K8F`sR2l9XD7n#A-ur znoF8SHl5OPD>Y-*#s?A6aE5hOx`Wy1=!{>hdv;yk=7N`t5`{Mxl;^z{JU(llw~J?L zCW81c>Q+X~-s!c;O~h4nRsa`s!X!Ba*uV*C5rPG2PfdTlI}y)n7-e$)V0|0 zl<}E0Ew!8&oY7QF{F1t6{gV~(Z(|#xy`!_DzeIP&9*P$y#wJc;*Yx4|kFlR)*T>e! zu8Y47wcnq(k{E>z@h-7nA``-|v|QaX2H*7L=KlCLZf|EkyTV3xjNn)xjnR}`Q(9RY z#7nb2p=x`l;Pt`V0(1O>Jw=Xk>%I6==07PmJ2}l-W_)M5&D)G`Qpw~K$(gAKjE>IZ zyt{q*)amW#-Q!wjo^K2x0=&q)!7;%3Eg5Z*)RK6gNFy-`jm=#`)B&V^@PgR6(NU3Y zp&wejM7xZLd>MTzzAe!`d2RB6L@@D0+!t?(9gYo&&5w4ETo-z_`KQJk8#WLF^L+iE z)KQvO`&ad-s`o3M?9JX^_fqq&_#KXez7ulZ&ATXndaft?%q${Yd`tbe1@F$8ocDR| z``KH3rH=WDXvg>`$GAoKLzS zwIwkx*~#E9tL@YCU| zBb%Zh#TOHGJvy-@c4~Aar$rqIRfV%+S0)apW{@8?4C~%uJ&=4U@e!J*n3_vJ5OIAh zu^^Te?H0PSX+y(aDh~$h#?)O>_at#e$JO7_@MY89aA~64+~~eAt022a&WM~}f}Mjp z&Bn?el=o`Ed4?XccFRtc%T$=fC^X?JZ9(y^J zG;=(&{W}98s#-r6e3H7RE3zKR`abLatZTBm`)7K8b$4_gA%pWY>XNQV+@3r=)zR9P zK4vU-%yG7I?I8~GZ1;DZ)$H(H?*_a@h+MA!`TO!S=T z+2Fp`b--~G)qfML;!Za=J3OxcxVO7MC-SrfZtv|_YUUHEHrnYVn(G;kX&Pi)XI-1@ zndlVX6Dy8i6kisvccsSaPx@Nn>Mx;C1&=GqeRuP|{Xia4LWa>=~Os=AK^RD>F z#0Bw@@gDK@)PvnkCD5}&Pd2Y>?A@@fc3ti2+RC~KP07$xiS6kVTuZ&1e4DA}^ql7r zVqYKiwDM;KC*~FwY%f@l|5(m3|7TQ~n3uRNwmG{(a#JByqe+O0Bz9ovS((CfS$|**R{bAoz-b>v#aZ>Pm)Mwm9=JB1* zK}2#t;BFv>_8H1p)7L>2eHz4JY< zyBZx2l2e}{Hw|B9Amm2E@~_vA)|4&by4c^WH)xI zY7!?UbE$;eGjU0LOzbx5ug{AZ?6!}HIq*LYgsUgS9*OpfZi&1RofLgCc4|BtKPT}9 z)^&Y!RwNMih3;wjtNFN=J6kSj2{)hHoZGazVO;%v^`#AMntF%dihr5j={(tcAJx^T zaTdxj;ulA_JF_Ewa4?Gg>q9+85u zH+*ZPZT!JhfAcn%-#fv##W&seFi|UE&hA?6e$X??TgQnfqy0Bf8#>!}yLUOYVxMv} z8@1_UL>yjAP4jA_-ke5M--(WA$vfF&bOA;eQmJdQXPLLJZ>@KSXBim-RnB>w{rnPj z^v=cm)$E+&m~A9j?_m)V-?@oA(l^ZKh;j=U-&^0Nwk3~G%_2JH9&$U|8q?B~t-q3& zB=*Gm#d2bo#_Flxd_&@m#EYEjx0{&uHN@4t7rP_oir*9Om8eUcnk;7hvVt1$bBX_6 zg4|pdeK@u=`Z5vm_r$E&J<%CaXJl~rjgXHBjX|M7L>v4XEsrlse3{Avm)AQgoXxH) zIPZOyCqWF{Qtt>~C(i3Le9t2HHn>j5x92b_iJI@rx!aFg3&~*_&Wi6&B8c`R&q(E1 z9a*1l=QQ14*{wd02=1-U#m>)(z}e(_jgvYTddK;eaB@tJZ!Z;#mk=@Z4iTnXInUr` z&OQx08#x8DfsFPvwzffxcPr;XJdBrc63AA^S7ccnAr|&e^HxVUr+fi}9CuPV?P@Zk zmRRSLvGp}Ixvr;9?UiJQzGt0--94I9-)_g_(~c^QdDQbd1V#2pY>8(luEFAYDCtYB zP2QGVL#5~?M6Fa|2VF<4?ypko@FAX)`hiTBp~;2GNb-I1O@>(;IiG$=^7X{p_|w!p zKa*IsMMR9KSnU_-uDMM-$;Zgb3%aIBonYn8uf_SL5kx~0Id5Ej=IKG1y?{aeJT2TddAKutFKGqgIOPefbdX3c@zvd=u1U{OR zSXGy&A0%7vCcL<0zmu)`8Q!P_S!R8ppnI$?_%UPTU~D8Ov?M(j|L4BcPpLv>(WEcL z=QoUO(Xrt3T|A=uQ;k+JzQ0>p;qJFix2m|aAGuwBlT-LSeSSlR>8JP|cjAv+$a6dJ z%6y1N>nE~D01>Q;zw&oe%nT48rf$tax!<5@%kFMT%VD% z@dTND4)*sy#nThEULm{oRp8qPh%Y1V!NN=X628rHa==FtS6;vl)5YoY@S=_-H-8?v z%$w8SkoUQklTBxllcUvatd zIGICVP*raP-T=ipo=@IR7yOA8>9}>f@j$wseFC?+8~@Rh_@Aec*}a4O;%sv)S*HE) zLO%!eJ!EaSHvdk~Hb!yg*hb4`+?D>ySi#x!S8>jHH7n{njjqh<+w^PrH`n2loQ;op z18rZ1OjY*xV7$y1;#Ho8Cv-7>(&@;?->9JQ4A^a;#=?)B)l`8`_a?C4Vm+SjgRiun zXoKhIaTnK0@kJMrqwmFE`Xw2@+v)o&Y6*101AQfa>4{{?UQS-pbaN|n>YBcV%*Us# z^YPBk#g{kK>}^z#uQwiVl?xA4H}eTjv$%}2kHgl{bVuvn^k1o&#>eSgvQM6eUhZT} z7g-I~VdB^3r9I{hV}eRr4)9`jkpAcxcOB!1*)tn1KC&#~9`n{l1_s^cY2cK(FA3*TAyrS@4} zt%=ao5l%sII4&`lksm$Y8btocA}U+F#9r1^PDU7EdkXpOhg5W~!7W2IS4Wt5xA zza<;m`jx$hIp%O<6Xy&rfHn?OVI)Y+k(<&6{+Dyu9eNe7w)pveYbmfgLI&A~=4wtD zoCRfEPKNJRYgu}+(VxirfAGr($meZNt;E}~k*xaXso$}b+~~pdwT#N$gQ?c>tns7y zUslPJ*nilbI?EbhUCA!UNb2S-MjO9^>?q;1z5dh-SxPOyGt5`G`Un_&3yF0a((WT> zw}j`1n=c?4&q5NuU`4E%aNj7#wUgbSQD#5JHj~`s(MZdE#sKo(exf$Ov&gs#V1FN+ zOJy3f0SMm&Jw~lxIfZk#RSBmIHAj&XTyKmr9zmKNK)UIa%P6tElabS>8(-kxn}ggf zCl9`eJ%>11PqU0lB;8IcfcO4)Y94&b$ut4;KS=J!*qbS2o(=3sEabjVsjX0lJh>mA z{>$ufbY~ahL?q9K^i}N690$+bN#%h-=G|l?OMVPXAGFHS1`=Z|(Aa1cvQu_7ElnY_ z@OkSxbnBb&WRU!aA!e;{J+fxEwbXjW`aOLKYmfQHXwzw)h81}VKGxPqiZw`v0`!X) zY5JaV5Pjd5db*9slp6T)L1gf&VC!^tQ)VG8r=?>?f6j0@mn?G5k4Ku#O~)DYE@}b% z7dfyA2{H%{=wW=2{sn2?59z*<`T;q#`WKYHlI)P5@NRxarI}phODawEbfXqo+tV0H zZ1}hAHC;@7jz!4m5yo_~)z5^&@`3zM#&yOcW?$|-z+Ep}Q8bw9Oa!S0dp%JNd1QQc zAq)BfD~L3{+4zFo_%>#H?2jz!?{>9XeN55hw;1QPx>}0 zWPA+es&?iXz~W>oT;}rYrP#LJIRWJ7^fgG(e&lUuqaRPC4#7aQ!%OUx-3cA+g6H3Y z;u7gG?9BC{QbITM=jYU-m`pvCGBEcS^$$8D;ZFpUPf&v)K$U=xsVdSxy@zMM24W9U z-=zjglSofz7xDoh|C4nhJ0-KJ>Uk^q_1_z-jdxWgf_gvOtkLYrO<=e7JS1u^wOSlR zuwOwnp03nBh#75>rYExtxEMWjHWedU&=oHjThZdbLnV(>mt(yZMDp*XuZ76o0_v$; zhlX9o{_*wjK^FRV7Mk@AcENfZKU1ybU8G|^S`bQNXYFpFcMm%e1!&}}v4I{yOV2|C zpAHpX%)EbwFD@Xi{`>Te>1HgSiS#s{+CLR&(h_R(v`%-1zSlF}d$6YtlHo1u;~I9s zuEY8%WuN#xxcx464)YlM$JCy92fD6D^FI&H`cs+rW@a~@3KFBCWI-+T=&j?+_ldH{7Y4zYi-f!$*Vr`v48`f9@%*8_D`OFEnS zFSFUn7SHqshKu3%7a7AZ>^)DR)g91!Jtu2;%y+@m2k6%)p^TkaBg?U~mZ6y@8#fsD zVbv`}>%WDa^clOScap(*H}=?l>?c2!bB6>QNs@Zj%CA}mav2w7Z-pFM0-j5A7JiXdF7a6dh3R_Qd z;_6dy%lYhh*CTWLLI=HoLn(X6Ymj06;o=*p%F+tkvAgj;cFY^t)APBj9{b@HDCh^c zZ7r~W23zAPIDZ>le5$$KC`XPJ86hi+^=l{WvKzqMJnDwLjBn@~^wa?)e>3=h24DAV zuHFUZzW{9JK%d*$D_?*_`Oeq@1w4rkOj=FWMC5BWI-(36co};mPWpTSy>>fvz1-+Q zJcNPd8*ki+*1X%Anyy7&-EC||<`5|y}L12PUI;${5u5w`qA?05geUgib-{ytbeg=$AjflVE|J#V3NK1SN! zK;5YT6>TcfEFPjgdZAT{x&IR|y%rdrWITsH{19#T2(wQx`(ij@KeYrd!P2}2Ox>5B zflfFdJU@yhR}ChY6R&U|PzeFoA}q{T(LX;yx2pMa7T2n&Qqu?76b5#6;D1L(?q12A zUlB=gCo*L`mUo2SZh`;zpv#J>O!YkU{Vn|aD>bNgpfk4-iSRzXECQkpjOH}>YA~?} z6O6}@@=L(6{w6zoCcFN5#;s_=O-RC3#6#SKd^r!^ zR3)ZAu-1>Kzw^xrP|QlsdMu=JEsXLHLFH+0npm`5``X6TnmcxaPN6g{G@ZIjX zk;>DR_&S~A@R@1|8P{?>zRom(JS9nMDpf1Jo_<-n-r+{>S==I{dJ6OFm}cQY7+HfRO_vi;EU6Mc`kZEaV0M@ri-zWK19O2 zhV)mYM}o>;@1aA+u{h#r&Pn@<2wG3kvH6jg5vb+7Lk#1iw2Nbt5b0t7u~+RN0E_2jQ`D`uYP{p2GV1 zR5YIMt^ww);n7~q`6Tq!?NFbOtF5py@~F`=k(EvMHI%9nOB3v2+GrH`2EQ7^UFaHbArhm=_vH8FzO$q95z3Gobog@m^lb zQ=O;O**wRR0sV-xMx^EnrfKnt?023@>{HaBtKX09jb!=jZ+Ai4?aJpvZK;5o%V z?SVTELNzk4O#0GeK4Sx=kw z^l=1^spGjIS}YDccQX%;zohS@%w#JxxeDpAmT|1ZMyLkohxvby@9WWehcb}J##U*? z2y(%Lk1?8zsTg>71~b<{fhR&E1NeLzR#hKpqc2==61eXUrn*sqt0(lC2aI!R!_Bxo zP>GxG-MF3&<(7aoS+_a7!t`iCJJ~$ZfnHld6S#BS{lM7nd8LxG@SA6JiqN^<$Fl+RA!-ML zbUykkhq{jA-Hm%X@G4{0<=mCe_}XN~(vg-*`HtJ0zg|Xb@?1%VM#T4dv=rdpf((5) z!Tn(%yB9v(0Ve+B{~_N0@RDrUo`JXUU!CD}@y8K9R{*J6a4Wb~^RA?g{oF6_&}LfN z0oHba%VR)t4=~*Ue)sUX78o~VU>@PuL(rFaZYRA-Z*1h*of#^VoUY@J3VKxy%PQ_X z%C`zeTSIS0cuHf?NHn5@j8m_F=tZO64p0AyMBjj{`jgidC~ghZrqO8?wG%4bmKpPY zMps1-;s1?Fa}vc@G1dmgAPR2c+dn>&#N7^o#0yn$R_zc4SXeU1UEsv;+yz5FVm8wb1uJ>ai3skf6C*3J78$TT6obj9I^^5 zvNL1QCtKpR%0RswJhb8dGH}tBzwO~yiSRSNjYlh#DjbTGNku>%bkIpa4-y zd#-oLjIlg}XGxB1#v$6$oFpw=8Mvmw5#>0b|Jsah7wuNUZ{l}d6W@_2#g!_!$;4uQ;WBeEPC2Q+=dLK_6;R)&GgK)d_%U-S>`Tjga9vLT2w3A4e+b31^A|yu4ZRAf@6yqr=*mVmr4KPEia$g z*v!y}No0DO@teqlBy$M^0YR@NGhazU$wCABmcyx;sg=1`&&vuGETwtFng5a(f{@*| zxKX;xlj)n(jLd&=JwdU&S(ej=@S$-@2Bb1lz?=Cb8q-S_jU>1vt}N6_-qqkp7NR7O z@YMolBvTBYlynI*rWk)EgCv)dj7sBW(DdX7HvSCkf*B}E9!MI?#>--4He6(1dAZ`} zIcEmj8n0#{jbOvw%cm^flJnwc$@72zx6Q~L@p1{Dw8|;vljKrS=B`$JmozA*cgY7; zdes%lw_pahvcOs~GRcQd;J%Ds3b{v8K$b``*IQ-Ul~p2nRnAqr$3pHb&)nG&%Ici4 zOJv`4Mt1e)bs|@KKyAJGZp*cHJS$mO%53szyEUImfE1wv~ae zIIsz5+Y-MTu9cm61h|V2#d}9HTqc>)kinoV1z91&Y7M{24$vpb3RT_~Rf(RO!K?16 z&frfro~#7bX0D*63O?5}0&!^q4d?<_qA-dPWWEb$;}6gH_zG=Oqt9wB?3&HSRF2NiK+br4#g|>^k)sqX(^e8o8p+DLyr4B&F686fa_| zA?}Q1Myx?*A4>PwHc}4sD{KhQTGgm}xpaagh^$XhwepQjev{nwa#fN~c$K}Qzcv@- zXMUBA(8@|Qpe?1Pu^Bz%Z} zYXdt9@+=?5^Y|{@I*T2Ft?cb=z^l5KU5I5o^%?DMz(ZlMlFMg?`@z*=`e~$JN!o1Y z;Ai#Jh#eyR?}xUtd22**WRfUDYhKC103L}5Pit*q@r&qNG?(CeE%%C|q)}xf=@)y3 zqP`R`l4MBmjO3E!l)MH03|wVD7Sdx8-`en3Z%N`JddvsvvUMf@b*+>!t0I3pe(eC= zl)|y%UtQB`LidU9OaTFZz|wEiz-&eGqK zJh|XpvaEIHJ`bOQnLb1bLGF`Ga6$*th|PrT^q-uw|1hC5=LoU(!tyMB`DP4Vm6* z=}T`(3(Z60mo1|)X|#>}7bVJuN@QN54nfhjntZ_3wm#(Zkk3NgE2~cYnG4oifrApX zMq7APD+>FsuK#OG*z!TrKr%tnMJqwso2{VKNz;R#(lfe8Hm5Le z%Oaghp_i?}g$LbfzhxcBpP;s6U8)tWA=J9s)i;es9Btds8jYl$D9>Ic+4L0oA9dT5 zFWI1Wbg%x(=9D$8`3P>3@$x%JO3MNi#$*o)GxH`soR#NZAVEyG-hZ;ze>)CZbUVr7!GC6@QP?fxR6aIYeV?3sX>^P29ZT!dp<-d zq8fX(C_N$Xr}T!ftiRGIw(m()Am6fl4e~a~N1?S$DLgOBRBHx#qGY|wBhnqawDbS| z%L?tnFIu^D=DOCi@|4=^T2ZnnQhG)FE-DshYgV#t>=mITpKKjj9QN$w9TNZA)|os? z(jE5d-d;D@eqLF1S_|6icJZ976zLN|*M5sHr6Ht`YOne7Tx_p#)tgqA z(h+LS)-JMmW#7muwSU+BYEQC9S0x)H9VDA zOID#xSHiOT)HU1Mu&o$Tp==h}8}>Rybfq3-Q`$O4^dXv-KSMLMWkwvWq?aHjDA^Dc zj3ot1cuRt7{U_Ni2`^099#v^{$$OpMqSe2?Nef7N*eeimX(@awzohNkkfo?Lg-=RGt*1jQEMkz2U;0vO!CXf-myJT`YXzo9I;m@;z)U!rI93Wgm-bD)^w4K zT(Bj%WVp13Y&YA}V0$~|2^W5C31Z7?S(cI%q72cD-trSjBiLS2QGMZe3>DJ~anN&5UNu_Sv$XOhT*k-dV~cU#`ca2YQ~bH!nl>OrZsO#QduJ6AI(U-Xv=?by6`IO2}hC;T3u-WLmET;ENn=}+ZMcS zsfa>^Us?XP1u4n@uZ*_W3c`$56S8+~y<-2)a57Z*FaK)nl75nk+Vhp<7M<#|ZTSlF zwse#Hur*$YnP^s;t1MJWZ{bU_K%N#GOSYsJENn~GmLQa!bf-Mo(i<9udC(g;^Bpso>677Hs%N>vd@^K%}Tl*T?)#}A3 zc{R1ap&c&S5uLbGdmf?^$tS%e9qgSkSu3)FwOV)6s~{vQ6bz+T?b+K}G70|W0T;iE zN5miT3@_RCfW5OJog%FwUKZzx*F+7{AKDYtuaeHqiqKv5dQ`MxTY&#c z1Fcono~Xw5XKDvjFL^w*4=!8taOT?)Y(#rM)LtRlt5s}K%375zD?3#7uubju>eJrYk!2&vA?rolA`ge`F;SF0+t$gy z`=YY>MXk0}7OhG$i~FTP>|IeG-{f5veP~r8y(0gkxL13Vx+3k-Izu0dyXcvr09lmM zAM#sEcgSBV%IKQ8PgJCM8F@jpKO!oTuCOUmwuG(q zfe_D1E{Jkueaa8^ucpze$+oq%TjQri!NZo5k}o#@NScVd9qmrp}lTz_Sq zO8fuoN08N_XY4(PROT+xpT;DsL{eP$+uGhkwDJ%ysz1TkUg>JgdQR&^St^p=+HKHFs|~GhwNjD9x7Vk3G>UCK%1bQ?B018A z|F+ej$O~B%vQXq56HQ4c$PUmxm{!h`M6!HjVMx2!`bBny?$avSwz&mqy=*=Jk$jRq zlw1*g*jiS1X#RqPFeJGyUJ~x?=Y>Pt&J-ux(GB8RdpA_OnTpD=`BswT|Kx}GR4Yn* zk5g+=`!_{Sh<9aG*fy-K>E$;SWeKCgzibZMCea+UJ`xq`4oNWkzi3PJC}}5J6MgEl zt@lI`_G;N)PfJ>86>V!i{i-K*zn)U-qG)+nTwqOHFa2S!9b}INfV!+kd(Tram)$8` zYo)2(bJ+r|GMp}(Pj`!-JD>%m7qpJlK9=_MZ7Ws&CwsjqdC?19(w+a(GTIrFJ*oC( z&q^w&H`%Q6jmuJy9GkiPr1YSuLHrTPNKAWWU&lN(O9#{}*vg8Ot`x;- z#CQ|<#6GyN z2^k^YlOIa@Ui%{lxmS{)lF!1TWJEP|qqU53(m--h`b@p1xh_oEbSnOEk@K-yVlH-bS(5hIHU(l5OrhOwl zBfhaUnC&T-{E|Ob*pY=JD_54DG=Z&A9K^MX6B8L4I7DQhA}Y6Gr8Yyek}aYUX~HN` z!ySNlm&|Bn4ax45Z6jZ!C`7VD{j@?}w_+sHsjdv1r6Cl<(gElTTDsrO=e9st5g0lr zT3nNd9?(6q`fP4fyDsk1YOpnL?Gq@bRsM`pM$?HIwgvNwpXtU^+RN_9r&G`=vQ%Uf zwr4lJ8~f=eFm}atYNuWi6326H)T!Jj=^*P@eo{qXwq<;>sI>D_!jru+@u|{hMf7cv z<0blO$DK}~w+$I9YAItB-RM67K6isvwIPfMCN6T{B!w$894cuM2Ev=UPmvO`&4k;- znP|w(T#;Aw7+3d$kJZSGI(W2(_g~zxn;wND*%z`aWdCVDD#de^(2Bihqv$~|yYL73 zRT!`1NyX<1r_zPmv6dDHBTM4IJ<51mczz?YX*+)vQ@fihIb?fl|8pm!kPQ?g4nBo! zs%Ou2EiGtvt(a*IGds#Pc{3{L(M7hF_7Hb5e#IkGK{q3_d-Am6JU7s~2Pl>?-ZF5S z;Hf{jGtAW@uAK;Wlm+k)PkJ&G(Vi%Y*6bMXM4CuL*g8}2PLV69yaYiz2?WI7ZRpzt zr?um6035gGeo2$A#!%>>6VbbEp*-8iCA#QBEO0ST+i}9D@M{XFXjNIvXK5G7J6SYN zD71ZsD;&_H>=DsnYwj2PLTD>5JIjh(luRv#Vw%9eq(VDJQVl*7Wv`eNMVc$mZzE62 zf2k}1t@8dL!)_sVaC4^L$?v^N$$$h*iDYc4%+yKZ!zdV*^&PM=k3@DN0A@$ zrHT6f;yLlLqLDSP0-|M(a?fV2hWMnD-QA2*IRX{5CHqCTPLO=dZ@5bDFRMZ zp?n`j)g8%T#m#KAC+-Cr&CF4&$tLJPQZ)y>?4bXHP-Ti-ou0r_tJ_1+P(EYPT0q$w z3FcYF+v3TC(2lGO*|5^5S|4PA!9Gxkps&@hY>8$lLecjP^xlZPQB;#6N2RxA&&c;F zpJ5NMD(#>)%HV+@QsE>hL)zX?3x!b1DDpStnG}}0@iz~NFr3)%?zCkwf*>uHbEO{` z4gwA3vd903}XwMCwR)EbSXl^jMBj@v+EG}hNltF2Od6xlW;VX_bIgRQL7cd_A z*^3yN;v;&J8Ifjg(tHN|wilRo1rOV4r8^^PhSN7g^&R=Nk^9qJ4>9*Y!H_}T)lTrQ z*v=zx)jHZyyrefHH%l4)AKbSWNujk|imZ|-W7!7{?aw@2lab$&A8ukFq_srD9;jBf za2@!Pl<+c|8lG$5Q;ODG;4+hZoK3VWev6SU_Z9aFgR;T8GV0%d5_KyK|C}+nHy^6bz@~sSLWih&1dQ;p^7P#-uEcVe~Jy4Zz`xJ1| zj=3o&Lt~a*+Y20K1Iavomj_IAnGYOxK($@@uUZ-9wAL2-*uV_B0m(3tDn8nYX68Bs z?W0wXqKTBpvLxP}SbuXs2SN$j{`d?O>$cGc!4!U%pc28RH!c z?dC)CCCJ+!mD>+ar7zEc7n;Z>`h*;iW0~x?K5%Rd3@&G$2cVxJ;80dbTcUaX0MiZ7 zYkP9ceDtEskYixBnrm7&o({il;;w_teJFS3z?}xsJe!$Ake0eIu54gl#Q*JNF~#Ar za>mt)wtB!pU-4`k`1L%#$#3K(TYfwEcLJH#Kst+e1jwX;Pk{E7#nXfFhsn;9A6Z&G z2ffmnF%PFrW%0;M(UYDkHp@C5P)i?zG6? zoRxe;2b0nR;X!OauO3zJorf-c0ViYMJid)i~ErY+fToq^dQIM0GAHZpb( z6kUtv`UPAUlU2ErYX!*DtF(a3!lw2)%#ZHWS#2c5{vP?W)Z71CiB&xV;vHW1T{ zqsXHzw7-v;pU$`x87#Wri&Qv3+b-z4k(sSUB5GGnb-4G_S3bXp!b_Ro+2C#mef9-+ z`)M%%KH8$sn}ACp*y{@>v{tp?(JR1pKj6O^NS+8rdI24+D^G(CwllU0Al4h|?7;O_ zw6O;mDw;y8CE3!74?GAbwZIXI15mD!AO6|^wA(Q{?Nya>mtw-D*^7AY0FZ3}_Oh^c zF{%jjI|lS>!GJ%b+5P|uCbTU0D)vs2djl}3=d%y)*h33(Akcs$s-nGrXjxHeT4(M9 zVyD7QYq{zn54sU7C|9VIHg=FjtkuPFKwuSP-U}R?=s5yCmBHCx0Q<#M=Gny9L~D|1 zzjM#KNSQ6Po1`^)EtZqjx0-&lp@+&0F37}3mMijmFSyrgL#=*E7Hu5f74<0wRh1+D zCcCtmc))X@i$-wM3OuO(d`Ixs3>=hC(~+4UK?`&P&O3OjKW#aX2hwWYfbU`2PryMZ zFqd6m+JNJG0_9?~L6QmuyO8OUK}q_)2$}L060sF{Rrc=%dB=k8Kt4!NU-@e9<{4T`L#pE2Z-`k8GNR4RL< z3(s@}*Y%8i2rX}i&&~%D@->;5gRa-i>aR3G_^ zmng_d|7m9G1K(?D^9VR7f+|`vRH?jl+0^~v1Opt%>*3^ii*bd)fxMm-P<}f`_6L0^ zPjDN(tp?Jk1D{p={ueZ(SkO8~*OhDgpr_4TJ;umO;0<{}cR^WNVMT#_K zwu0GHz>~Cx#mrYyd1g&|7}DCJl{i(QYRFK279BVqxE@2Y${T)=yLSUOMH!#T$hSaY zE6}N(z;=|}_5tAKcX-c9|3i=v%J@}<3Q3f6xF)Y@cQSW-LRnRC^2t#3uZ(RH*}raD z%K{f?fty-tL@b1UDwt0(ZFk|56O3L4mbG##K(@4@-6klq1h^|Z*8`k39ip{O7=G_dFD3BDdr0eh$y(kChCA@7tMLLA{6|vxU_2c1Ju^F-IV>VK z_*~#JgXazbQ$KZqK46wEa*;>F@$aHp|7OI>^gN!r8J&UYBV^Yv<=(%62r~={bzsAyBKu>tWM$kNkGhGOdMas%#K^(%PLK*K5D^Ay@!2yV54)IR(1K)-^D0;{1w z?R0AQ;Y=w1ec+b|Wju{U%!m2|(A)+%(~S)NifaL29DeJA4)^w zzW@eWVIN~vE1T;}=vzLeb>L|*v^Ib`I17N-weXQF>6582GYmW(LdwPIbu{!-LiLwH zM6a)6%s(*K$lv-Q9qVXQ7Rh;7 z#cnFJ1gKxohE}>G2ZGcY=mNByT(PLt(+Li4NKeUBT9PgO12~ute#6wZc#|q|4?_=^ zfTaOcCpZ(^V=xlA2bAM)#=p!*}!&M;09A-u2M0GovHBAhSJK z$=4YFg;RFY=UynQ4O-a&Ep|eJH&g5AMSd9yAG#SuXC&=vxWojC+mIf0j8YWyJ##!6 z3S15TzhX=+nYEO1_~ns3%ItsRdJ{E!`hmw%pjbm?pIfMW)|;A+7eQC@z%D`Bd?kHF*1Amn~=R~61Lp$5J4W zQT~0D6JpBXs2xy=GNR5xUM%Ig0(wb7b2?>6Ud<-1{R-VpLhHVT%v=ZL8n9jE4XfgF zh~9q!XCtsQin#U}(AWZ$4*Wl&&H}uOM1F4|5zqOZ z0Pk*brP!BIozV(B9UC#V7;OkzmIa-@W|VeDyADmP%M&hw#x@OJy=ON6qDQGn$7L{V z+A=BytB?-;GT>qmI=`1Ei?xrsSb|)}JAmspafPUblz>ykx$_5n!zn0O9oaCUsiHcW z^$GprF`lpz_x+5~S>eP!u9DHja&Sx3-yR|7IoeD-+;Fs0_~Jx%s+{B6f$(P?UU(F` z)*2~E#A_{t`@P}oSf1Jf{a%1~D+dp@FvDzQr5oOPAMdoo{*1>`T*cN3#HY%54S#e{ zSiiHMyjdx{^(r`9fpMON)1l~r@HQU2NFDgLpV0|VSB~qy<{dBexrc<-fCdqt-qWG2 zqp-6Q68jv@h(LD#ht*k6M0X8}WW(Jo{MHX3%5)~C5Y@{UX#7Et+?sJ;fn_9Oka)cK zS!A;w?}#46t9aQ=FliFNTndMWDd5=rinR*GdK$QsSk0-9O)Eshu3|mv@s~jSxcN!C zL>??^1DALw78I2#=!xjZ*$qnTV=7QJ(VWgmQv_Um0&id9m(Sw2{=kM_1iR-yvKNtI zhzipeLK19vGx0BFYe1@2uG^xTbXBrrA+hiw2w zrxafFd*W`5cMV{04F<(=5gcjG-Iju1GaXB~8Z8b&S1s^yF4W(L6kNw@=Mr%=#!C3( z8zk(9SnU%XF-IA-o`$Pvz^Xxhq|C%eNLokK`Ise=O!pw znT%luKOQ3G*;v7PU{{IrpF|vT8>tiY9*T_Cg4R=sV~QbZE8zVv*!sp;jz#FF&=5_V z#4PG#@kPDjEa*TJxnni#gg$7{#822bsXWV%cx4kg z%n#f}tn0Yp_8~Y^3LPKFlU3r+UZ8EM>PB?jO>A@u9T6+%H~FT(LkUBc$|DnHiNRLE zz5UpdqtLh^^zdSn{)RgLVI5+z77swls)q%t#qYaZ{}k#C1tCjFov0;@gSTnWCl;IK z!K(CUPEWM1XkZWcz8v~CWQ0M?;3sDP3JDKHmTusE+JPxIP%6t=XQEoT68RX2v?$VC zPNs74=?$b}DcHPa@#eRPQiNYy%+r&It_5~#CnOE@^4UE!ge(xMtH%qi+zfT!|)+3p^!hw zXd8&OHi2EJ`YWaC6BFuw2emZo^rSDWVA=&_6SBj?-OOgnnbGophpwzWhHd8EfOfYLn`xp z6d+8hO?f*7^;jU&47+Q(KH=4ARW#No^(bYMK^6VD6XJ49k9U*IOXdh zkxeYK6Zu>;w7$SRJAfNm6(qD(aPBCtQ<2+fva*_>QMLwAwm#YW7&Ki}Qx+0KoMH|F zO-j`E#VU)y1o{_SA~2*}QV!f{3`W+U%xeVap_M_ByAcnL`*fe<(C=%IC~tw@7Yjb$ zpU7A`eq;^&-hgaBz>204TXrE9%|@fIkoDXqc6q=v6?}Xd>1VWLA<_0DY*9b7wh~rA zRF8K;dx2|J9=Q?u{BdmeD`ftqwwUpU;yEIbL>o5d3^H^E9LP7&I28R5`~Fs-v8(W7 zi?#X0i#ze+flxF6Td^D4e&&-m`JEYyx=2HOR5W&>HeossV><}US+fJ1n>I#{6uxOpAx z;y~xufh1j*?;L<@A$aan*sTXpQ6PDh_*C~2h`w21*j2}l#DgRHGvB<*2wI~7JBS08 zBlr8T3U9G&{fI9`mrEM+pNO3o>$3i!oHhcNdlUS=j8`hobs^|)72=e`aBU0MwTFU{ z@bV3R%lMY-_=Mg__ZA|TjZkAhwtF?-y^qACK%Y_Q=K-{1H!>?^?HR~z<)D5HQnDAz zk`DU(VX!Q#fU0SRGC6po(rEYZNXIPR9gW^x2Tigi$eR7}e0@P9Z%NjWjs5#5t%Qe# zJZm~n>V?+2$(Hl*z#3Bd5o_I>QMF-C1;{}NT5bgq@8u`$=}9)*6S_VnE_sWrOhik} zXvTJ=BOm(`$eqLaz8~-Jz`Lx2K6gG*K;%au$lwGt(1S${$7d}9HBlgfRi}yshNW}? z+TMoyE43lab1S%V$6ZJkM6X^*jSK-Z@qa6aHXdLqyL`+eRzkehm zSpg@00L^s@Hg+joYKazHCKr1O^(TKCy%QYohhP|T0wI}10$yY}k?UpRZO(lJM_sJ( zh&8Yf_9dUfn^a?Fx3M}kndwj>#7DgMA9No{lq9O{+mMuxGun~JK?!U^vi2AL>r zzIeQRd`~Q&59axk`S!1HWhUs!X+#EVpu$F?oxvb*ZpZe=6FqE0!YW~7vb0TPC=Kuw z@lbUhwnJu)he6J(h>SnQ284m;zm%&+prHzyKNDFhLsTXz>HDD2F|d#K$PKS^b)aNbOJQfTk@7pp^;AyDTns{U6cIoiR^u_V7u|>pumw(S4rrMo)>x0- z7W_GcUw8wK?gw=I1J?O(NtVwsidT5soZOj)r;_sX+~6#5Qd8Q>@Jgq%46aU%}?|L$YpZmE>0P zQ&6yHa`y$`+4e)Pzn20*$d&ND%|R>e0tW73JjXnwzyRjHK#=T#L{$Nq{vbNCk2tOn zT`bT4|3@s}8r>B!L_etE!OsaSj=kXUSAa)uqJhimY;r|2l7AeXtjm3t^Ni=9=eH*- z2|zZV5Stt)>f6S58e#{o;mum%a}R>p`#*B0aZtG#p5ZU}DC$8*Fr*)W>ADCB{Ephe zL$cy&oYHhib!sX2lK#wlAl12c_|K}~3vVUTIfV5+PDDDG(@uwBc^5*}aBOD+n3+M) zXAhCZ6|$;DJX>`0-sAfs|DVDbPlHo={F9V41=qO`F^&m4@e-QuV)gPF`RhU=^Bvgu zzC@Or@i(iaXt@Dg^zvyV#!?a_*-$x%?6M{pp$E9CFVVt2d|SM>0^YU*8MGI;u#JfH zg|trv<$a~fekl2KUJsOdfiwFpJ`)UoY`{x(Ao)r?PT{@5Xu{C0UBoG)`Kdv)F_ZI< zKaw&2OJuzkD*VBB7f3r~fs9^7Ypae@ztYabv!Ai3(~yX*+ClXUK4YldP!@HI^TaN% z)Ygn8hjRvMK{-EgN#CepA`U=*I^g4e()_i5)k$bU9ND*5TQ0@P7m-W{9wHHL>!?z- z!de%=J4HSN>Un={B60aIvMfiFMTvUD0JLR0i0>z`rx&60aHPLEb3KjyJw{ZqP&>g% ztW!W;o(y+;5$8XFYIH-%zsP;jwU^kS4E3to5evMLllP7i%l7270MTJG5zCp0uH8Z^ zo5*W8-5~`&zSp*b10K!&mf%@4wTGYbi5*xbkz@U&eN?GOVo$!6TFQ^{hgad<0__f% z!DP@_tL6nU`qeX7HmIyF__-iD2KnN7=h&*;j3-_w)u`wq&2m3Zzb4BWr zXOQ6W#Oe_|PtO?A z?;x_4g)5WrIFp(EFH&teo>6td=1;<3Sdiv3%yS&PF zRF|lxGxn*BycT(ml{?{=I+4vMV__1>rv@_CYMeTJ8!HftRD(gQhN4f&+BaB~_R>2n z*Oy4ZGjP%WC)bxN;bj_03)RPJe|%kU?slG7#DzDw4Q~DotnwVZZyedr3N#+PPTm;_ zmtxV&GQ>cl=3|Cm@x-SXT|^eO zk+3dkm*|-L0pGBMd!-XQY{oJzrdl)roAd$6?JT`kW6-VdISJ`A;;CfLWgaOdtBc`& z9M-()rx^VvbeW3|AH-+<0DgNUcHE9VxJiEh6FF5fIlwWrW&ttsRwCVte5NA%qSNXT z)ZC0e6MJYXt1;vq#5-Cum0y%eKJQg3^b5N<5V9f?_?}N}*psz2%yZ?Bz=O7wX32-8snGvh?8_}_fjkWBy@K;y%D{yYM!Q ztjc06^(=CWzS!nP#5PCq{99p}```FPeCat6El7nE8=7^B?LB zGE+Sf(qgcwH8qI%=6m9VD58rv{NX%k9teW=Mdq=Y=jr8p==j%CX=03FSe6~c$(yhu zPmqtbNNg}vE-OD&H(~u3!kKGe;~bt99MiDd7?Pkce-e^UyU-)alh`si^aUcU{e?Honc?{YWMjk2H9 z4eFAWoJD4~5w-6nR+*1C=|pV0k82teZQjr`uCU$wn?# zGU{ZiVci)k*rM2$BkByzCe4;dV*j6V&d)^YjarwnWa05I5VQ40hgM4?)GOL+VySgR z5MOit(MY1cX?WwEyb}8^IuRYLfal}oGExn)wmbN@os{aGNoYWFkaTdsMM)FV< z{dHC(>mU5mSNPi|Tw6sN&Znoyd%t9i$FUBxvD!0`&;LNqFU^VB4>>da1$JxBrz$Xq z&8c1R*}-IcOIcwQmHt1`s$cP9-xCFo#xumqKZEOk1$yqqE*vCMdrMUQ6Js7irJ*f; zNx_SIdDfTYC@Ivm7E7-Sbz*^f zoGY>rYt@--J`m3vioRZjlkYg=ppJB#sOSZE`WbCL%Jq|>WJS(F5%)XJ_jR6xb$ z3U_WroHdTTrV3|7os$|sSCiaH-bsx3H@-}u@0*ZF(U~XKsy>(ZV~vNaE6Futv0no@ z&1(cT(q?$BUHHLX#1SdTX&hO&3Ke@{Ws}I2M90lnRPiNw4zZwEBdSGq9EBXV=foQo zD&`T@eFv55B75Ia9h$A3mP*Un=&V@FxJ$lrORB{8Mp2Jy#^_&T>#GuXR3J;UV=Abk_b+m-EgW_*$P>U_Tbpz|+gfeaKg4;3=r^Ji5t&IjG^H)x>k#_%nH0|X z9OJQCJJIh%p3;=OPMq-7oZ4z2nPvb{WedjrC)zL*TShK`M2sL0*#mE@B2D+P))w-% zCE6*n7>$U+0~M3-a8;p(%w64N{9hoS<*_aXyv8;-be;1{Y}oX3c%ju=j67U!NVL^Y zij;rBcUh=mE#pKVBm5anE|dm!c5s&5bZtL&cbBXqmuSg37Q3)J&pAbCEZTfht$?iz zXNIrgNIORSOy%SQc@$af3F6#C;OL)2A2WESMQ({k&O-(d5i`_7Dwh)jc0tCbF`Gr| z94%08DlZ@xeu6&`U2&1vm9Cr}GM#9@j(kSjN|r1ta}9|CPT@tGk`X_^DiOJIpJ|*K zP?Z?ADjAye0bV!N_j%T>}y&oAu{*Dc)R-ep)IU% ze}Oa&=M0aB=x{rOd-*3;9GLsyPp3H0;<2DS#~E7}Q9i@+8jKO27h$mAlGB1=o?yhgkkmQU_wSr$jJ& zS%s)dBxm9M7g*Md*q{FR`!__6KN0g@fWPgyQwUn9!zb)_{{PR38z^Ks;9!K93}}7t!k< z*dyT^9wXV0sNg!VzZqDZ*4i4i2AtYVjMSWXu?*`f+t9Yg_}$OA$5iH19Vz(z)B0sy z>MP%o%eUh>4H-`_wqwmF27StDQ{|AgZfLsL-LRg_s3{e#meds9649$hRH)h#f4dpaA@pkulX<2ZK- zE43D#?#>)XbH6=AQJnlpz@} zzKFLOif1l`)+ZsqA_hLpS!!9>=v{d2EsVA;aZ)mIen&F9NpP?-`Rx~yk29J6hfR-T z%QpDy(bNaql8e-bt4pvg zi;#u=%+m;WgR$9bgiU5<<>g2`OC#iZFwwz&M*kglW*ISy3mvRjUFn)lFQ0dp=!Y_J5qDTie4`cxkn0jH*77xqxe=|lyg0utzltTrar8cs$MMkO+x z%>O+0UU>emiFrbZ6@%dTJap%G;w+IZ(d`1&$1}dwTD;s5n`a;j{2%Y+k@wAI#50li zwrFotw67nz#-G%LmSb_Yqt&s*=6A5jt&sQZd>R7blcaecAb`%2|#$Z~^p5QcZC9 z2WI+&xO_4`?G0QTP3)IJU9kl843u|6k_6LGo=X2doBxW$Q-ip=l%$$c}uPd>(w3^XGpsk#E_-UOcifjCZ zO((M7z?eUBwpr^>9%MeYaR!xs}fIj5Uu4?s}6=jGa0LnI)RE-Cn0^JN6&}Ml_k>uA1ey$pl&0wxaZiV zZp>~pG37S&;0jS}GPLOqS0=)-rAW;`=+-up`;^LOJCXNd z*rC_Z<_J{@#-w_P(r2PIZ_)lD#7uEM8rcc)ngws&m~rkTS|5ZTKS$Qm8v2V}C#&(h zvCzDe20vGrMe(@%EJP$Yi!VafnYgX{O473}Etc4>Rb9p@+ zKYJOw_y$=HLUx-o^T%8>366-pk2i_KyjU86R&@jwjd2fJwkWzOS|!vZILkpYnM$Ky_@e4t{E9wb9{zClt{(~zr@d0g15muTUal#W2g-=hCsUx;4XrXqEk@t#ef!- zfD4^?fO+j=I5Bo?e!Cxg}O6n2aD3e{cmVP+5bn;?#jl_}U~!Eiw{u zW=8=W3gNz?SjKbMr-u02dq{?m#6T>sIMe7A@o@}tZH8(`p=Dj3n+P@Q@{O}lx(PFR zO9T;tR43s-%v{@tsOJh&m(1wIsW5)X#Vu9=#BQf0j6fg|i~VM1EI<(B--?atiLbax zl~U}V$l{4@;k-TwienO#WO;gv)BvP2N`mrpN)`=81652s)$!&2T>I` zC)Uly`7jl@R_s$M$KPU?Yi-6Lu(!qTW^pQLAev=GcbXzYVsC9ao?3Jo3f#OQ+ z3RV!jt;s*JBV4R22`t0n&_SR?2m~w9JuP@A_Jzjsw?JYR=QjB9iP#q*db`a07N~Xt zy()w|g%hs`oLsSgE`+}Xnrd<6PhgseJ+9);QACaPd7kLut$+kp;inX@1=3MXWJm1P z6`iO>drZWuBAjS}W+Sk+#4dKRw^d*th&^MXTT1Ldax*u9oP{K?2SnvS?3*@VeFX|g5iW#4?-uB#Mabu3m#;u(7Bm%zQ|VkU z`fS7=4zZs@5A8+ob_(x`Gin4vgcyUs4YBf_BGfC~0pCBv4HR`jkLR7dTP^9YkQuh~1t7L&(QngjFhfUtr~kGs=Wj5%MI?lo7}yqFZ0^ zQJg#^5G@3%rtqSI-{Q4E^B4QWM1PM!c@P>Udh3cXMa9UAaH0eniI8zI5}|kE|Dw-b z%tvUxK(cW&12HS{Q-m5Qa81OIIDez)6c|A(!8d^mBXEcWnwh{e6SNfgcSTre0yRkN zrY-sr+AF9oa8in}LyEAje1eobT_BAVAR49`P)(n?P`0TrE)B1d@UH zqzIu;U=InrlOi-Xfhr>qFT}URzPSI7dM0qtijZE#cf>r^PtO-!*@AikbEF8{NsOuJ zKk=SGrx4!|yWWKj5IhwqJOZ;qNQJNy0_|FSOL*iW974e#flVjur_f4)!XARo{~z9^ z;B67MauMQHNxmVF&jhZW=!+IRIZE?e>`O0=-6+ZT#R(qbwfHH*$g6~x7syEBJPCnz zBv4Dl|Njp~s~9pOegy87Kt~bx5td8%BVlO-dX}&}!j=iFE+JV($SDF%{Qsn@2vbj- zlp&A~iV(hvFunvPqd+PwLX8wyJ_55v@L1pt6(LfJ-P(e0Mdz;w3oCdhJ}cstu-U>d ziyismN})djlT7eY@JZ~v6ITgDjUq&nB78C-iAB;Vq(khyEJ9T)vMfc|sp2_8g2WX? zNSMkdv?%Ul=NW=4!p90* zCp1&wJc$2_G*gUM?2NHtJB!et1p28!{}Fos3M(t{HU$2_Ybw0rtTcfjC3Xr6bb&Yg zUsQo_^G}>F_=I@xFwfIbA5Z=CU9l5O^fC%eA%Uj-jwnJ{{&&PG_lQW7`M;Rjp9&0aF)o3#@`5LeQHpoPSyZAsRP0qu<+Z?3 z6K8n{+$4dxCQhyU$oECZv6XuYOa{?SEFvAT4^nt@v4>FLzJwF+2s_0#QEaXuE+d-Om#q<_jQ8j?Bf%$e^Tu1IP!_7cgQLC630Jd7O$xm&g2f6 zjP*Vd{2~5jAP3u6lM=hfc2FO_Mh+wL`qkvJ4#u^X`|aiP_snAzD?qCmw~4%FG~IXi z$Pdlbkk)heW902O$!@&N`cEh>EczYlVfPt(0+bX{&;@E_RmpP~vv${ux{t613y`t) zh*$=!+@Onn`WBigr}@3&@U z=`oa|PAzw!+IF9s$uOyv%}=@Nn04h=(ow2u_gQzN2a-B!CFo%! zZ;g?S(0LN$+pSr7hJzaZdahYdg>?cJ=@#p-{aAaM%t${-3ajXKOJK0v^G)B(59$uWO3H^ z4~jw;*#zo|>lEZd9WMt$(=Kpg5BHxSEmikviS$YMv#NJQE1{f_?CN~#2~V`g@-fE# zo*v3*ZL=IEU7-H^jW1r_F83kN|ABRxiOOtghpLx%v6@(dcUzzr32K&al*)US zZJ?g~Mm?boR)$Ka)N!haVHa_yx#(CQti}zwq|{JtC?M`JijJ1sBJpXg&+LP0$K^V{H41efZ%0lhpFtPf zTuOsVlZ&k3TlmvB+HcKj0>8pd9hpC2) zk*>)N)Mxavd0F>RSW9@$nrL154VL3J)kvc<1KGbOx2AK@EA?YN@s2tgiCu?guR%lR zpb;0K$5lE#%yjM~Q_cNTzE9=;ko<%StzG(Do5i{@mVuh&OZjKk$l|DQ9hJUO-l~mw z;sa``b&(b~HTVkBet9@O9ByQ*25TWcYUu{qBwwQr{8XLBH;W+^KhSCR51JVvy^|f< z7I}vjs#TU-V{`7w@40^)?VfCu3)D*TGHTXCkwyRGSf!~v7dx0pUH+oHQyGjN zoTVPuL8>WNRNBa0s9;ZEB#V?`bWS(W`pCDbA9iA`=M5Thg8K4*j53anqG$MyGTKJ< zB=wVU^f6ky#rSkuhHtT?SAJ5ftK*4T=3yZNSv_iucep?$x;orSz$)dTjrY;odr~(! zUHt~jw2CU;0laPsf24f2~#@(RZYH#T4?kb(5&-018SqhW?f+jggYzcWOzJ3e#AXdJFEOd~*md9iDcGLH>T8fZ= zNBZVd$MuufWAR>7DLsc(o{3et#F(1OuUVz9A^#zdl7pn{@&vh|v{v3Pv*Qv^ksyC1 zKY(jhSsS;=mE`J5CFKZ}!?9ABQV!kynhvUU$oMw80jj`3fqx^;&>A5%#_nHK$06t6 z)0)Qee9g*Z}m}4+Dz4No-_dq#NK-8G8U~JbM1p9*^t!+><(B!C(C3> zQyb#n!GPfE9rTP&kiOOCpoQY#j8Vo z$?6-Q)3;N7?%U;C>$~Z@@5}T#)rx8b-&St}b*Q$__d-4FTk4}jQSGnwS1nqQdR|>g zH$}2Kmuls7xK#^VQk-~YG~E|90IPV9~*Y0)h1MYtA>8=s3`K}CSb60EEDOUy8e%F0hoclM|b$4G+io2huzBkG{ z&U4#W)!WV2My;X-`x>Et(_HhHmQ;(w@}=J~%Pqe)erx?>{g?Yq z_uFYXWI1APVs2$hFm%zUE0yJBsif9T?dy}>`L2=9jgGMno!w~bP`It&kAgh~GYT%` zM_F_8CKt@HrsVZ4{3Cy!^|ZB8K~vjQ>jdjgYov8e!C!?z1tkk|^Yr;w@@iQpJC@jA z+9T{C1!o*-&RkzF>5@7?E3FQaZ|EwRCi~5@Y%~96{@0Qm5E0Zaq;<&W!6Sn%1#bvh z9=tuYSNMkT$07d)zYVD#IyGcns5LB`W0PJ7><#=7G&3k8=-4|B9>434BK{hPYchU{h zpD>m-$)@LqImYFbo9oE_MBDe^EHPAU&^S-rnq{nOD<4rv8>XHKpsvwMi{MUQK$R6q58;Qn};< z$$3eGlOLpio1K_lCVz9zf&4&QaaS>Em28#cb$PmG=GcHeA@jl>hvbCx3T+ntQ)G!^ zb)!~9)(hVgx+?T?Xh6uakRBnU0=ookHlH!)nQxeWwX8Np8*_9ObbIs*3=@spjn@s! z4E}~1#_r}!b928=e%<`aSWcUd`OWcv5TN=cS{hmYHcc=YOp6Sgb$#RlZIti5yQJ%l zBg!G!23Swz4bLu~)hhE+=FN=k)ZbESemwgzDCK@iQi}EC7s*wUGm_V*oX@J0yCdJ5 zt;_n5mtoU+1JLEutZa6YzA+Z~9}YSmxH;fa;Fh4Sq36QlBiBZh4LcZmIb>rLEjq?yZ^Zy5^Fjn=x~49$$|4CnPHl_N?k{aHg9Q;2D^@q%Hx zeuZJPv5xtUX^?S*VIJS?uGCb1l>H@}y2SgBE8UstSmk_duT_{|usPq5w=b_|?#aAe zIn{HP=3L48CVNqKznrZ(&$IVsXJvQD{w*gw$5|NR{9x}@a3}wct+=~{&&aAtdu^fb zwpva%&a~J3vpLEXX?kUBU`g^n6tKqse}3ovn)nU1RJ0_Z34<)3n0KI^t&eS3Hgdfs?`^(wxZzPr9E>RQ#n8f+AMBx2+*)l*(8pOg>DiSiElKHhA>0h$lrF2)^C8OOzo<9d4y_EcW3H5^C8UI3l zuOnOC2=2f-B9w((KZEYO8SJZ?D!Ogi$*=~n8RJy5Pc4#yJ!Hlkl*O!XhfGzHKLs#cz+Ds9ml8-^1Zv%L~eij`H$MmEmr9Z zsf>t~v`3$6H{x{Y%8V+KXM_{oinS;g^(l>*S=6ru8keXZw!!m@_43B_XS5_@nnYj1 z2zn0JkmqgV^*r*@<-`WR5WU3^a}TBxQjJ=pSQ*!-ZMmtpM8Z39a^7yZ;YS`MaP6{L zWvfJOq!_&zemp4}`q<%YL!K1JbHy1R;2)ooDnVEo|&w`(SaAhav8pktgF``CfW}DKa}_vz0*NkwvBdJOK++I9wMQW>ZFxd>uBypcOVjUFmC;sXMm5OxW2sI! z=(~wzw9TPMXC%5J&uGX92O^o@^4zx2Kqa>nh-ad!cNKPHC|2MMmij&xVi|lB$V<2g?dj`Gie-k};Th$3#Y3^aAdvv9CT)o47PO9drhrD+p^(!wU zI;OUjW~hUuVPxBzeFpiAZYWFhhLKjhz4 z)EM0y-$5nKSEx)=7s?tr*=nrT3?FCoSbhY6B_31f~w7}XU{f2&gco+X))30g!qx(l6^lVC5VW}PI7re%r>-bB5O!?E@nlSa}ppouc zvZdg({-IJ%-Jy1-N9jBN?Y>5le*^@ZBcd7>o-w!co3G`VCh7Z`@0Q!0TTaqj%{q1Umg2M@^5~6mnH9?mDl#Z9HK8(|JYL z%;ImT>O3HoD@-e>A-m1Zl>z!5jv3l%XPlvt#bAhIpGSx}#E=i&pGhgE9?B74HD$E) zozl%#T<%7_C)?TC`M0rD;Z60mFUOjyobx&TN_#A7u+8Z)n2a}DuXD{Xz;zR(2ww5-AQt@_@+9o~1I z-L8tBYp!})p1jE3Oe>Y&&k?UX=4qQ!vbsEMIk1cFW38-3RA>&tPkc;XhX?=M>v8 z=T5npu3XMSV@l!I&fPAb^qb+TeI5Po52#8sv_uw2`lDeTbp-|g6{Z^NNCsPZ*E#K7 z)@%3A_U-|Bd5v@3&K^)pB%;ArE$^b!9px>pHQ=Dxj@=_?GJj zDg8_peKXa+btcm#=PTU~XGwLW$K;*k>*?9T0w;L#P6OXyenf?Y3t5hj1*OI<0$k4n0#8k*`4nrrAr8eRGa zel~Wdm3NQTUC~{(Sd~(~^WO2k?ml)>N_7nJ-aV>C{at@RY2w`?zfl_LPaCQzx#}V< zQg>KiPk%)jXYi8~^u5i|`g!ViQWv#?xs5*2V3eL4jLKu@Q7u<$>+PdWS2lQG>n7^& zX-(xX3|Ey~hAHOD$`tuC{bggkf2tvqJZ7z~gyF1ilrL6Vr-*0p*Uy}#n8>blm)HRTQ@r9aRN#Dw!8UHm@lAq}&$sLsG ztQVx|mng%POHwagM=Fpy<+?mf{!RBYk@PBcoz_@B>>a9nrduJO_x-JoRClW#0Lq| zsXOPrrW-1Kt;>|A=)e#&T+~g~2k2Avg=!T=um4F&)a#{9%4O||^sBF;JVQPytygX; z9hF}dlQK}(Rr%5Iv+f9dZDGtcx%Khp7W$pWEry%s1ijvPLN`#qPv6KeU!SY=P*mSs zDc3X5H_)}-`5zUUgRaVsbmw4}*^d@R=bQc#D`0U`6k$C9^yEH*cKdu+er>ti2A+hZz5uPxRjYG>r%QFV*OMf!(V4C@y< zF8E2{QU90bT*Gp7;%oO=Usc~DUx=f8{`!nn89CXVGwWw`&AXqyCgVj$vyV#Bw@HgW z7JhIiSIIn|Y)n4%s?@7>4_`cre(>eRSKr2zfZdM;ZX9*kL6Or zlZ+qUzd!pvFVXY%&$p?Gd!OHWdG_U?H|w6JKHdE|=+U6t(z6N$^#dmb2bw>G#YKg1 zY*t!{v9Tvg&WNdAdR_d7m@NN2uG0lS7gljB&g+?xm3BApnY5{JzxP{b4{2r4n%K5A z{F=F%3~N@q>D~I9o5r^uQ14Rt_LZ*1jgLMNbu83h|J*gtG0aoe*G}Ew{mNA%SD!UI z>uT!8qLN}T5fF@-rU;Y``WuJU5XtPl@Xrf zKfNM zmtXwvS%pWx-Cy#c@54>cA|5q*S|ZPF@yheV4#cf5^G%$8Omg(vnDGIf{SE}zvs^Zg z*WdRZa?Z)0kz#&x^L4AYtx~0&o7O%)qkft3TxkFJ?DCNnDmD48=|4^Ce7?Iu%R0@f zH7gNSJUS>%-&c21z38LI#52ehE1gky`u^3QWDQN#nmRY<$7GF3UHM_!tN*>Z^y>NZ zS1)7VG)?e4-Sy=3qqzIY_a5IL^5EzL>CFJoTfeD(p5pi8r&s*3{I{h&#dE_?n}7G4 zVLGS0maljZxQaUt=O%wBe0lWs`*+tq_DuKYPVikaH3{upvQg!@8og_n8|j(_v{=@t zd(B^K7OQl%)Q<3_rbo(p-$~a&#|Y;`$1PVyskUy3(Q6Wg)GklUfT4aX%W@q?^H{wo>`@Rxihi5 zLYtbC^^fH(+E1Q5SAyeaLFp_dS$k9Q-HxR2hrPb5m3m9p#yH2|HlD+C9`*IHl}dU0 zR-bY6MeX--kE=g>{Ak>(4NsRQ?tgsge(Ps#Up#-fKcP%6D1-j9OKhn$y87^{&KeKP z|5G9<@EctZ?{QnI>j&HSg*)tbZNB{0DgBeCCZA2({-I$;VBTua?to+Qf7cw}IJHGe zizh8kHO#KFBK~RgiekY*^9)}0#qRTs@q~Lf+WQvXDeUAtsZLO?=^p4_=;xYW>*lF@ zZ04N5k{iBxlyLZQ=z~@F)dzK-97#Cvbkr077aN~dd7kk0ct)`AgFd#{pH)`YK3um| z-BZ;UmMTUaT7WlrO~e{z4yo1Q1VK3nyCZ2;>DvXJp7@lBx@11J1=RM7ym>r%OmSIVMlUXc#Xtq0hUtTxQd(+*pz~av;wXN5! z+1Srlwc6dhZT%xvrZY<#&A3Rc_B+kbX+aMg6DdsetFImfs< z*pKCp%Dt7fG&A<&wD;4KYbPyuXL{EoacpA#n;wbdle*@(T)TaL_^pn95x2F%g>qxc z)hT-~DlsU`R9Ve+BwIhv`!Z)*wlDLm^j|V}WFE`-K5LQnx->vvA#iQX*^1e<3L7+Q zme(YrN%?vqRep`LMd$$-3NnGbb`9 zKdo-&vCLf`pQrcDu8~zGD>dU(N{zHT$ze(HsVP~if zll)dGgMG5Cbl#{OXLiZF7P7F}ZC=;(jrzpdV@jPA$~i~A??d|;HWz|+am-P*fgyVYO|v~Radj>i0R6)w$vp8Zp1^|X4)rjM&q-X#x6 zznh(t^P}}yeuez0Swk{^$yky$BkP9U;Q3uwC#Yx4z%u*G7M3kmRxVK~tfFPI=CrrZ zZ*D1*f!crcn3*ul(&8Z!Vj02U;cEpxwWd*zg7Q5-ILX7R#+3e zHDZV*K)K+#U|(IhvhXY0m-bEe!HxijYP(n9&S{_dw+^w6&-*~j>aqM5t^`k4L%pD0QHx4tm9AX2aoqXnmmv#Hd1^PiGv~j|-1LZy zuQHcpeV04Rdchv!IUy~TTbWh`bt|^AgcARG#gWyA*Z!mSpEdre99}M}WMud}zh260 z_X7K)!XbrS3Jw=uvE|yoa&)tm%kQ2&HSIvsiMKu8eD|tLVzc*GQ|o7X^Ntk;+m2W# z6!b4>pK~xb&=Kw0t18AF!GA?5v6&^emYNi^D&lS65Zx`0%{n9ZLV8?UV0!h81({#v z6c+e-Hfe+9`sT(#10!Q%s+2xjrhetGtG%ptu&Slj|0^ZL)1xpZH#zf3`jd16>p)4lb8IcVyM05HE|$y? zDRM(h;}R>&zO1;ia*Zl?D@>~pQg%qG*TrgvZnvz`pVUfvKXbjbFDvY4>u9@Ga44@q z_MnW{AHR7&BQfUn$CuCEZ2#a*`IsJAFudS18jd48ADybbfv=}x(swZy1U!uRHu_+6 za`cZ;i^BYaUK*(|#E_b!ufVlCcT18(B4-33!9;<(^P4ZTCrP*HDs@h5x)XbZo6_K$r<@|>h?{>Wz z@oG(C`1_h4kES)vw&q{YA7)$YI^#HEzvDCO78|=-h6bh=vq#%v@}uuY=Y*XOQY`DW z1XqT&Yu?nXZ!><#7?rs!b7ubc&hj3IdQ@2+5D>Pem?yGS$^LP#${me=U#48yC#64@ z^hDPUFBy2nGDKfV8tonJx?~$_`>N1mJ(u@iW={I_6kXEZcL8rQ-!yyM>BFUu&P z53Qr~$JjTyYB>%&Zo2m?3k`b9=zxN7cXapIzA?t=`4Ps@rv4tKgQth>p>;@ZiL8{& zAz5KL%?nODGu<_{Ch{o%qrrniH-tZpu2~{J&RaURbhpx_OZk=DQhZ(bUqLnE6Xio&0#KpFPC2!jb0~ zp~mS4nr8$I3U(FK#WXHCJf>6B9}y!$`Um9cOZx^o^XyOqbmkYzu5Yx_x)jSp|5;&% zVs&FnmpT+PFFHG{e(3ywbbX|H#re_s&i<2inYFF;@4`E-Bko4p1~u4FDWG0}8ayZT z?}&-T<(TqFe7P7Uc2;!nr~zRIgL(w~ZtiZVq;B$-b7$ECZDIMtv!7)&NQ+K!e3N8 zEwG<4OdjB#=C1GBSJ=@uqVS@vmivi&30)9W4o!CRxxiKd<3o-|+=!eQRVFF`xeqTE z6V)N~U0?@la~XzjwG^r=n;i8Vz4Q0xp3R(+embX?SbVQ+%|4JvOrr_0pN`qucGx>H=! zT(Qn2o@CEAtcUodT*E%&(tw}*Y6rdy84=n$ymQ3K@E;-uh5sCRKJ;49pulUEdX}5I zujGxsZ`{{h2HRF^e$MplgXt5}lT&&nze~EGIw)mQTEnbc=__-_X6?#uk-sUwTkbPk zo};sGmD1BR)PGp;qwvD;Tg6U<*9)l{aynp@rHb*K5~)n0_oRkaS3RkYVVCC;4zL`e z|H;tLV)QHNH!*Nopnu3ep?RTcp-VzLhjk16C3sq($-kOqpXopSx5^rIxvz)2t|Q16 zp1&z~Yj$1iL&b~{DHk)lX5?nn$-0p9O_ndKPWH^4GI=L+PUnAVS6vO%bGlelPrs2t z(}Guog@-%~+!b&>u#Y+0*uxmC|3~Si6tEIlN8Tx$b%PC~^(FMG=`+g|^A^7$0r&i; z1#Ax55gZm|3VsrFH>hvmdH+j(@67)itLxj)`MJ^iyL-56relunQbE1Kruj$mf46qc zsZ%gA?_Iu{zp=oY|5;w~f=$_7^QPtP&W*~wWBaafkt^4GOj{uj(Y+$Z{mryizueTr zFoh2e*gJr1SZOf15 ze9HyXd82H4VQ6i*uI!PgNOxI{tn2IIUFh2CJm~tz5$e3-9OszgIAcHL_|;a+e$O_& z@M1y#{LR*ooIeXX8hor<)W#fCBX87WtUlJiMJd#8!R);3oXseE>orY>nmi{wvO31u#r>m8Z@*ahUBL>wEB9#L!@`w0 zm93saf7>L7zw?Rvk}JaZz-v}NQ>W7{*g&^R`GH!{cg6zqF4J>XsK2+Q8b6v2nfsZl znjL17d7SBQb3^lO(>La~ra<#5qs`dU7-ZaS_{*DzU8gSVccMzVZ3gfYA9=dso$#W59ec*T51hrrWEFj)@@h*bEn7+ zye-u=Y6q{|Tj-3Hb~(n_5Bko!uG^OvR`c4eh1Nh@TH%X=*ZJQ#ZWL6v4{@xpy>nG` ze&&jB7rNp+zdPD_CGS{IH`cS`bZ2EvSKoNnIK~)n*h)9cJ^eM~cEeU@HJ#8eHW~cJ$ioan{8q`|8|LdCnd_=L<%S%qeekZ4H+i4< zx=RbZNgjtbO*Pu4O24?9Ip=sRtl>|v*7mk`bSy06USr+v>S@1dJ?k1`|IJz4z6P%~ z&{^Eo)%93id|ubznoVveuTfN8Eu*HJt7}H@LAdd_sjqpTG1a)oINa3S zRNQjNxYT^jZ-`;N`IzOgskz0idud*1cFHG}`IbTQT&09QPjB*7HRz@loEu1z>0f6xK$F&~|EiPaaq2{R5IRXFT??hXzLD;r;T!!OLv{LZW}6h#9=&WF zWQ;W5&~BKHn=-r=EopkqTfy{|?yT;f{wHa$=_`5=R~qlKF1$jnrRWU**mKN(s1dGZ zy07Vi8)6@*J=2BR=BdYA{&XPzTi{bls<&O?ShAM7agLVm&pnlmtDMjB8yl|ZzR0h^ zs#R0zgR{G%lBtWYsnhKHMoN@hy8q`^rQ!bZ_F9t9@2+p1Yp~MG*hahI%riRVnObM1 zzh{PVtufrO9ewbC46<78WE!ugOAg&GgPs0`{iZ-esAsoMQ+}1-sr|j3EI)Y;dhg1+ zS>u}R&6jJ?0X;)oVLXaHfwQCVMA_tv{wZs^^s z4Aq|c{$d@jk=j+NpfB&fE0xe+@C|U8ObdN09ryiCyZ^P%vpjSxl=WHz?=Si>?iX%@ zsi|7Slds?}6xlPDwXt1>WLI(Nfq9g3qkLXB)jL+XCbx5Do2trf=`D;^=IN%o!u&6~ z=E*zt|C4)qy@rjtwzix=t2fJjEo6jykn1!5o?2&p@F1IYDLnq4EM@n>6khl6A z`r7vTz8K5A0;4(2(M4Wt{C^ysbGTed*TqlOVY~a9!Hw;SZF^$dnAo;$Pt1vJ+vYXf zeO%|;@ApSC$&+XLcK0dl+O^hi+2zu5XVXL_4s|g)YOlD$s!5(~e;<0G*I;S2DKkv} zjRMA8x-`xw)#6sN8^kf%dxuFmDZSulno``4gbh?4Q-ffA3safiqrD}H+|)v3lfKKe zU0H_0jn+&jwG(Jj%2g+0vl{;`fIrieax7v7DoHEg4$QhT~A^2LTNYZ zZd#Fb)drqO|L8Dp-gOrqG1E{XH~W3W6ZbCm1wY9BDDq|M?a02m5*Ww~_^!z1J_S)heoO zg#OAZa01LCS=Mox8^dn-S6FWO+k?jS7B!cNQi^NU?RSEm$ao=(S)GwTA;a^Hd9MCO z&8P@|ttTTw^oIBzdR??g+RLw?&xUg8(rg)KyK>0aGPa{P%-E~8GVP*X2aa)v>1E;( zcA6Z88G2fvO<2{eo6f>yjPJ3@%V3|G4`X)ePo+xsw(=XPt1~7uK4DPyF}9e#*LpwD zPmGL=3O+Qhdpah?WcE@w+p)OF^dN6W_(;!gp+6la6Mca>O`jJa0?#g0_xhsqwMn08 zy(ONa6?U!9ZspajbTw(KqmeH^``FqhWm?2K{h=&#<;5DNeMT(Y96w`qcY&}XewF-X zujPKgt>JeC9PC{xMV-r6_cRQ%YIm~F#rFstk&9SMsD*-_nD2iS=TOp1Z^{+UDHGe2 z>G#@RqBwr|gGyYQ^N1T1+3VfAA|v!)DZ>ko{W7P}f1Yu^2Jsg%w}s!!enRKwv%aky zo2J(n+a)Bw>RF~}mS9YxhpUR?Lq_qi`RZKrY2`rPvDnM%Bi~fUFs;Oi-UH~oxiZ%^ zTVt&DwBU7pv(&}XojLh)OTPO`6)}k`E)KNaREv2oxJIe#=qkz`cQN96cDvvM*EF>mT_I~;Vqfj4Z$5s=HZaY|i8Aszw`sqEPt22) z{*Gr^ebkHLtI>E*4_t(R()E>E~Y^@j9KSvrek(rwwzRhtAu9Zz8+*NXnul ziYS$LQMzJ#?_X1FILOYsin#Im`R3bEZS)#!7|2b5pdnjOQQ9M+c1B zp*iM0Qb+m^vt8oIepBPXKaix5ADN6z%;CUa{5a;ejD?nga7SM$LYrebADBm4o6Cr` z(N_Ey+9-F?Wg)2a1w~s0P(9qFK53oNT5+x^Rj#kDVNc+<;#W3>tTv{qL(p`*00sQ# zxMDQcUXveGJ0=v2WMbu}w%1;fKA>j^gVp!6Pn?fm zSY0F^zC(PZCZ3@m=AL_g>4W%mJs)YL&*0PLT4YLSne}aerv_@@cwp#h@%m9}CA%_o z89hW6xdG=ByE8X|c2-Sp$hIalI;70ko3pEx{BZC5M)#vcHB!Gu?r7WemP!+?ula>I zpMD@y>}lmBHB;?~hA8{#I3N|9*m`80-kBX@Yy&?}L!6A?bvOk>l2+=s6LXUCHCN_q>_=1Wj2a`T5sEIpgG z(0x&BTpv%=tDq0Y9DS+w4LW2ygZI?o@_OZCXn-W~8a1RQe4wAJp?E)8AntH_z&G zVP_%7QGR>uHcOQCP*@q}fG~;Xx$W9OrLs=>It6-!iUn%s{ z^+f$2wF`Jp8i2RvHZ>bvLUric%x3x$E7M-OAGeMT!Fe$r6u_~}5PAfAoazM%{x?Ps z7)6lVN#5$YMl@-zuhS{Lnr>E`=r`4kQaSC4|9z-gz?K^pn3l3AH}JdN@9y6_zvDkN zeX{23-Qbqgj=3s($do@U*R(P5c1&fvJGx-vpu~L9m7~={aZzf%n(;}z*|v&k8sQA)AU+2hI))*LA94_%+p_h??)xYfD$)Nf2iKn zBNR=H4^0bh@ZIyIdMo6ta#zcolKwn5?PurovEQG3AO11^W$*Vra)~=7cAAO5dF| zF7tcdE`Jwkxx7|d%bpX?3mYPygmaNQBkIJ{30&C1xB+29ERVu2n5%M%u?%g{mj}LS zmLMmt3bNu9!pPgmMo_)d>^(kB7|IoN?6dE(pzuA8d-ib03+n?bYbwN50DnXR1uAN4 znN|pG04i@WG#!-2IYu8m3n97)`-pBMv|&4%d-H$z-~3m;km)#!xJ#^rt++7q=&w|> z^gP%m5as)kcQbEq)~~$d>FHUn%wZ{wGjIGTma^cx`cC>#^7Wsbg`omoFRJCRN0m${ zl-Rn^p_p3H732PmjYu%!3PfFsE*zN|)ywSGl8mdKYMGtWGm=ODn)>rw>f7|v>66^= zvW9wUOHs-b&@grqK05Y?FNiOeI40qJf-N6T7#sUC^0uR>IoqrYbpdQpL;B0ddmrVD z$$6SRCZnFawr`$TK>JPvx{y%aUeI~gbuY>m-8r^(zOck4@q6OxN5zJ>bCofb;YJ#~ zz7YI3wL`~(6+@#!MMICpTHSZMO9>U$W%E zxwbi*kG+jwlh2x0`6JB`Hv}$t*ZIo&DtMCp13Wq2A>L2!zTRuzvfkZ!(ZO4}EwtEQ&DX{=G4Fxriocz>TOWcH_KB&hJvyvO_^*gI z5x&TQQGKFDN0o^960zO6(>m3hMc+re)cMLZ>31j#d`he3M$$Q@t*R@p^h!jaR^fl? zZP0frY3pb`U@vE@WW8V>XqqN`rtdQ^@E*ty=x*%O`{{3_Ps)t8DV;%J=+AXtuh`Q=G`&bCT+7o^>_v!A~Rgz{j}H>0f(hJoK@F>Ut^5 z1$KCNCCkgOqNd)KPu45er08#9zoUnTU5%<8W)p6qu5wwwCHq-cQbzU6R=G{{`g*VU z9|R%-N&d%Dj968TB9obRwnk1W=5TbY*kJ65$lnn@+iq(t&oa}wuk;yw8M;f?lxo3I z{#0#)$R7?IxTCe;k_fs>}jYL5wa3LYvRNLaW4L47=DW>vfRtR!m*?EkE z`G`tjQ9q@vR*u5|`yx~_a5C66csf*06r|&!S<=bSN%2^4UvQT=(05BcA?6FT!i{`y z#2wObuNb(N(>fz2vrNjYWc}0ilp<*lq}F-a+9j04zvK&AZdwnSwg^LQv6jfF%VBZh z=LKBF|#vgKs`ras?z zfnD?n{vV;T<&R~dE#3CkI@C7WTEH^LlF3J!%CX(pPv{~wm2_7Zs1Z_{G&aPCwui=r z&cP=bA4-rG2U|(|Lw&@^&=7?QdX?UxH{x4KRdr8wGBUeLs9*Nzob73I(gr1;`%yDv zmQvjv0ovdxw8=J^tz)kx+ygJ-Aj_yQkJB1iGwh~)a+GSS#9YPC)OMjE!L|M^{;B@K zuy$+`>r2%{S758WFO;Q?R63y~_L-1nZ(^V3=;^9&>+W<~T3dbyli9WW4fYtTq5a@f z|0!LTt_C{!$N2vaSj7C2U9X5BVTe1z*E1zqURaA+*SPjLhK9X#E(?3&*k$W))A*|F zdUlcC6xGo?>z~zk>N~BwT1S5lUZ;uVwEhjaqDjn1HcF@~RJ3d|e>Iy;dre>XBEnRr z3D+ELp<99Kc(8s{y)D%h8L>xjuLLj_ks8PuY-GXe`K%n4EfzY@9FokSx)VXJ4PHz7DRm7+>nG$oR0FOJThr0N z+QS}nRI;YqADSv!uL78s3Yzg7eY{{*3mkW`8uwNfF2Ze8zOP0|Vr|pm>U@l`R!DpG?@mAB{Y)#w&57+l- z9hFGQC%y?)l?y7r^{eDAxd4^oaP|$~m)~tVWzKIsXq|40voEydSsPly&0V=8Tzxn( zZ^SB@jdqhRklON#xQs*k5Ab`9GID`^;-P2aR!leU7gLV^&TQm9gS+fCTaEt4Iq?8? zKJw#7WD2>j>FP`6h`dNzEAJNFk`$UE{uQ(ZN=k>l5mHR(R*=Z6#e))3E`+X%Dx7g= zdTIuac~9o{^zX~uoAcd68LvF83=uJuDpY1UQ#12P=7!~+FxK&xHJ^QeV}`j%gv8wE zcC)RFeA+v$x{?t(DZ9lvawX*-iIXOU<}0V+G?t|HGLq2&dX{k4JYU#rD{Q)H?QK3| zwwea=E>jx!g?q+e`T%MT%G*!sCgr9yL^-0@($)iuy$rs=9b{W^ywII*$QR&Kq3F^q zWh_&quNiETp(p#^~vVhLXhsa3!&#s?dV_RxyZ z<6xpxO4JqXKdhDsToz{s%6a(U(5&|Ex0yu)$S9nf%a+xrFeZMdvD>=HSRwq)tmC#g zB3a29XMSV9Vb8)1`IQtQAmLFPg2$k}bWxtH%#ioXpM&4UjZ)>%LDeZqz-}Fhmz!>} zwB?QPov&@V%Pxhrwk6+7$Pcq4i!RIl0OCO#ql^}(a#}O(8~IJ?xB=aRt-_69O0$od zPiP7=hdIW4;5l#@r}HDgv$~9?nGJAao`D}z7wHc49~g&c>3(z*@Gfo!SJGT`fl~2D zIuo?f)z~l`q@&Rr)QNK8O6U=g9d=UBkxn9w3ffw&lIm1mC<*E*rHE`-6nQz!xce}k zW{Z==--=KCqND_Wf`ff`po?U)XsLtoMJ6$?_-N|rU?s$>+BiiAxFU_|5f6Y_o^z%+7@}Y^h_QnCMi|r z`RYwAj!ebt>6ZLi&dd)r4do9B<%KxDCpbml3SWiEg3dSRR`Ipj+c4{9(Q}#pY&P?a z{ld(qIogiB;B5K?M3QEdn_vuNTzuU)N#{{dQ6vpneRLaKgZ<8|V%u{LZVhK*Pq3Zn zcJw)106(P%;6?BqA-Wr=tGiP5NkM83%+ix&9;4Icg)Vmugcx!fc$W zwNl2Zf7D;fB6*g&TO29}LlfmCP>Jmjh0s-~&Dx2pJ(>Q_?hS4(Z)DcaoRc{twF-GR zv~`rBcjJ7WMS;fHX&0rJ3cK)M0Hl-HolzB$?ka_bm0f+5B&z7JD34+*6hTd{4_a(*3T@GMn6icVS_`ffql^oE!Pt0 zqUwDz+NiJ1)05S~Mp2bjbgGWxP=*p!o~LzDYshugIm!j4iFikSD^X&)_Ews%tWqjS zVQRK?Mjb5G(Y}h^6~EL|+Nk^;>?9{hP1G&wC*>;k$-UV%`VnyXPS#lFw&r0+8w2q| zW<0vYzBe#e!5~oKJTb=M4Wu;8lFxJmeuB5*fuL1Lq&DHU#yC_3l|(<8nkW(K_#*T` z{1ks<#$pL>8n1$HXA^$Hh0(dRoxX$5G8K&V+;$wp+@P&Emg#~j(nF~LI|FD~^Jy18 zLDge+LN-w~dOP(Iy&$FeoAMC+80|pkR2w&*ybKj|KMrcBgk z0#jtLc2Pc0j;o3KX>kNPV@wQHG5mUvSfwYR`j_=pdSf{SIrU+3Gh#z2+F|JT9MoE) z=~^kJm@bocWWD;^XoOE||6_plqWYk-`Y1eyBys1p^88)>0$r89N8M&a#tmE&cAz1& zN)9mpLIz)cbcNo7no>vb72_+aiz@pBm8!xgy@gVo>cel9Zm4O@bD(aYfE#}3ds3Pfl10-GX##i{Tak-M z4vsZ{@bBUS>Q`-=J{PFlKD7sZUP(Z$RgwOrY(?+X#awSZ*Z138UhKs-3mv9kvlZ3q zxV-Yw?Dgf?43Cpvq0fi=(d$yAt8h-NkVZJVn*0#m;ui-_0B-LdUXPAZo5WJ4DnNKi zH@a}Zj#r|E`_e%Eym(pIuQVcYCS4l{?3I$#Y&>4MiK6gOm0{N?ZG-_@HM~ie=!f)N zZyU#Fc^SJb)ZCPxt?!E#l5iYXMqGo6aHE0RF;?qc_fyMbM(x4?KtIG#0$(S_o&FzQW9F@of%ceS4;U}^SOsbRL>@1*z4dS zH9EK;-0OB)q(B3_)cnYoD*P)q6Buiq@t=yTcMx+m(rcJQH>cvZbcpoVNwTN zOZjXbCqKrWwIx7U8RDsDx2E&cUI;4%^D?llcq$SQOrXj3aht zLGHG~*FJxa9244Z%CA?<^%t{#RFmS;w%e96%QI`n$E3H3`kWQxYAaROE=O!h?G=-i z^TT;Qd$`SDlJnX}$|*$(fBKpoeK)Y1L_4qkX_>GwPy-FtN16T+bt9MQuC2F3q>e6> zm|2~j&kxOOrA&XoqD_!v?U&!Pj$RBD2KSLTASx8gD5 zFy09okgN1bV;-8R?G&7$hHPmxPHW7L)oxG;I8qtve4W#ii?BENwYzY^)N0nxxW79r zZbfDhi>&8qw5fUS57z`^qMJCa$~TH9R(do2(9_g0N{Z32`Lz7E){gs(@&~I4o7GJF zi=4*XDAU8>d$p)_uzH$QP}=a{XvzK9q3Un^w#c;KLkq6>qg#4fyveoVa|11;ec|6T z@7oTL&E8Q7xjz<%cSS9-SLPq`<7xpu^CEr~;s5%s$k$(`ot?QgsbdR-{~ne2_ScN~ za_;B+KCv^Gh9iT5vqx^X$Qfx<;~HgFGY>ZTGMmO8P1z7vH4AR13s|3Bs$Am3x42mG zPwu6ni{AQ*?fLq-;Qe%(Ux%~R-oRpb$4vK6HWwxX)PCF=X|!3*5#ozxHe{yRO61gz zxSqa0dWrvw`mggqR%FbCtaxEA-`rQ*d{%9T6f{huj2&E4$;nky0s@gPtJA{gXZ*0{ zQaNHXVS@OQKOn+31AKg|-QF+q~tGw{j93l>&d8+DpG+r`{@!WLDym@>+T^m7mr^%lPHg zcl|Z^f-6AQ%XiWM*~mFmkThVtdOj3Re&Y^W6}qa{hOI%O}sY)hC@Pw%-+zVwXW6^Qo{`I9#O6}O28YHPOz2Wwz zF;!O*pl{g&5^_vjNI#~8_}Xd@vRSbDf0|b+J8+(QiTgyo@h968#AhfTUD1~?o#di? zUFAA1YNVij>QV-n^J+4)$zb*8u(#J{oA^uFmZ^utSvJvE)-fU@*E|+cUJM3X&KEO%hdmrgSKaxd-IL>-X(uT#yr<7 z_XpdAzy%Xh86yCW!n)i5_m1cfnQq${rM}S)?IE9;&Y^!zrQsxAQCKf^V28kG-{e{2 zTJ4=JOfY=%1oJEJTt{JVRa+Uc3#b_)#O2&F-DMP`6G&%c2JWl97an@Y+gRlxo+C$D zNB9Ss#z+TE)<7x86Zc=1No1?iNoXF-u;ugiv5oatHn&w8(I1pV6L3`s!y+9hn!r8V zM0Vpb^ z26bjysc#IO86jr#dzE`wF?!&(S}}A5=oG*H=dGj#SaZ7=FX)N72i?&QaJ$rEdPlyt z@(aHt2wt7DzV!DLW+@h)3tTtNLq|eqg{EYJUYEMWE)~0QTYxTCNy=|_8kYmDoErl> zIZ4^d_EFz5LzPSX9I>g8EZt;YP;(TA@IqW_x)CT~F{I=4U%HoFFGu1`=7X5Uos`G$ zPt{4}CsRj{MbT;((?_v7Eoq}ooqPo?@8rAO>d-IlA+=cf&V7(H;fLtK{gB1jqqk-5 zYd>k1+{j!>9LFpJug`pmm>fzvSqUc>JBcS|)6!@_0!GJO;NI*_H=Nz@tRFd^b4Tm6I2_YBt7MF zRUUVM3<(__)rK?UwN3aYQl-o05>iFus_C;&GhbISflfA*I`~<@ zsM7XQEAe(!rF#SO4N^ep0upO}CthdAYZHJ+*j;Zz&(mKct3e@1|0E(Fs;@=kP#N`{ z@HE)Xv_btv;<*;m6PAV>#c(da(u1w1-D86CE^e)2!i$(*0k@-w{|s-}I?^HV{WhkL zO4o$r$~1I=ejh4j@8oS^?W$yvef<0&Z9Xpk$9C3-Lw}*Hn$90mD&SsvE3Ptp${yM_ zsL5NQ2xAr5OlPTHww|}FJ86w^Y^fQ_q=yKc024r z?+ANC(X5xYHuvta_wt!dHI3hrX0l3O=!^O!I-eRr=OaszgIt9KrlG)!V9+G|S!^Uo za2ws!ILzz=omF>I1-|8QBZ^GqUdo+~XDm{(@G)}Ucu2319x!)sFpyz??Hg@6D1V_3 zkV3$(&BTY~ejG9mi{BjUvNuF$raq5c4}OEI;kLA!(P#5s3bANZV5;M_|0&H=@6a$g zV0stq2Az|YdH+Q$b{96yQI^qXNmX^fnGek5jzflFMWY~Q$y#`~G>o<$^bT1NW{Y#p zIo=DlGC)U|ZlZ$SgyH%>S{U;mDM!`O2eP--4X8KCz(2Jk@Ka{!FL0i=#@NTQN;fja zT++Wp$WZ^G&SO#9FBtMWswQ5jO<<xz2WC-uiZiaG*V-70W1flW+`n3VV-_lBOwm(uLF1h| zfxWK|Ldp6YTuK8Y0?8n4Q31UHyWU+)#Mm91!ii7rZfFy6P=$**M1^}yd>dxl3JEI1O!GWr~y0RV!%4;W?TlZ?hN|0 zHVc2!e;E^CJ=@6j(VU=L%2xeAGtB^Y`US4qy2ALKOhK-#Q;?g3#cAG$<^ zkyZFLq{QDu6(PkklH%xT+859-7lIMr9lg`1QgGe@0@!Q1k>1CsKsV7_QKe8mV=t+P zx9eqrzjNQX03NIJ@T3Y;E}(CH1r}^?;7zvyqG=Hz(@r5lNH*_nJcI;E`AB)A5b8)qz^lIma(8pk**pe%cunIN z8bKfj0xcu6fo**flFmmW4LCZ7=$Sf?OXC>bL(ik8ky-Q&YAjSvt*I@<&SvQr;2H=d zm1#%{qK`DuX42`vZ(fM6P>V@3`Y@?P-9;k_hDxF%`321G0(x^$gDljdaT#idb^(VO zm0`Bdq6%n_&|%b3%|Lh1RP7J-9;u*UaNq(!bc#Z=(GF4%ZN_~_Rn(U5Pg1E(%y9i0 z8peRa9Qxr3e3Ri!VaSzk#oQ&YsId&3!BGpQyxtbYu=kZabZvU6!ZDlib)_eEGTr2- za9dJQoh39JmB}Be2Ih)BgfxEf-gY z1iK~ZE9$KuLLPKU!)PL|q?*uj{75MeEBG1hqE;VtOQ+>!c#`o|ZVP-OuhJjXFAKGf zp#4tOoMgBWPd@9*jIBVRvjIhE4Wt{y@zjd+GVWUb=xtkaQqGIjIST#rT7) zdNt!98lo$pFp5sw>bRt_l6!Ts456J>b-)wRUlDbQhpQO044gTB&k`5n0vTrM+dhOrRuh14gk0)XKS3AyfotgRPMh-2i$w z52R<2G6R!7nfeKQu?xW0ii6jP0MgoicuE(bXPX0gzRf^k{esXy{k}|2kn-dxPzc?6 z6_TPa(0}VM^rqnM^6QMgR%eM--=znDMtDrG4%`6`j88+KNfwdS`n z+XCxpA@B^e1UBw*pqN`x2^0aOp>Z%$f5I)$MP$K0fJ7J$+*T17M2Y{SPU91(10v`R zJee5u5%`sJfKWXYK54!GCyS*6=_?5P&vYPpbu}tL9%DAK7ykB{g08r|}!EsXyO-D0ed^`bGVks=6_oyfIyqf`=@-J9(x?lm%#l>(q?4lCl zu@lDe9i*WZz+3zV&tVvlT24_jU=3RetEd;MkW)Zen+(*oj_@4g;gkI@Rjnbg-!bs) z$3u;h1c*%!s)nkg7Vvy}!Fl94u+a;k ze_@0nv=^x5523re7m^6p0lDxY@bdZqm&^$yw->-WOMsoI0qh+%NIbj;-}N7$sXc^U zWC8qlCnT6|gS9CUav740V(_ax%-EP(NS|2>?fX6QmBj zhE(=uaAUp}vK7j~vrHw&q1HP99f_uprM`@6_W$fi$TkCVUnR&LFyI8o0;{|RItx{O zQRs;$k-BITwL&k8+fcA_p~Z;S5271rk$#qXfa~j@j8pVn=xyht#{p&NUwS$TXv-K` z4`}=8s>TzwCOwjRqP)Ps4_9j8Of(tXAonq))-a-&5~@!(u{LduHkTPet}7=QyD>v~ z$UqLT+LB&^X3BjTaOsE!cg?WMN&GD0mizK%@>PxzJnBRB8UIK+M5WOSQeu ze&9L{VngTkZP=-8(iM7@yg|3HDdIYEguW>y=wq4Ia$j{n>y@9&pc7EXNp0E7`VX-V zv&?8IWg;54lKU7jbfkJ)e}}*5%ha-TFDhN?#scF|y2MyfSLHF<&k)H&Me=^J9PZEc zlwRUG+)F8d*K?H2(`ET0@)+EoM{+hJ(URO;=xMap@A3jT*beXy)EFiP4I?Xgs~pdD zrsqjfLJ{PWYnq2Bp#G%qXvOh6U`q`}iMk0_CGGT!(6OzBr;wj&EN(;gsp(Kz_12%G zL0Ycvr9P^!sSJIF7DJbT=%WeLi&~{Vg*g=g=b3-> zU8IiggEgaoeu}Izj_T*hJmWO@DPs+p>?L7TKKT1IsNB-X1*)xKAr0{$SZ_k;Hz=Fy z1Cy$y@g8jfYJ5i{3gi7*fQ(*yK)7z-Qd&Xf;r#J7y@>?~~s z*AQ|@SI{QluFs`5(Ot+#`VbmnR0jqZgT|4Y^fpK`u0zM`pMeF2jptee+<}w^uGb^- zN_$`+!jfvb4OLR>Lhs5)!&pz=0151n8qI$s z*NtA7WlO8)xx1*JHpsL9s`)=$C-g;KZ)u@5VcPJ=wPExjK0+RE{iq@8y7`JifahF{ ze8hd&9PJiUmM%q&gHF;<^*4JIrHFqlR3OfDQGKrOfbJ5($8q^k8`EpKs#=&A18!!f zmZode6g{Nwrk|=)^)&oY&NU{WX3BZ$6e^@tVAhamG9PN%(#Bm_E&GvJvA#H><_Aqt%b2TlPdW|xnN(bu{(v9RpFk-+3U5Sb>Ap0Jn&2XA zV?CB_Yc!%Q+7)0NnQ$(73JE4T`Zyy=t4OYETeRs)Z|bU?P6}yfw3hmR+A(DZX)7%x z*|G$*um13Va5Z$;Sy=sJCeF2}5uj}vivQBi8@;K1P@l)>UV}D>s*}C?1JYaXLH2;EtcVh3 zgp^2K(Ed`N>FuQd^lkwMDHANE%=5Na7Wrlc-ULd3hD-H`;u9 zHC33IPKWb#*z%SM!eL>)MKpD?%o7%wKCyebt4v8|JkE#hMkynl9M{JxZhEi=nmRBtrSp4X(Yl8_zXs43kijV*}^JwEz?#@iusOdmHBTWL+EZ=#IF{@xckg( zHX4_PvrSLXB+e$qjK%t55)V~i8!g(nuMx6JKMu8P93&+~V#qdQO2Vr42LAyiNOu+- zUd#s|9yS5*_yj7Af^(M9AMWCL6}*XZC$*)aht?Vmq%qodrMzB79i`V%m&&QCSsor7 zua@<$RQI`~gF*00Ms9@r_1gFQlrLWh-GM5O)#Ho!?C*)-8q z)s|s9ZlC9@>FO0;FD%#D!_mrq*20+9@Sm9ZP$67|T4p1mfWO@vR60e}zqL);WvwQp z#8jZW;1uQpyOtRPfn$BR3tTeWTnO<8I5#_jE=eQwA9P}%qLB-0>FP=8xO^sL4{ngw zc^8XyJa>SKGc5Z}$dNTK7|1^D-Qn5d^ZVN7F7&kYZpqo_?kKw;$F~FaGB~T@IVargB*n8+ZR5%pu9e31yO0m%2A>I@1zMj!N`*c>FbUEjrKPB@D zru5D@n0G4WROXeeJDHg|7yXm9t=vRrmUDX8s`&1%ZL!P4e@9fX4hXB^k}P*XK^2FF zh)V+LIquA6sqNDnWt7XN-M>A}13AH|>OXn{MKe{+%k9w-m!j*%wTMxo=0*JD5Ug$4 z-Z;u=ss+{lQn65{;QC;d&}6B$c+Yr7jxaCHiPjyi2hPLc?;|!uTnnEP_Q@6IsBB$e z;`x=>X+&#$hop*L-f~RZl^qf7p@wsfyWMCUD zNOyQn$ZhBW{1@Bt@T}+}Q4^w{yFR)0S=yUh*y>qs@Gqd3`BZ%FznM$YPp6(rJ(E!) zZ(|@%@~B(M7j&524eH)U)`qUH5$j^7#x0EYh2OM45(?6Z`bjA!aLAwHdjk5(8NQM} z9`UIAsyse#UCsdanoKeCXL_C7KzdY8WY)yYN6E?QN79+R>iSRi zJ-aVzMBJ;G{)uO!pF3B(dJ6-E&Zc|pPhi&#mMuOm*PQ<7_pU#~Qy*q_^q-Tq8(W!) z{BroVp4kf9RzbqZr--EJMv;RZ6D?2avs&d)s^@r4n;diA26s{aq2QO`3TcWmN|zYU zStROQzT!!Tiq0w4ui%dOeK2$`U_jA7%^XQK-> zMl}QBKUq00yOcfRGG7sIh3v-}pYtkY)P~t%q&3ZXnDHmCNpAb!?Q-U(DZdu$r9zb8 zh`iz4TJT)bg}5Q{!<|DMZH1}aZeBFX;N6NV(AZxm^JI$pr}59V8AJ2rz;5F@d(qUu zt~w$eZb!O|cYkW2uM;{P$ z*N%kB1^ShkSG0biK?PC7zn1xIf;vL_DJ}@@lg`Nh1@{Ey02e488l#jTc3!hp4*L{+ zHl|%z{jf2%zqy%w6HqOFWH+L1xCNSOxU|2dyV?(VF(`*7u;NRShs=Sw+^;|W zEoxC}d(;=}!Fas@6E``bhpUT?<2p0fjs8X*aN6$iq~{&aY@QaA*6CMd<}!C*e?#1b zJ8xR!eD7%I`V&4j{Eu^u>lru_{I*S|ow%mz4D8R|CyK_^&!Cu@VCjiy20aqw>#?Uy z4V~A*=0=8Mzc~kjntdo+nybKMQQz=2+|g(U9uJ>(k<`;0>FY7io@Y#!xe(-9L1@Iy zz$VCl8m<4(7lZzKpvJ0=?T3PP@6&m8t@2%?nklEY&Brhd*n$Mn>?a9ou z=M45eOj#BfmoYf=W8lr77QWk(nol)xtM zOm|#HrM!0Oo3n%7qn?X8i}x}~wzsB8>jZm2YeiRU$2Qwe$5`_|ej8m|eFa>g*pIp<;VMakJTb>{;5@R9B zq@bLsX2>Pv&-w`}6<0B@w_S8R3%lrS?QCi}!L1VRpe{_13sF$#z)eIUpbQ6Pq_KFQ zv5j3~zRVvty)gluopDeKxnNKRgO5|!Nl)@wijflJWbZ8Cq!bH`3N#F~&c5K|a@wVp z&n}mnFRg>`sQXKHLBA`@oE1u2>u%+{rWT8e6$T{SD>%z_As-#yT6kz{M+d1mK8u8- zGSVc^tK8)&-n8c_^)t4*TL(sxOZW`m%M!^ywm7USEic1Ag&hnV5tZRsVY$sM(%XwY zd^_C}-AjEtg0(`spj%K_sirN_tkf%d8+Xdq-?cT;p1*1Sh=k07O!$tl6ONknV)i0c zR$ZtJls5#cg(e4E`FjRuOGUK5@j_3!yA|baQnW+2-o(@rNvJn6zby-%Fo;{ zb82}Xr?kv1oVGCQYKE9u$eowv^~}rroxVDyRBDIh$L`7cT!D=UL~cwd71cAQLBt;O z4a*Py5u0OkoW4v+2}Oa%S?6)HB)GcR~DzY^7R4EnFC}TnBSRc(It+A129JWZW~cmKZ6bYwJfK191nrW#8}5~g(A(&3?0mWjKUX--<(U8D&$ADp zBY9TVq_5(Tz<|IW|Bk#5IaJQ@v<=xke_hVde%{H7$vB)fG`mas$J8_FacND-MO$z7 zXHw(v6ESrQtaGjlAL0zN9Ae^aLoG{qm3k(w%G{V0k^K1g@bs%G^>VS-8r%z`fRgv? zZt5jJ%DFaXXKX@(DREB3$cTpa4%}P}i7@(UW3v_~9}kSptCY=W)y`_4Q`Ofv_(!=* z$MSE4XzKy%m+)=z`SQ0Zaovq#WPdA8h493+|p)gTdUd`^kn>h6R7 zY4SR1BAR7RHP^N-c9gQda=dU_!qQwZ&eBe=@SG9ok)XBSB!-I@z+F*MStwqTTgi#i zbhV7~gIb0zvUZ^fyIt6iufTYE0IPU8m}d##)_ALo2pkR-${CQ?Bf44a-@QKh{ zt&eoj|HN0v-^^Fq|IByBzac1xMA0M9)kje$_@9oak*}i8#Y9C+hWV&?!i3? z@kz#K+WWhc|M^azv7YYgs#mY8s-N7emL+`BJbjPUOsXOZfkzVMZ!Deid;O~d%{+ep zIQKQ*6gT4^9lWAkg%{)pBTzrFk<1jfKX-%s&hO-gT5brJOrP1SR4uX=?84n(U-(aH zqij-~;Nu%9{UkOJXGp1XS@^G}AfB=x%`^XF3bCh{pV=i`GS^L5D5UeP_=DUe_Auk4 z7r;lJN}eT$lV)-u(G(0sKWn4)yV?QmyuKTLbyZJ=kESYEa%}ML)+R~Hj~MB)^gcR@ zK1JW9R)8rmjyi=XyYF~@oIq9tqc9I&&k=Zqx*2Ws#l%_ojkX$z`tQax{T(ddv%opi zg(!{a>$}K5z6hqx8n84U1&hi?ctp>@6L20Jwo!P$GD`Qg|r{)W3>g^er*MO zHI=k0T1oI^HC4N*X-bSzS*@mAQCq1y)E8QB?TTJTAEY1G59u=eP1jIf0{CK%6J=rT z9*wO2Dda-(E$ZBa8V>KqbZR~<932qnKM&UFez5pgAq}u*?uKQ(5j@o&!FCHjFt}~t zokl6M!GU@l9*%3^_Gv&igiohF)r|Tdm4TXDnNE>NR}mc!lxW&@K(U% zQy=!XagbvX_(l9+#qol*c|pDx<|*+GEX-R`{&sNIY==kb0ioib3t&PT1fNAs{L>Sh zGR^Ug1Ef|3QQ&jDY=qGFYk3fIB7yo|et{wh%l{`@tvk5KIJVDEk^%p^ky;}j3BI8( z;1&7|j+%$?JbXZVICuoAgPEryd|*Y0y3p3r;AD#i7fu~8;gp8oM*}O*8{;uJbUwhJ za|d3Y5BPl_*Le*-nn&>U{F~2cH3iI4KN;nXL|_^=g8jCU5eK&7Soj$$BNL%5{6GD` zgS9-L`|3=-534Ay))Q9W(eQOHg%@@Um_$z^LJM9%;sfy-y|V#+p3(5?_C%dZKx$vX zFSNq|A2d8`OHkI%d^xhc64$K16&^+v00Y1`^_-u%B%cJyKVEJhcZK?}zN+mE0 zh55zGgA=C>ESSB)eANartd48c0e?_e92;gkip5+MW|PlCPh>-856}nG(DwC^;8XaL z3X)5}e!C8_7eAn#bHF*e0Cg!3Zlh6%qFD?N3iuJ;*NTn~n)Ehs(AfqNU;^%RUl58$^Ad$vA-dCLYC zEd|`C5ny<-fFr6nbQIno{Nlh$73O01WAnlL=EtublpDY`!+cLs(2oR2p#ZeR0{@_a z|8K#7Yge_baaK0e5qBz!$W5YFq zF%FwygujJQXBwl5f(=gV{7)BJ<3kHQ&>k0h0N~Eh_%N^77knmT3%_SzPinpdQz5-C z(1ft=h z{S36i1i6ISxC&q^0j&Zy8{`u1g|Iw%NXncqr?6I%*iYu~3p2wNfFz56f2>IUxeoN0 z4KE8e9r_q_jX(}9@-YJ8Haw3#r{Mb zRUExm39MqpaBO{itAY25_*D+|jlePCQXJ^lu+9|0xA0Z~+EE1c3v<*J!h2Et@4%JH z;HWUaU05T{xCV=F;kJbLT5v=hUPbVK9Ijpv8Wh&IaQkhLkPQB^DAY{O*X+KirEhy!?PHhjis7UUc5^YBq{KBHDNq!xkNJ5W~>ei8Y1qA$M|ew5?O z?~(90A%W#D-1hJt6a(pm3a>G3rk8;Cff*oUfSYPTB{lSgMVqWqfYNQ3&0#(EtufvQU3y#61_|*!lf)=Pl zXUwJow4@ozNWgnt%nH3Q|F=h6OIOSm;3~c>Ix3Kk{FN5LG~7~%sKLZ zx4>-k6K1I9m}R=5%&z!V1v5)6lo}q<9B73b_n;^Le$0hT!_0uG=(lk1Wa6HFhSyz; z_qX!rk*nZ*+=g|+8O)+9F|QoLf9J8yOzmMy_M;wR?wJC~O zZX0UQ9(r>e^LLo7@;6{U?Lg$yuZZzDj{0l?KjJ~u_5)tGQI8yy9-cj)L8r5zeHoa! zUP9+SK)VgpgM}2rqueuSaCjVih}MUoyDVlx6?Zy~*(CuoD+7rch+mA$*R;Cm*YIji z#&Ir`^$lm{;%e7X-~Et@6XjgT%H|TjhsVDUC`p0#6oe+lW2Q|&zxBXe5Bvi3-vGSo zVJ%!6{a6z9DT!4_Y3Ni@j1LWQY#C@`Rg3_|v1+V_ktG5D+VL5}oErl@tB78V#ksZ6 zyOp4k4Y7tUiAbU_+h=_sqjtdRq7KHJX3(9M7#|AYN^vOPgc;Wf35Dd<{R8yWKUlYKM0Cztun0~<+*5bNdyasG@n^76E`!9w{k9hRa14r~e@;=qIe}+T*`QkbZsqYKmTM zj8$TofwUF+WjvlQ)*y~(6C$xrV9k08>(0x>7ocz5!YJ{K_>365FNjw9gx6WTAHn)^ z2{>#2#yRuw6tW(pK$xF&0AhF-A?|Gr#)dIik(NQ#r47zZ&O>-O6!tCjFPZSvaT088OrzrH?j|ek%V&{zYd6 zuTzWUebP&Ho0=t+B&R9=Nz3(Cuy9>Ltn^^SRm?VO>YW%$A3`A48xs1Hn=P*-41KBC zThA6#MM68FlsEp=>*%v!H%vE*GkMGd_(>>Q(gACFbh6rnQeq8+WQI|RjG~%BMpcACd>7^>|B_B;Vz|q6fBG1GpGqOlQH1uV{#tsf z%m7;SYNDPJC0^41$StEj4-E<~)>5Py{sPKipDXk};FT)+($wvuMf^z}5-g+~BM)$s zX~U93r7?Ac8th$Rv!$LPi* z_(skaGeV8@%2Xe%D!-m|QWidf5!qD!CeS%L^WFK*$e>@qoW)#LmUif!h}Xs^#V;?G zqqG*vUwV1{oti)%rILx4+ylzN4Y3ehS@Q-eT38JP0~=>%@~Cd$vyLaG8w(X5vBkB3 z-smMsZ*Yj1?bp5Ul#0^7OuBSSI;LEZ8*n2OCzuWA=F~B{G|4}b+Ab|}C#bD8(!G=F z7i{Sri|C1$ep5^QI84nVn=>DnhWr`kE6?k% zxYeo?A@GM_Ybd9UVSPkb(_R_G2L z#7}t*<*C7fo)%IuR~3(4+#>I#e*uEVB-11Ag>{E1E5fiJi4dYaQ75Bz+NN+<*{z5^ zo`5;CoHjy#t8^8AmTF7=z#6bdKMoI8U$&BYA<*Vq3r^u5dsXw|=v3=#{!$E2z2i;X zGUYz;Rhs?Jm< zAJF}bRfvW8Of1lnsY0PLQg!8r-jR%_PU>gmd7<-e!*jtKuh5cDuRvbsI+%RyMgA;v*E%F}TujmULj{^e zcQQ?nJk7tf*(?JXhMUFKGOlSp>8=zjr>NzXw~~dR5HInU@t0be*~?!r6xMHULG`qi zp$}0lnKg)cAI**?-|8>)1u`Y+Qj+{6I6^K9tJ>Po@6tSRjMN-<>xa@Qt(WeyJW;m_ z@0D_yb;hp`Ff^7J)+A-7W=rYs3|4G@8^YPPmaA7*0y5Nm*Zvhl3jOC-{F$yMo^83` zv%ls|&J$h0rIr67FEV{hJB0t4YS^w?cR8$%T``+t)Vhs?Y-akh@uEfKWsB^M3m0?O1vP^^==G_%G%U>u{5IT_#h zOwL=*wdH$baczd**~7OttA=MwcFD}?*_(5Z!N-4CIYsX^WRpQl>;}sqvuN4qI2G9< zdU|9>%L}d?vy=XcXs@wq1LdCL^3ws{d*5Yt{pXz{2xM+<$#5;4~{D4SZNz*U4^I> z8fgEMg{AyTVG$n(l(-LI=v+utVb1ZRgwob}mf^x|^DcG^*G||ctTip<8w$PI9AITH z6?;g-{PhAgy$ABT2YcsDbbZS`kd=_L-L)e7g|CI%4-~+!?oL^+JvUv+PQjJpUz*z| zl$R@J)N^mkxR~)N?;ngMXVg#X54jq_BYtxTU~5tL1EnFX8II) zz)*=bN^0m+@TRYlH|TQ({tX;b<|>K6+AFIS(sHmOxIs@eX+o@Jfwh1+!Ssb+$Tt-X z+egPaYtNXGQA;CLdlO5vwRgls>sDJLq5=v=Trf}Jbw;M&F*TVfTrx9^Z_YiScd*6j ze$-d`9-qdy17cMp(*@IVz7!V=3^9)VNRMJZQCsu@YMdBGXln~Z#$2!8)6Glyr{^wl z9m?&SJJU5gude63E6sa7*XHWx?(dxA&d5ILNpsJ3zj42C-prYneKhNC)@1J}?>%pO z*(_#h?XVg;qW?nHVYc#3ZHeFr-EVDS`)HC)$AtUbX1+B47gw5_L&YLP%p`vnr;DYf zK&Z8RT1u0K$-Sh5qD$(hIMo3}4{|guB2uXlHlSOoG>xth z{n=3~a&gp^sHgVtwy&0S;XigfJCN>42k5`3UCd#+H~oTgk?#=Ao_KS^866pya3< zv^lzk>`uHytkfS^OI>6F>~Cxbuu=@;p7KGqKkwq2b7#0nb_{DnFP^3PP`^^ou*OUv zkHhXh8<=cWkYk_{*<^jHF>S&zE!l;LN2|qoxR+dC{!e~8f0I8B){_77ZwQjgwH}ZVVF=xh!J$&$)D%n@E3tM^`3jdPGxT}Rhej}Dy>l+!4TP=+(%3> zN*fE|ODLn?*1Bjxb&%RoJ*2Et2&IzZmgmV&WiQ6XQqml8nP?LU%n)-z)kCd<1%ulI zRNyiE&ei=r{FVJh{r&wj{Zsuuf3?8cK)2wH;3OafM@TQEZ}Mv85A|0>?zTjPy#frU zk(5fcqA8|1vyZ6-ywJPsefAaZ{1mn(_b9%}m(?f#i<%H}cWnet3&<{Fqj z=FyGll5`&Rl1ie^QD8oy{-N9yMOUL&(kFqEHHs;+;DCk_$y9v zd%00uC(g#bW>2%7*;w`{GYj0j+kyIW4;V=Os2J)gIi9SJY>^{)Mh#-+G2MtXQuMXD z09K}UTAI2-EudyA2bG_dBFaO#i)@t-OKqfpxK$h}76x19q!1OljF`m5!HvNo!Liud z1kVMJ2V+7%BN`w|oG2>dW@)&5Umk>2%6Fl`aAlH*0$X@U^Jw=Ne z(8GXl0Sr-kAzHngo zFQ|`H9+gN}r#sQZv6i1r&!nf*lW@0e#kU3Yc-$?0AxDya2~T=Isv8wS`H|1l6`3Lm zaSuM5?s%s1;7N5mYz{M!``W{(XH)>2vs?EgCuI&Ya60I1^jgR?vgpqcE5A+~g-92h z7Emo3T%GD!b+$S~?S)9y#%d|mr@T_0LUQYs*~(^Rp0Z5&OF5)GR9+~IT1;)D{-$nG zlT=!(u1(c0YUx@zy`O$S|A9Q13C3QFEzTcpn1^OvAmm2JJdPt)zBP>!_LFS000H09yJBHHd0OH3YMBHDrjz1Ho5B z_E;t|f?mTzcob!?M8xq>vIqQIg-MDOi6mq`oq!!{F+5v?@l@LdHi*)AX4YUK@#foA zUgHV>KiG2iKvo-JN0Lg%zlUe<*LWs> zhpe9;$U5@^K}sfEgd5(aYyDX>D^hc)5|EJVnqfq!lWr0_dr&>MLLt&A4Xo@UtUpyu_An#lXAf$x>^ULW6^ zAfC7pKC5D1O&r@Cds@K~)*7#l$VBUbyw*YR@%#>}%O7Yhc>iFFSq)3f=6svmo_s3} zyr{5v+{g9|{-3YVny^(&!p1-z(Xe`cj;xwMNFCh2n31*Ei$j^HVTg7MihMmaN%R>JDLwGO4T5%BmsNJw{tj3u$v5khk z;x||=24d?CD`qQL!RoMGiH9rncYknttFUxiiiAB-E<(C0T{hYQ<+9^&|iX#E4UKU~gR zl>Z&&r=ou{U{gxPdnW8pA!vw#cRwtFK70%NA&@Q&+nfc~;IOwf3R`TxEwCV@7PkF| z{XnJit#l>OnlfljseHSC*eg^Fze-?Vh5W4!tU)#KTwNWrSQY4CZCGt9;;67ENPvBb z!W9_Usl#56T)bz&zU9Q0joPK5&%&13ALzZ$czuFozTou^+Y3nOJ!J9$+y8mO!tF^z z8QEw>F4`5|!nXXdUn%Sb3Ae<7mV~X;VNXWbk5wLZs*BpyMz4hJh4rBaVcS*XeEU^H z=tToqw8Hzsp3$~=g>BJYVPEb6t5-LC?*lJtFW8(qVrz-}t2uhOS-$NH7&Q4MwSs0f z&M&=1{=ZEyUi^e(TA<8ouvIog9jc)&Wg)X-*b*U|uoXQTbq;&)!afZSG78)GqoFfl zj{yw}tA>{4p`T#dg%8XFKSS6@6}Ha4gT3!1`uquc|0(u7f!@4@2E5OgVA$*R20C&J zmdM-CknfPI3|dm zcA;%4xKqQuPa?nEjnZ@QdI9Zthqk@OUEstKVb59EpXh}I%&2cJ{2i${T0~nv;;4tX zS~9*r$Mx<*JJQg`u!LNA&p`>f*c0}>aFA;b^duefbU;r^p)YE~y4eV`9x_$(^`ixP zxf@1^o)|GYWbd%2%Q*$yM7kD>LZ{FP0?S&u&*WlX$_4ji_y0a zJQLlaBW3efpHGYu_`d+Ou_~?-jlM35o(_AsBJyq8nXoAp#a$S-zenI6jDn9yN3Fx& zq+H1D0qn}Z;)v2Hp)~HduurWtG%7Z~))f(R!Qxs4vA+V23VT&W^hEf+3;VCa-c1wk z0SzAG1Y9@l7yg0!(1z8Hik^D~ndLy@683$A^%6O^2)Q=JCEH#xJaxr%<)Sq6Z z^digY_kb6;ORtF5p3r$bAXyEF1LgGw#3ESOZzJz?5HJ-;j2`EJAW@!N ztCz*N)DCe1zZg9*-oC>;Fq7zO6bCQZCP?(Qk%btWqrm4YNX7$e5LgV5Xk%gp^3fWT zzv$15$HX~(IkDXMNQ@8rm5C3!M2yuZQ8%^kWPf8K#`rh-9O^SHuq63Je+ZVT8;Cn7 zOUx$YwE2`@KLD)c47713?$EaA?KZ{*Vk*#posigM=ti!AY%C*|JdaFR51FJzQ`__n zWCd*|Gf*8xdI+cXiFkxMZliwJ{sz{>c)dUH9?k+Qa4s}xC!!yY0nxJ~5GA@mD$B^7 zhE30-7pc!F68qcD}FR-*r>Z!~rTPv920 zRU67IRIZWxp&lB@;h1 z*w-2<`P9_N<=;n@`s3@S0(8(yjAq6t?S#p`;#Qwjg8Br>&Zmb-^|;r@ZaIfdb0e4HS0%Ig-X9L8igdD2Mh{$SWoFzl_Gced~l&@g@+Ji1k zEYmkLOLPbQtNeyvtSlj?Q5iseDXyL7TvAWktNllgKn~>^r4{W~jx+x%w~V`bA8G)y zl`n}iO}jz{HA{a=XUKc_F~R5jyHF1Oy9Tcep2Gh}M=04!edY-LO#IB=7u%Yb=QOw0 zM_f-V@v~_cGG|ASD*a9!&hOIVsgde%s-<|7ZKurBeWV_mF8uJnq>m_f$ZPUtRR?O~ z1EML8*`bZqZL}_@vMt0qbWh;2T9pdKD000viE1jl*@mI!^i8D{P&6wmy@**@dk;{@ z5If}8#9MVKuwTG`$rO`0dXId8oT$#HDX}&?6W*Ak>RI^vj$#G=2gc4C!~~!c7bBOb z&zL?+SvDbb&^$KOm(GN5r>1nBeJ2(`3~;vLFm_N+^d-Q6X~A@nDhbu)Gt4Gq5;A3F z@+v!3S;|RNbK^Fi;rdbnxrTjDRVG@o;rz=+^uI(QrVDWH(uu~@2db5kNA@QdBXj+T zp#dl71Nw0haZQ<^k5nq^Kg$EP0f@~UuhfuLbA$QW}=tOD( z{fVUM)rgGEVD?k{;r&=cnc4aDKxPTOo;5K?*?nNC9uc5c8IUPbkUIfR;%aAX4 zTg(mY0P21=tI%#9G8Ieet?cKmLCV2Lq(;eL=^h@bUc zu=xZ-AEnlz^`W4@Nw9galD~JbTre0sAKc?#8ffFY>M7}ZlUvZ$FLy&u^SsG9cbp?K z6LKD9982q+Q7d_L)}`c7M&r!vv_iQRvp)p<{*C@LZHH`StW-Nwb?X32smLKw3uAW0 zE_XbO=@rw-VTl)e2A&m}J6DJ1Sg%YJD{zjoWzE0ll?l!Iuc_Xtgi>9 z;!*6Jm?AMtqaVhWa(s%c<#=Wf+DTIevQIPUsl+YyrTUu~BUKIk0+y1K{ujRKfz$qM zzYrWMPF22Zb=g!d!`$3<$KEZrRBV%i;|o?y+@E;6=#64SiWW)iQ?PiU-wHI2UFWzF zSv6vq<&ba^UJsWxRC_DyGAk9AdW*}&oX~>M;9&n?&)~K|rO?e_3GnR9z<3hxF6f<> zH`sO0Jp^rRl~>)lA*)iV1bQ~j5Hg#zyb7HOqUaBG-o>tx%i$U8BKv1CHi_$i6?5}TAHO6C(ZQz&Cx}Y=IA=FVR zB&CQOLc4<(y$jua{UdY7xX0&p%bJ?qG^2A`^W=>w>%SIDDV@ruc1gdHDYz=TdwK_l zs*2~ev-%yjw<*Q+JR;dXIA&CI_2|3NV`5%8szrB;9vf9VBFU0#e#8Aow*yXEX~i!k zi$z190xH-$(tYjxIlfGP%isg>WG9jtU_~!r+ZLJV=o4Ql_DiAMgpWm=6;CO2yy&ok zzZH5DTPLl3yDWx~4uCb0J_NL3_BtPqNmH-`?3{R3oZpSVMsD(;nj z7HbDCNuT_$0!4zqxMl}?xIV*I)7v@9wIxe;HqN?|Q!4%3j}qD5l&I`Q8A(~|bFO-- z1>)R8HGkl$F-+Y^-(nNkE2eYiM^VQt3nLx&n-R&jdp59)SbCes3b#yk*k|x+w>3m1 zO)e$h5$B7?z=%FIq>1x`ZAFJxj}|qRuPt=6MB7af^J7OuT`4d+Ca2)rg2M_Nh;JR2 z79Zhw;Fukm6rq`3TYfRUr*}|SfP7O}nWK&geU=xR{G`*EhC_|{l!M)0O z^*6~W^$qQpw)+owCkIw|I*Y7-ooA?POz?YdiJYxYr|Vp1VrE2cpPcfU$8!p}>bfua zoB3Dxze>lXf$Ab4@y%s}^#4qMT4q@%Ti#oa*%w6ni19#&tLuhs*NvW@I3d9;G<=#|Je?c*U z8J;^WbXUw0Z+i+Ft3vJkEkYZm0XYMt7*8H}$qt1sxszP!u1Y>z&ilMg?%ugo1Alw- zq(#A%@+P9SG|l)&e@FhRUtnGc7H*_DTj*_`&tKrdXTXI>uS5-dMNT$#F@yTCBw1Hnrj_l>uoAxN#GKB z9`p1t(#Of96kt@w9oodDjv__Sz>Fxs_c}yk~`FB@O>wGw=%I(^?(})Y8{k^ zZsbe^I_WDssli)h^}tP!Jvc&0`MyxgkmG!%{gQ8(^w_st-0Jhg2k!Yiqdh+3hU zl9M^3eggx=ZmJj2TPn(x=blmT^#J_*E9mz|XU0X%q8l4~sh!+YBMH%JiF7rxkD5c? z73$H)$aD7d%w&G6X$g1QdWPy{;oxnPg=$~|my8K&dVms+1!bkJk*fHl2kLy^TlyyQ1ZDIXPGO-0^zUt*z@YFw4K z5x!7kaK3z3hY%z69BPg(8FQ$o>ISB;+LO)$l1DZ%o0>~hXD+cG)y|Kk&ykmmS>}~A ztA$K`SvzrxiWKIu?X)~W65@;ms;AWfOtc9`e8k~E6@5cgCD-M+3W2)RO50qIn;lKu zr{jZ1sNM9a;B@wHqfGEGaEU~^A4J>@{7Nv!JnDyjDX*3(P3~$;b>FcRCL(h<>pb{) z<^?L!qm5Qd6@98USgpe?5u<`*?2DBHc~7{bR5e8nRpZ7}WBmt><9Z#olQ>w)HWgKe z7}tZf&C&YdfSY@#4K$+Q%UMqx;a3FCn*ss9rEBOP<#)@fP_|x)x+?r0{6V&-wi1uD zlIq;3UcUWIG}}=71^qt>&$I9BtwY;Pj8x6COukHzww}IvTmqjP@bXKIf4H8ZLd^fH z^ZZk|MZ{h@Nw3emQ(sZ@2}b|N7%NxhsnA1{Bh<-ULya(;tKk12h zMqke-OJjhgJBRmpSMqnYRrKW0vxr0fC)z$!fnZmoJk^z0t+gYXQ+4Qo_ji+-pQn_T z##r|1tQN-}=AH#>t1sDQhC_N7sRh1T)4e;{@62JT6W3LG2tQrv&{JWMzm(}uc_n%$ zo{!^OWZkegq==x6OsDUfOJro4D+3LAtw0)LV2x~ zS;=Y{yWV|>cLMeB9Jf8RnXRV{;IcyH`G7cs+GCGTtr%B3@R@!_Y@qDwc2y8mr5Z&_ zULa~$gUv2ZzrdahQOX2)_Hn%1hb*^8ltdT+KVFbfG;jO^eUWUS=7 zg-Cjvyd?HvW|Xx#Wew#hgyjdVJ0oH?OW&E@Vk7H%`CO=sd9E}J*qB?`Vu4BIOn#m| zQtEBHOJ2`=%3AIJImhVj`M*td0+EtEYK7c4l&b|Jii^dip%yX_9cRzjLJo=8rakn& zWadY{%N3bSpk$UHrfILmSaL;Fq-R~Ipra34(s>U=oV0&s78zIFwj?7;eNwn_<~v>| z$EY9Jj6kfZm;Zx(R<=7*QsNCWsf3PLpZnj@>+~|FkaAFW!oyY(+;^(7OtnGJFYbF@+&xVr*-+T|gN zM%D1X(5{;{Fa>h%+4m8D`pWY&$#E*d@YK=3cIpRrR=J~%rSFl2sdSe){;>C)CKJzu zh4NvgrFEH$21`Q>x=aXsqyAAG5y$-xscYhN%RzCs(u(^F81%UTyJ?nEj#(0rEiaLC zH;alR76HwBKU%kd-wUkdr*xuxj&=Io=DX5iwG1~^IU%g@*0fd-&lwjjOT0^j-pYDw zPv^_X2x+IcUiALZG&PC*Y!=-y=HJ!5+~7PqIwtsvu5aWj&8d<4MBGOveInIgzoIu{ z21y;MirRUhh0A0|zOT2v(2jf|KPTR^RYF^cXXZYEm--2EKKGBecx1eH ziItR1WCAnNs21vJJ?!7X4KpOMCHFh|P`XXKxuxJOe2=I0CrT6jJ#&hSc8xP3V$`it zY1a1Mk@{X+yHI&!H~XCWRX#*@2WIs+ZnMakdIq1;XPNU-MLu4>$-ETTn+^x>Afo1` zq?+3XuTjP6c|>|hV%8xmvIOygTL6SBhcGu3%Pl9qDN_t&a`JrcmB?e#Zn%D`G>~CD5*7v$xL@T_Tv2H=@Al8O4u_3zCw)WSN$Pr87Iq%XIQ?VrgQcz3T(nr5 z2meQP0O!tL@f-iAlBZQ-noyn5Pw7Bpen(wro(0!&eWj92DJ_cHrF`Zp8zrSI)1RS( ze2w4(b7^Ib#_`Ppvu%Tdz2pVvB}zl!>8>Q6QFg5exstHbGxa`!Z`QqXC1o7dp8l2k z5-4ab7i?^6=>*uamK%6sspNeWv)-vPk6}IQLskaz*fGj1PZXMq9Zh5X)mblbRf*?2 zcpf==g|_Q^iNA$~-fna}TP@HcBF8s{?@xZ!y>cNcLv16x@N@tx-O|uqY8SAIOKKtJ zMzF29t#pR?om@mD^TTqA$1loyYyAS;j2YHCB7?CiNr|WSBUXB-Qqp`m@R9h*+}0N) zZkudm;hau!?fhf-&3;0-N_{}Aa3OYm$gno^6k@uX`v>0Zv%u*$$h|<2>F(+lSk4bK z%cYXSu7CqJ5I)yoza1zYvDGtyjR2j5gR=s zeTtbz?hXEHZmN$8uC`p)?gj>!b`W0_o5n^BbJnobQhzaB&OPAh;j?p3qzH47;OmIu zsxNqz+7UV}Oe6Y{)hz41gW+8%MLCuI##u*`j9(p7?|tE8Xog8J z&IkWT{t;=(>1*i~ED%{Q@IG`ldT3y)G%sq5H&>`2w&%J?r>MH*6IRG6TxeWQ0&k{X zYTY=^Rm;&_8XJs^9i25TDj2MzCiAUjgJ0+WO}J?6BI^V`n#(9EqRs~z6|6_yHR<y>& z`2upL|Cd)avEJ9w1#as*^4`Tcvzo?_%Ighw*KeVQ{IbA&%XyWTw@^)~XUa_TK;I9O zOmOZNjymdM&sax)58U@~-L)J872fZk(O+r$@8MqozU+&JWb2W?2az3%$b$*Zh$n;1( zTXI}__rlJPryT!SuB5zAEcw-0a&yibX?{%O?7E4@hblHLKI!{czNC6OeptqRb*a_P z_s=ZCxg!VVevIk=X-DN^sa@$WrrFM31yfv$jj!|>ropywl-_>cZ~Ty80iRG5cD!CJ!v}i*t!EN?J?|u>6~)j8u9-Vc9JQ+ODY>-R!;^9lfYc0=7^t|Xr&3gM{z2% zfz>X?v`L$0>u!xmuO2x;IR`xCgYad=S)O|LL^)F~6n~xeB_h! z&FF5yKlHNv*PzeTIMABy3{3rhB4WH#$jSU|Z-LnMSp_XO)SsE#ikTToj5fZAjH8jq zC#;g5_>U*rf20V-=rYD*adl{Fp>Bb=z%ii%$?MAQqkc97};AQ^NO9*9>=jotgB8u~9XMGyu5 zBw}~kD#x!%Jy@=NfkLLL>P4=VFO5G0K9e2PGtqE3(%KbVn$om*9rp+6GC#_jRlN4Q zu7%3G@7w+i^%ko{PfWX1bou9drLTOfQGBrzh&Sfam}N@k#zwC2bg@;|_k_AJrHDSl z61P_Xx2dg7PMUctAM_R?%}geL-QR#KE~ko`+tT>Wae*3-#P+X`1dU zN}8B4!n$BDwE$a&o)A1sT>!^c3EypY0Ck2G{4FU;7)j*?_Rx#D3wlhjvhbH$f!V55 zBwsRq;;COT^%HBGYKm*%n|K!rng90prL*ZK+H}5AAc}1)^`T3TLTkb5OA7<{V=R(Q}DE>0qcg@rpTaS{fV;Pu(fvm{CIblXQhD zXmd<1c8{B4TN3GHFXEYcl@r7qMwRlYv!=75efnDJrTMn=0N6RM6Tj09l>w$1Qj*?) zd0;qLpHz(RF8zYlVLjC;PXNxtYw3yUhJ2fQ?`dP+$G(!Ils7bQwg(PTF1f!gH|xB$ zgT9r{6myaHu#~vV@!o~PF8LnqRzIi**);u>GL64UlG=53e6VM16ZbOm80*rKMC`(x=A2j0V(P1f#i_QcWTd)0 zbknrJxGhI&ZsIn-TIx?Qtj!pu^(T)S@0A`zQ#sQ#idaIm^)Fzr$#sNj;vVBFv4T#c zCK}7NYvOOhVxYNvB#WDxfj#gD(^hB@T4a=$N7$X}{@`@+tZ9#Vl<$=?i{2giIP}0b zlh7?ErCY{G^>1MZ^WIxi{?&ekNKi)yF0-?uE=jlYM)9RJo=47p5cz1rVR0xjMko>beRd`=^NV^b zv1O>6!nnY5buAfh9w=YZTag2RU7RfsB1`LT_7DB4aY$2`h5A+X2J(#->Uq>gvZl6+ zA7%7c?G1Oid3=7FdZ5!zzE&|tmhTL6SOC7g-5>Kd~iLIttVr`YA25>vHp2}yc zJ69B}E=`D!K4WuYOc*A&LO6BojfnN*;s=#8;df2|jB3{eYPxx2ytsLn*5Xf$(^{$@_D)>8LOVw>ND9I!5r0HY9>=pTLTV`p2TYSb(hk& zfvP!y-%dDi5rQ6gajN)NliS^Wf>M?SXT9B?q_R>7SUj19YM@DEt>WKN00bUM;oT`ro{(oOGQ7a1;m_+c2p3@&1W%ZSG9@vLRsL}dq z@}yoGp4HXbJaQ_Y8luQE#3y|Q5N4B+saA$qr1hjWsU@f+;)BK;EK?HLGG)nuh=F{I zk+BufgAYNfPt}3+Lb(s{Sk`K-D3`K<8Gt99wc1qjq46Kl1raHqh&5ypqZpZ~>>&?R zW55D5S*w9ap7%r+nXEOSdTPtS@O2kg9fyqL``}|fMeYT*`78ZD*pjE~*NF4vdHr9Y zq#mMX0h8W`=$@}YuX$)R*8|kg;6VA$2x)7mT;Mx3F#ZL5ObbLmb~4r=@~9YDSszA5 z0rf^fjLl~)j#>qzUp9!h3HhJ6nUT6NNoU9$`NB6S-@aP zC;IgmA?m9Uec6pjKrB-L(W|?GYdW5c0jJ1evJ{vJ1!(zO{XUR9B;YTX(jHU)Y6q!5 zwF2}Wogjv(vzTAN9XSj9M_x5ZJ%Zo#FQT#jmb`_CtmcTm{0Y%TD$1jYTKY@MKy=w= zMBODJdH@*K`VM0c@z(Hx$wM-VAjj$zu|jW1b_ajgLt~fTnKT18sW>8ci-Bn+q)x`& z+6Qb+I}oijhp1`10~^jMpmR?o8>*iWcUy$`Y}6t?7!CCf;LyH9F4C7Gy3a=*)6$4{ zh>N~_5c>0^toqnTQsba)qorV``2$`-6h-0mS9CABZ3~_H= z5er;Et3k!;X=H zb>kMvLx%UYU0^QA(I26%i_j_>i2NnU+gb^_0k}MhY5S-nI!#~Heg}(9Y2+KIU=Utr z%t!3m_+s?z7R$2_{s*6x|BM)t#XJ(5{lVq3LxITXvLarx&I? z809MJONhb-(2k6Q#3Tb_3$$noxOtu%)e!OJ)D+?ZB9ktHx9Kf8l_;u>AU6Wpevz(2 zmltUs@+|m7?ijU=U+CL7_7G}xz*tST)qew1-&DkjebS~;{f!tRsE1J(?jd%QB-#O= zsv>$~8Zj5~c^9?9kos=i?FE5JRnZs(W+cQul54bojGrODk+AU=B+7#4XsX^Hu`xPk zh9iiv-k?tVLSKPC{hN4i1TpfJR7K_~V%FCiPH?=0BlH+z z3tHJ6%nk9xQOx(_5aYd{sEo3ciJQ3ZYSLyMNF#bNat5MQCn8cf#~7^lqOuX^0|q5{ z$3`Olu_kp#k0Q2_w_sDhh<>YKAR`?9fzo;p^1A*M%u77^k3N*FqnsUx)i}5UgGiHKMCKtHLg@9@` z4$l?865sS$Mh5CW6l38?M3uHRx?o)L6YceQ@+f#wKk0{vOagW$oieTv#f(jeQforD$!aW#IhSPM*=Lk&m>!#(q5x5&AQMKs%5e zk6gOP`VMgI&esp)E{{S4=X@dn4LX9Brs?&-{LvCxTHY9sd4MDB0Q{+Gj6{^rK(d?B zLQg_u=2AV`cm-C&i-@LgjX2o8hDWbWz5o~1Pt*kcDDHBbQtpWP=;8T}AbS-nfi%Xgihd<(YqObIS_(B&+pTNVb7dy;S=mlCAgdskBbB^{v8Do$-l}RH z=-S$Ms+jH}ifCD6O+49DBC3$9F+c3kb{arO(zjujQMIjLTUVTNUZy!M!U zXIutbb&z_XbRkY@m$B090g2J7gFJ!Jpo)?TKXrXNPhJm}-N&$dcZcP8KXFhiW^5(| z>WC^5kCje_i7JF_3!d0uh}s(C5HT9;t3Lx3yDhlg!YE+1v@V2>v2C{T2%|zH{UVr4 z)>Hd2+Ehi%+6kfwax6aU9xzo`fiLut{@9oa#)Wp^6Vi!&@Rl7%)Qx8J&?gx`>*H~a z+C&F^0ytCFYNs)KxU?M@Eg$K_Nf=A1U0Nle`L`hI=%tA*`VDumCg*`g zd=5s;r?^u~5y|91%mMom`MQtVLz2c+V+LlDjz$GC1#|i*aDd$+YJlY_32Z!Vum&to zel&(4T6Z2|aS7T4G>&z+`_5nmG@7VM3dS<>Ji$|LV5bcyvkZp%4$iRSV2+O^{v~SC zJ}_KR)GP7}B8f-Q$H`RU5P1%Ey@u2cu)XgmlE@0gesFwO0{3So5d)w87UDfLWE;^I z=!Cyhr@%J1fSeCrq5r57K()9CmK_54zC9s(3-!<#Og$nNP#eIzvY1K*dPQaO1h6x% zlIMt<;ICPaIO*kvOpP!;l7q;Z#2D%Tv4?t$v7?OfTK{6W5Zyi)SZzIU7c%;N!lNY_ z{jnYz1vJ9Zs9%404F($*fxp-bxlw(IhscSlLJmND`X*SA8&I*}u%7|`yL)63d7Y{X zzN1CtI`TPrl6(w~sfL7!>;N918YE)s$;ZZRDgjwkwdf&~iET=Um{@ixU7sCG*JTT^ zN7=hv4R#ef1W$-hx$4N-d%|o)cB7jtOarx4D=KXXeGV{z`qIEqBQZTVIB+^R(cdUI zBGANtP8@>l{kF(Yq$UXCuf%&2+T}%byfAd_kEXk zmUT{Zwskgf)^RMbzq6Xm47-EaXq1u)a^)gpLuEqK0-CRm_q1nG?&I7@Hj&*jFPO8? zlj&>ae&L_yzUH}+$K-CyZk5?sIO3>lACt7r`M~}sCec16zO7}uy?R_rSKGu!v5PDV z@bpO2&}V;%NV@PM@{zmC`^AgmV6G~c$kVvXIV9h*;?Ikwm;P1qR+-AW zk6OR}NpGAsGb7>8z?_Dm{zhkWhP`h>cKoxXVu@!GR>vhfm$*JV_dr^~Rcf3;b2B|z zc|SAvWg9u4ayI)P29(HWzNKU#chM;^Z+rF`g56HC7%`hx5%%gp7}nz3OZy{ zSMs~z7G*v=S~zl?dnJ4obLG48W+g$XqZC(HD2LT$sw8|5-|<`BOVZgto4=)e&i-!y zF)FRpkC9(%|M>hn$O+gRGT=1X3+9`eKP5ReIU%-WLfLqK{Ncp<&P{X&txjwBELsrxz0Ej0uvDKNeXC z4%)WitjP7~?`U1Am9$(S1x~0E$qI&i<=ii`>SayIi2XJ3cgH_v(}>@{zUsfZtgq4v z;xy66`PKXhs00roH=(V)hwE1C+L%GHolNb~7NXPt*k2{@Z+BkKik$iZF>D8KR2OXl z-HUD&Q>9?tB6G@AEi1iqHsOp*@31G@uQ{jNJ~;MT-Zk8#dtmyG^nTeE_j>Q9z}{daSUh+&uq&@=PJD(Z z%Ntz6`Sm7jN%{@OIorgPVL#hy=u7RhErk-zt^?+!SWB^7q*-)qv_*jMH4WVe-9vez z3iNBI;H|Aj+}xDG#mg2OSYk@45+#b4Bnl=cm5ps1`x*SBF9GXu8+j$wku&&{k^iDs z14n{y0`Z=6+2b+>r{7KQk*#_Lx?lM!2S0iL@h|s`%0)63{{E2rD0D0`lJJu`#As`A z^9u7=@PHpP=eJ#sSs$BjlUb7Jtlj1Q<&T8g1PVn<@T$>|oQ|?4j$Uj@h;5wEq+rdI zBP9-&d{_Kp@k=Q;6Bi~tO=uL`-oD0uoZUto!|LigrFPLwnA>@QKB3-TEXS9J8@$un)d=Bn0cwi;0VR}-v=r5**_94!tJdnvI42f-dIK*230Cjw<#74%)TNan02^rjDz#gR{pv3)pFUK6^1+(7N0*!P?eT zpY1>`Ag2>0h{vEaw1TVl+ep?<>1EYib)nQtE-JU<@5&zjl<+ImEjl__H`F>fDe%bu zF*w}&GLYka3M=1(P`$vY$mmD`VK6^O8YnfEODdbyfB#PfLs&?)IzhrPgtl5R2|)s)%IxY#Z9Z6;{i zW6m>GWg_Or>@9OeO9Qq8^NGI75Y%o+7U_#mBG!@D$ty^#F`0rw zh?IoYH5v3Cmp~ZpFtN@?wKR99~$iZ>OETsF<&#zQw@m^-1H)N7{w?wjZpBvU$w7Kz&M-+F2O06)s8X&LBgL}n0P z?W$Brxo_VnrK4*DqY7D;VYZp;XGP~H@U`Bt1{WxT=kK)1u{phA2(f$ zceeA+G4~_3_)aGz32hwP!dd7fCZt5fMlw&#}-*-Hh!gBiFxkstxaJoM<-wpB4fz>{9KFXYruyUGwcN(zpZ{ZPgjkkg0oCFO1kV=iEUPpEzU!Cp$ z&!q#;C~aJ&=ISx3k4rB!=vQ&abQx8?nJVy4V<%<>3;rE_%D*q#H}4&!C$1{IDI<>@ z;;J3Vk>i+6R@76A3Q%K>64AOb2$K4zdg-K4EeYmLQrah1Trqky|2Og#gK-lq#SW$SG!Nflsg&&U#HAh|dzn2L_t2*9INTJdZgUc=BLA7bOApcU z_;q8Z`cXe>4tvAwU23_Qgd5myX#>&#>!9|L3)qsSfLaU9Z_w;^-w^XzIyG8ZzfRmF zNoBl7Qv=9YVKI@Yb+e4*6>Xl;F$NER3l>Nktb4sFu9c7lFiabQUPV9P^W}WTaLZ=p zFTFYXi{2^C*EO^noJQv=GejY#v6#apI+sNEvO50;&m>EWn~4TQc9^%mQvcx!m~)An zd`n_3xTc%p6}7o)Z*x=77Au(=iy=B)Jfh)_mVspSqh+y5YJOxT`B-YCZ^3S};pi;7 z1A0W+LmU$)+1f{jvK>_cDTb#&O|T0&+2}0#nYwxm^erC13e)d>JMB6d=ikj_&?Ch& z`ZsJlSzGV`sVI)SNs&;A)RD6I2JryioZ%d$zxx>{*%M(#@ftL~mz3 zUr0V4n@_kG-C^lNJPEg@CKL6<0Y)GEvvw2mW!FPrsD*m1yo0>hWk`O=GpI+UpkugX zc09pzrO_9*s!?6tKrco-@^w63dyTvFNAe8I(`XOM1pUc-)rPo(x)U8q)khjDqs`j` zcgzckoKOS$80wa$5Iv;3rdn9xV)QHCnz#)et>4rOx|!Od zj#1(u@f}xZt30J^uZ&gVRkkI%D{{#CTKg!T##k%?nWYrfHMT$OD*;o(=vU(pI~sUN zjGdLfs+cjE9TM3~tb()LLX<^Is2Ny0G^(ViD0KNYz$dA8O0*mgbrX0d0k=g*>YL0D zBYzu*?d`;7=x#`4J%Ud$_REdwI#L?mORh^TGYUxM2#b+KbkHZmJgH8WhMGbWHjU^d z8^TKGV7ae#dl(&;*HPt$#qs3WNX~KQ5uM)S3 zH7ui*rD_kN1G+=JfKM|fpzox=sB6kK<2Nyd`_I~gUv11Lx2UHfYKW^bq8T zoYZ8!C)M_ zYQrbZOqElA%P#61B=26;BG_;gG<0l}k%py9t;i~RN&PumM7N?B&^%=VF-fhDABW0t zMJ>Ov4IipRuwz09HoukkPNCTKy~?m zh2`q#M0|+)9Y*ds*+9E2HMT4eK11bhIw+EtQ5tz7uV!dD&8UMHQ)5X{2vD+KUMoRF zK~Jcwc!@Q7IsHFS_dG_bmWMZ1uj4nh|8sY9^`NmCSOjO$f7FHK3T3gzFqh=N_2XnS zy_ZZgCza_yOq>Sv$)Ql2t)j4A~@}yF$hwz`(dD=qKA-cpa>ra!oYbTPrNrqY59l(S`LicHT416OIS?L z04Cvlytgq-JBZ9i?`YK_y;p`S<^ylfQRS3z2V1C51T|tjbV0T>&fyJ}=h!i$iT)9< zre+$Cp;|0MPsw<-6Pg6Po-S%R^010wV?fWCfYj7C;+2(Ycp7koW*C!^ao9e&4W-H@ zFbv(Pv?m6tJuzOJjqg)5JQ>JV;~*9KB(Xz$&kp0f^ke0z9;L=aCG$|hL;yuaIYpNQ z7vp2OCH_daAtQ}_D6btvFKQUJ9_gj-CEmzqDVMYwpNI_=XVYyJ7kO6NMf;^eM1mdz z%Hn;{2_d~6H-NpXpc~W&pg*nx7fC~HDKSA>L4Of{;g69~$~dxt;)K_DhfLBRq6c*! z`bl%)6ZH4`Lp-4TfHbp$N>BWyQ5m_eeS(efUwFr@|{3eDe{c)fE9bEx=HT%$gXgTFKr7ID2CEHYpLeJ@L z&@S5`TfkF-lbGTmo+xdIc)+*;+T`G;lZ?LDKFLtu|0D{stw@ zra!}8>%&0%y@Op2 z(+*Li#oCNLy4N~2Jdl|MtA8hX3B6HxL+aW@P#7|`PQ+YkDv9gK>Pq~NUc-2;%jjpl zBkEK?P_M-Eq@YemP3jMNuZYu|Vf}c54dOdm3j|kLFUakb-IhgxI`&eL-PB(Fq#MiYzrR5gGg-)=~Oqx*Yn? zx!Dtpd*PpIy%a8Nh8|b-AL<6b#vF~j1vlMmo;Sab9xyKm-?h~WCz=WxS0hI3`P@+n zpR*e#?(yxi{^8n_MUhqVDCT8wVa&!r4R*WnkM1!F5IJH$I!~BF9t3~MK53^toYyU% znisaM(_;AxvG@P1E4=U5=43v1n5$`C*|>B5x2FEmzeoi#oiAtU7}}NZtEYYb8#&M7 z&T%KvZ}=ty!&XRE2IKk>rGy%6{^(zpP$=0khVNw*(7)ov^{=21E+Ou#Fwsr~M zDAc0Lk=OW-CO?N;vqHOBQGE~0lyUMKW@_{abBdd6c_oa-OG%U2_k3@Bp*A^sum{)OV7}QkbFM{lGV4V@5n9EEc~dvMdeNY z;0Q|(0b>X753%Cv26U{^RvC;nCw5BJf%sVsy{+=-IrKX4OK|Nv-V^&JwLp8KS9L*a zXI#RQv|n07&?L`--jo5}%=6Mlx+1@dy$l?O1LR`0HNI8;2UeBd#8hD)bw#~Np5ZsM zZpc`DBe$Wu$WzH-!VUZj+DL5*%25$xr4rVTGhS)8JPFQ=<>|E%7v%t5cz~XP7s1EN z#eimags3IS=mEqAgufL!pabRuDm70#j5Z=qXan?G*aYOH(O38410g|IP~Bv2;G17mGl_Qy3jTsh z=n&XdDx#gRQpR4TANWSLBDWwFzA|vuW3WfsJ0SfesOjJw`wnYNVYCtUMXQXG;FP|v z3_zD&7Qyp1Yu82Ea9N4|qmTbQv_qDcU&WE)qfhQf@=`mRp}?yw&Dn(^SYv z(JlB#BNqIbBFbRBv_qhM&NCi?7FiZvbGIIgq+mPsT-9fc0L|?pkk~dNOSO@}N4kof z(7x)$K}|e~EZ6!#U&<8lc1$%K*e`7?xS>h|$p?dE1bjA4;QfyP|H(S=muT?a zFi^++;4|C;I`&t{5rd>KV2+Ii1_coD;PsoLYaorFD!8kL09UR6cpw_4OqQuS^|JW?0vdAEQH*A5wA%t64f4W5yA;ME!dW9J3^z8<7bIKhkYf1lT2 zNSL^!C&L&`H^u{<@Gs*NI5=uS%E5R2o3Tg#Vyp*lO%OU7dm0t>)}RfaFm7m{4J#0& z&cG8zEY=LY2)SS`pl=eO7#_g#;~%gg_#I*{9wJyglZXK-!%w0HIg2Vl)u7MOr|G%$ zNT}*eXPz+Y7@g_Lq%v`IGL1v+q#;@eoJ{eMbXHHEtZWs(i*LBwTzVuiQa>^`%!dnx zTW}jgA0qLgU??l}0@8|y2g`@21a^fg2AlbJ2jaPXTsYcRnXlXe&89HWGR@=}a)QZi ztzsqZ)g67DYwbTA=dJl2Ve3#kZ#!(;Xu-@CP0Q(X?0RY>b%{E|l!tdTkU2olX2(E+ zdK+_jsLs8x|FV6sPq){$rP$xu2H6ekOb9zVZ+^_opr7KvRMKB5qFjSt6Ly8#`5Sq? zx%aaU0yAL#@7BNYU-I|rKYRTc^K;?%yFXIX_Wfy<^V;7>YA1F>LTF)A8Pg5Ru9zHG zmHhqlS1b}O*gB%7@wpQi8366T<%n zpF}Qje~U$x^+*+xHkGzEhkEh$_%TUn($NAF3mhr5G^JD_Q^8irTazXybdGNlbHF*) znqe8uex^zizmd97U0N$26t)9XrFM8^z~*N?&2!DUU$WchF3A3sosh)>67tQT3+Xd{ z{YmdKmKD( z*qLNaU?0#X>Jza+AFI6JOK_o3+i=Iwz~GhO{!q`zu*iC@udrOX4xT0l{ef*{JM4Vy z663lgbV;n1R5-si`E`CI|D?p~`R2smclC9~+7DV9nM<>yDX2N4la2Lio^nqr59Ab! z_>Um*m$`7{W%yj^XYf>DL4fl&4oZRBfn)x2zHi=`Cts%hC|t7%(pduJuBxMid%gT2ck~JS%334<*QXRSs*ljJ`-L)&-9t>(JepBW$e6Z0T&Z+xpp7!OXd3>uh^# z^;sHPmYPws&aMMy^#wABFT(DkD}dPc3fA41+GH&r;;fcJntCO9kyt`3DtzU$I1A|K zYdJ0YU*u@GN~ktute5h)@iN|Do+93wUbjc{{N+94vj+}_Rz>^sHN+7jXokutZHHbH zdRwocJ@6Jp5Ap|DkD3Q0+}@zxUnl30Yalye3*?&*B=d;E#A{%}cEELP3s7e}G_O@`H$TKUmMvGEXVX`YqKrcBz7VE{}aP9ZJ}OlqGVzU z!2`9nCw>h}z^32pa^0 zjU$`VpP6LS7j`zA!<1t-)3ccC>~hmJb`LY1-b(joe$z{-9P%_&m)DZ1q=)9S{AaWyVr6y9)N6ox~tM{IL%wSA)dLq3Bay0%>Tc`_Eh+IhaBzF@Ph`z9+ zcS0`%Sue>rsO?k-Dr@CF(i_n&>=4fJ&$!pojgg(<1EB}O6@kQn$-l;z>0RaJJWo9x zJa*4=&n@qH|B>LJaP`Q%=#}UmE>5^9yajenRL0esS|h!Mu@BjgWFW= zegXRGKK-*+K-;I5RM#rIWkT*KO%M+Ys8E|<2;}E4krt7Y;VR)tq1(Y50ZU+*|D~_K zkM*_k8NTs>^x*aI=g5<2d+sv#i{B#b5?4u^^09c3#J(+%yf}m2^@>bFy}ror6`RgdNaKa zKHYix8O_mUm{!bW<~;L&dCXj5t}`d#=V~%(^hUZbZKf|!!>CHs6|xSQM${(m;LY%( zmf^xkbs5dU ze&UyjH0XuyOP6Ls%o51f9&MTpU!6?``|ALLj!jRwj4$j`EjC5cu@6t+WXH=8=uX0VUA+G?I^*AvkG!s(!Cj5S` z9uT-%NAn^}B9$Xg!o9*DLw|?n2IGTW1G4|PKg~Zma46U#+%qyYnj3A!HR7-G^@V2Q zG-<6|P~ENZdUIqTWVmd@PZ7<@p;TL-mG5IFuq8|bOvg+!O?6FSSWEt9-!Y?@F-%oh zMV`YfUIv7Z9rO^oJ1tS)s4$g9U7>dWf9;}{!Y3X^9i_ffA1Eu0(+PA{+6lj+bddT* zNfZIBu>{Z=+tOueA9am73xB5yl?`K2o_s{KA%eidUWC!uIP@_1SIZe2b*H`@lIfSK zR&|ErP^QW^rGnBau?x)c1fb{S<9BkUxg}9Nx+4-Bc^>W>4u{r;T?{_)YADB$tzL?V($c zsm9c0g78@j&?n&Ytfl^@S^_7(0QHLeLY^mEkR`!!MUx@o7BJPq#B1UK^fw!XL_`QD zS(|J_t|RY}I0f)9Ag2FCwV(!2O$rQJ}wsV|gTiXT#NO3D|d&Jrzc6MKojgk3^+ zLFI4o?fGQ>4tIc?%+=+-L~lmBMFWvFks^_t@IT=q;d>z^R6R^Y;-lxIwYh~{Z9V{s zWe?E-sU-1AKlP^e8|qFX@(7)Tneq8}DPX@RlRd~iqzT^tRp7I40!sS>DjjCD4hn~h z=75?1i@FHEzXN{%EL_nI)Dr4CwTl``^`SabWq|J93_krJnEA`7Thu`yysx9~QTM5b z)MEIlkJNIwb4F9g;J@}%gjC5LWC|$~w~6LNQQ`~a8?^_9`Vp)Z7D8vCI%LuIha}v& zaC#T?^WbPMqvruNvaKemN7Ml-qFz>pD}=(!E98-KarurkU8*c?0j7I%&~%Opt%Njw zA)lY$%Jqe_=;COHXvwHQaxStWav_37cSWajey%IOf-fgz3S&fBoF!F~D=55DQv0Fp z0*+?{GJ`3|sgK7#V|DQjxQ=gt%pM2a!*OJLxZXF(ASuG#&rp>qmHb2=B{N{}`bFL* zufRUG7G}*v@;|a8?1zWob8dk75hdfPvQ!b+yNbcyY=JrEBhz5Uc*x7-6L9iPBwNGR zCn7*>CMFWqiDcp){x{I(8{;ukUZ@Mz`D%9+RWpF*-c2d2pvp?Qiku`Dk&jAGrLj_q65C9#3DMT!8KXsujA>7=k~do^C;wWvM@IpmV9^bU&C` zL+N33GrA!?8~$t%U5l;>f6|_w02vzN=*!d^>KWWaO{f|WtN5DCBu#KiBFP)X1R{*W>2gzLpYk23JB`(|>Or-Hwn`fXuR0d4&qyGoKQq#x z-(Vgjt#tvG{TZ|+HW`TcComo>jZeaBK%e$!NJ|(5c@VRRRm3ZzH<6!Q4y#L)xJz6j z%981jI8y<#8NR~z--(+<4l$XS4(m`e;yLjR&ktD)XYfsU2%CbBhrEJr@EqvEj)Bi~ z7SujogVS~;I8blEDQP`iw+Yb0binwj-`8j9#~?MpqW6S36xG^kt8|ZgPWuGe0sm^5 zT77M?`bpoRZr76au%|Y>K1*1DzEu*d0 z`Y9i@R{yW0tBG1cb%$0|@oI0iiONg$G4P;DsnfM#>RWJ_Zc^j*HJS*o_fD&>^)kxn z#f^^oGnie4fa_TYsL4L#0a_JngJytFb`&uKk^^?4yNO(E6IKz%V>s!-{zbpQ{yUs3 zjGe~f>4U^Fq7VKZZv-=OC{Y1-U?0e;WFz!1B85yPfDnzV)FgC>VIx%PB=Gk0(BgPq z930G;8#J+MS~tg1E?sY`iwL7vmd3+()u9%!TIi&$#`mJvqP0Z_a|_9*e3LpO`PD|M zo0u*QMEj}b_0pOoRRI6_1f-I5)gZ~s(v8R=^dm7}SrXod7Y4q^aP0?lV8!u&AlLnp zyh3csoYAK%k>~!^$2&hA1LN)4hmw;$w0#F$M^7O>i4lO5F)rHb=B$dLzpzWR9xKrLDi&`O-5r z(Q+8Age2nc)HQHw+6uXQbLgV*+&oZu#TM0H0(l7fxnPIf4zIO?nyXwjIw&=m(`aAG zhi6g)l#o$NJxiceN$d`U>seR}JdMu=I>17vTV#__h#o_AjBpr`F3G!50zHD{P`xApjl=g4JfupM zpvI{xnjv)|YZ|@LD7V2xYcJ4CQi$lU?goV-QBOuBEtwf6ZO~2HIC`yc6LRYr(_3M! zn2l96U6QWzuc$-NVHuK>nE_}>JtJ>7+LBH*_yFHULdaQv&j7n!5~MwemJB9H7^I48!Fxh78FhS`jvVYo@H(>C0VcM+3XharK65_5OqsE?D+0^L@q_%m`{0En-3_} zO&hs3S`77A&7%GOua-B)PizV=U@P=w+gEQ>=CHEWemt+DsTtahY8WnxU8L`6N2OZ8 zkhy`ay#nkal zwNQe~g@T<6aaV!0Y%ntr6Jpyma#gR{pUQB&|9J|(X&OO|^P@079Wsm3VFS0znr86My zm$aB>xo6_&U?y>iPK=y0kBSs_7}0fVVcV`yCtHR$9CJ`v@0*jH7AU0`XTL@oI{ylE z#=AI9W|mG`;qzcUY|q@o9C!T7T_ugH>HSL|2U@}s^>1Pdaz_}*G{xeg3)uojvKr;b zCI9#>o1e3?L`iW-i`a^}bB)>YpK_`>>IKfn_Kb*PPQs0hTM2ujLln+gAg@L2Ztooz z;#bI6wvb=RxA{-IM7Oxq$Ci2i^nUr*-qE}}G4W4Mp~~qg_HR_sJ1J&IWCqy;`n43T zaWpR}CS$ezfilVXZu%!9rgXk<`;t0`>+8#6Hl&}*zt;bpEQp_%dn3EaE-E;`fkhV7 zXM+~{LOY^f1tOb?o~XZ&?%Eav-OPMsiToE*f}E<*dM$QPbf{@=^u5I-GrW}a;>X6K zY3cLqABnA=FR`n65xOcfjpg+Ba4SdCpkbXaj+7TTf8-R8Yb5JjoU?FrIC(*@$j?pK zpVQB@f_NlN#Gc8s%$K8g$U8tb>QC;JzS33HH)2=kIe$%44Q&b}4D>~c)6GH!?1z+O zZggzKqg&QX(4$SxjxM2!nUHLAq8>em6vulC$*viBMJ+NpHE_VuE7;MhDc8aylP!O` z#y8=Y!#!9s`&B}eE3LjYb@sbruL}|Go~u>(Cwh$t2`AZ^T!^mBb+yIu8HkI>ub;rC zN1i%{tKU2;Q<`~h8Xv7g0##xwXJ^N}&^B_rOy|^6N_m^fy~8zwt8H9k?A&v-x8+ja z&iD)AINYE@fdY;qQd+cF{M%?vb+5IW_orPG8mfOgDrJu<5XxEPFFtl*BBbvmV6PG3ICIEBX0zwPM{f zmYG${!Czksv=^pEZMIs*r07fYH8ov|K)c?v`GfIE%*D1+fnXy?E&i2u*1S9t zHcgHevuB17x+IXjuj|##bGVIKHCJ}VoBW@$Dke+`XX{1ntKGw6fAJlJ&ap#$IF=AQ z{72_f>(T|Xd+he_+lo80TbX8ImADSp1i8O7-!a@z#)f6reT`4#aKx!6F*TK^|va)BnhYK zwek=ngY6i)N!Mp{a*HQ_ORtmMKj0Nd#FN2|+F5&CK(PIZj-g%~--U>6SagJ*%-&UJ z3qMWo6asz2mvVNGJTWFPGld*eiQsd4G}6@)@aJ^{TtpR$jsx7DuROnMqPl=zFT69l_#_)p9*Yb@v_n zbNQ(<$`lPaT}6CdV>SiXF!j*R(ZdPj)91$>;XmT%B%GeERWiOQ=V4ZMby0E&ZeF}K z>ud6ujIv30+)?uiwLIv&LxnN60={E$r*aR)kIY$@G|qbj5^C=&&Cxz;B6~49mR@h< zYPF5KL{FtBc0=}AX8X%pnyOpSp=xpFi*U{uVY(U|=-BF+m$)FeUd$u^_xMH`6$-`Y zK!}~*Pd;k1;pX=DTvcPd zsbjdPDOc`Kyq9Y+(;}a&tz=GEX`3O;gRanh*i|*3UK)Ro%+&50-KZNs ziQX#ZC=<j1A!E-pVY8mWl=SWk!quFvfdYV)ncM@YG zUdORudrKz)xRiQL7WKA|DWV+>pRw-M8;J;eK)6N?#Nzz><4P%4gJ+ziwwJ49D-r#R zUWHT?MzAxD6rqwis;vm-%hx-2K>KC65RIiWjpov7Q;bqrTis2VyQ$`K|*gaP534_Lp2a_BVgLX6*BUOw`}*EmN8?1O4GqO>T2(Apk+9-!}U-| zvZfM1^;Hp){x4FN9AJ7EY`}cS6veI=Mn;nz^v+6iWHEbQM!9tRRBXAQX8$q23fw1l zlTY-=;hT38DpPmiaAZqTr_lHCihOmVmw}SX zGOLAWkR$Sfbnr8YT6jn;s{X|mlCNN;UEZKBDWDhWcj9 zq{vue3Z{U*UI|h#wrgYPMCmQE3roR@2y+O6{G(LXiy&`Ip958>Nn~whlsMQvFmjfF zj5O^R?vt8O71aVT*Uqt>!m-p0qG@EHV|j2F?SRVgGij4)li(w_$iI|&2o6!&2INO@=Iy4tg#)W8#shUi%0|F%+zr~;GLA(mpD0GO;Aznu)`Z97}U8i-R zUu*fK^^6OxA&TfdW?EQ=nvlRduv998?@zmcNkic+!~3o4k#~XJj#uhM={X{p-iGJk zGx2#^6Z|^g&vIL8r@S^*3h%Z=#FuCtZ3lTq%0aqP(;$aruYMivg^d@_fe@R>&o+LW zn(@ofWz5xZ8rxNvYI(=+P^&Ty;&33D;aDs8RUO0JQfJ|A?mRP*VtrjLr|@Nw4&)(v zjnX%?E&e^^Ti4JRIzy2(PIeqYo`&QHg>Abzs7u80ZWt!43zftGA5qK(iH zYowny(kKsKfk5a_{544Gz8yYcA+#a6o_)W%7sB1eSSU26+o}`P<(5}rL3|Ns;-13_^;IgYZo-@Jx+5Ug2v>4g(7mDIET(-{`!Pd; zADHh<0j{)GpP4HP^gOAk`mgPoxGy}`4tcIZA7mL?g@XPI&I#N_TFS?CQ5yoGJP}m& zz4{<*8P0}3(UK1MIJtCB(fs~jgMtEADBp_Eoe6O5Gl zSvCp{^t#pz?gl=F*dt}*3fVB+&iYOtA?~+#4|7Nh^AzQsw2PtJWtTR$2H*jHYTegLUYkCa$cV&dtq_xNbvZYuX&ZPhG zb**vyE`2UDl6!@eBFsZ94%5roH@-J3gJD_Z1i_52sf5;rgtq)^|iI22a7L{M5 z?nACp6|pKp0{VuOrBP}wRZI>bU9l002R{icuN2@C|D41`+ue6c4C;uX+APGuYV+nCy z+o{jgE|D7h-LlECxWM5ocl&;iv{Zz!MA${V}Y9yrh@v}yW1RTQtGi-ES) z7Hy|Gfgp8Dt_=yLaX`ZOt-VKHXnnPQ+BDEn2BQm*#^_FPlw=^ia4YV^rot$-hXkX+ z_zSE#HHXQMohRE;tI5B%i`=}X>h0gg|gn=QLf!cpy#Kk7Iy|2&K^~^gfUn>uITtiL5$g5j;qQ!wpXo^c0qwmF*@Zkw zhN;oeB~+Is+d2HEX;p~O1x6CI@ z2TiG_4bW$IifPH{VwHfq0-h+Xg1!K#U}b@9b6bB44yQRtBIFDg z1GOa|o{Lu`8c-Oh)FqkY^iELtCNedclWaG3CHEfpM7g>&4|3@lv6y(9|H(b%=5ZIJT!e{!h#n5Tj&2WSczX$#-7e^P|C;mG zKh3kovmx@@~$=R?Y$+-G@ZrSzNTbKr$$f|Q5NryIML1ECT)m}&}->ewd3j; zEgSNa2Y~QRx@oRyUT(>>+^~GNd@|Ls zyfxiqo0uLm^_Z#5ImkL5gFnR5uutIp%F`dI4b?T$G1)E;5pD@P_}0(^ZyrCxZXYsCy8m%YU+LEhmnA_ zQ0h^)=xhpSou*6nJo7=zF=r1)yyLZfjcbcbu~oC*c5JsbvCXzUv7Ru^fnK{F)JHNK zAA_~WY8a37%bH7#S9d5I<$ZEvxhY(Yt`tXb@hc0`Fj4nG9%P+n-gzel)QptiS6IM>tKTPwKL-P1eN+begM zw@cpZyuEH;Fu!L+UNF=+x;NNAIGt}KJ`(3hs#1-3t{>5=(GF89>N#`I9JP+Me6s`{ zsg6$e74~s4-D5^r)9lS1l`WU8-L3h}^OzTOSLy=Un0NwOz;dHBsKvL{EP07iL!Jhi zNx-ldAMx8H8S)|0<>ty(ZMN1x--qU5COTd^ z=fyO2&at1c*SFuX?YCxIKAE4G9x@zNlWt56#(QAJakue7D+PW8MIq$((s}V;VH5v7 z`hrV}PKrDWKa23uzk=0+*F%RxUf+bs!GPrdFR;wJJGk0E!*|JF&+{!%9I|>gaw&PY zgFC`YgGcypK1nDp-{%K|)}M#$GdT4i%~G4_GUjbeFIzWD1xw4Ap7zs@`p$=r%1+Ey z*D=Pv*P3d}w)`=rusi7-swq{Hya08cdPrA&vDQjkuAER6xelm%`=l7@k+4(#qEweR z%kku#&h7Hgkl%XZv%w02Z2I z7j}A;U}t%qduhnYJ##ky1^U% z{bCn>C3jjXBRtXH!K%Lk9|P(!M~r6om>ZfFSPDAlIfgry+y1kqI}>e`V}A+}J z$nM@l6sO|Y|4gKLi#5%<-geUV-l014Vy?w}b0)f$I3tj{)epwxko_L`7Im}CoMOu| zx4;$M0qV!QfEuQ0Q0!IPD>vlRQl5~-zu+lY_ifV6$WMO(-X9tkEE9bnDiOW!9T@x> zY!z4?Obrb3O@x%3kG}o>N_n0@IIn`+Hu9=Dgi%nIaZLIjT3YYmSG*{ zJ=V3h_4egyfoH?`Q5y-)Y}6w<|i(m*_7QZXC&!3rgRm6?!4++7NB9K^@DmiDyiAI!#WfFf2mZnqlQQ`))r0*XM^^z} z#nG;JXZ=__xI>ZR?(XhT+}$B~ad$87ZpDgQad(H}L{4&C*YEq^hli$wgq$-w^UWu( zl(girIiUmLj&B#gDSmz8w#1a=W62-0UC*{8=~7a|q@hV&lCt2-v?OkS%rm>)zR=pp zQc9e|H{ouQ9Z(dzs#k>f%v_*0J_Voq_xU>q=KHVuE@gIiS8x?{PV(%|_~hP}_C0N; zEB;4<_f6XLU$wmhew@n~m3jK7o+YMtPcM-boBATN?~igB>pY3B1#-w$H?xPYzk8PX zS??B2(C=yA$Veucby%MYtu0&aIUV;BlI^jvU6YQ+^-KPi*dgIyeEk?l!jU*<+_tzw z_UhJ);$^X<S6BICTb<%OxGszD_{@g^}X`s_s#L7XAN<6&)ShzJ}Z!M zH9eAH{F<9J^5=xqCz;02q>S|$V}D+A{*`(${bI(#v|COkwO_j8JpHSd`?oA9BWLK2 zb7!Dtpq_iBx+DBaesA>A6uPXrfEg#Q=KYqp)~S{XvGr`f$L@*$9 zY4GmTbvC~K7E@-=0I)>{4F)LJk~ z00;gV41;d`523x3WxZ#6;n){bH@-(g@5Exs-lWah+}TcNpPBtm_Kw-lW_zElb#i?2 z`^0GpE93LVjY2d5mfxZYO?j3pNqW&G&7WFv?TdUMdM|u3cpYwg1$?c&tKD9=BXhRP zmsK&P)J4D_oQLPK8NZw#(&@~$sh=}Trl0%u(-o7tBU5nhNp0i0{42q=%Nb4&xI3nI z$@t;9kowLw$vr4j^UZKg^xX~Y_caZl3jUNkt83**|oqSSC&3o73)jK&6uXK0}^T{>`FY8yf&#=wiDU* zW_yvnSGLL7_axU%UYo2W{F&s6{}TH+He`Qkn__i~`NY=TF6KFX%gi+DXtmVRO5JFI z2p`%LnC$m?>v(nVF3)Dyde^H74sr z)?ZoW-NmyGx^sAB=QGb`S8NvL8{zeOt_LdnlLGz2p97WQ?{)^L(7wiN__ALkG?&Pa zlL}h;OCj4f+i+X+n4^wjaeZT`m^ra$fTzA}x5iX-JhE1?J+NF5yNDn8eaIMerYDn& zw5ML17tDEP2+DNh;pMayN|mLcH2ItyrE)PJ+1}g|p@cw5Wiwpjca6 zpSYxi8VSo2`zFpw9GP4^DPPhbi8m9@C*+H7f&6u8jN6_7Cd6#1i5S68`ICtSvt_$6 zSX&Qd{*&nQ$h**x;HH4!_xQT|i0`F$ljo^>g!`uZwR^v(gD1{?z%w9gspq7tswby! ztFxKEbY>~{YtJZWb)Vw;n%UNSD07kPH`fy9R&T_$&*S#Bbf@|I_%3)mM015U%S+^4 zic{O8Sdsc0)6ka)o^^Z4RGAFcd`|;eX?|vUWr46 z2_h@lfr!k;$n;ElA`pqU;efsj*tTX+gl}(Lpjwe(PbW2{#8>u>wyDhnWgGat3%8E2YerWU3@Kr?}3+I?0@g6?(OM+?&;`l zhf`MHv(Vetx5t&>ac14}7WPi`cs)_qPsT7ALNGe8+e}>$Mq1KNp;2dwlTJ@)>F2cmfI3%sbXmf-@s#*HDVNq zz=G^)HihfJ_F-9OI{lTJWEO(Idn06qo6IfrEOU_7liF)ugOl@SLUCTQn9OARa1(&V z&n2z1JQm+bc`SD8TeuE4wI8vUwfQVNEI%z?>9920x>$O`x8&XIKg>LKEqZ9V$#Lp0 zsvmt%8?HYvw(HCE&gv?ovUXSgGkQaAqyu*m9jl&=u8a7>l~qrumJ$l(mM;ZA08d#m z&?0PxPxv;5e6r6wIWW+_K8SfZ^5U_-pjqgA=q5S{WrNOW zgK*PGC-o1xSGcFzM`O*FKrqfGjiKK?n7zbY=O3^hraZVO@xnm9zED+4;S$6e)(4oX zeE}A?EI*Vp*nV6~wi>g9slg^Oam*Te0BOOzq0f^cY+b4y9pDJnlRL}aF+1{3wj+5f z*w}AiMs?#J^A&_MjKm*dPx0&c&+K-QWHbZP-S;`a<^GGY15qw2y9p6zHBB=|^0uB2FSNNgMHm`bj9q2HE%U7`CwwNDZ13 z>p`tzyH-P-3ysy5WH_BipGg_a2Ww$@2zw!1nVDhz;^Q3C%p6`J=9u;abN?sNS=8Ub zO{~m3H~!MojeZdqbzON&k2Ly5uS9OMg~Bb(1bMI_`;)?BnSYJlp`mctuIWuRo`<{8 z4Sb2_Vc%zYSL7gl($_R}SsN6JRa?soBbBup!4aPF?3T!!Xo68VQX(2s<)~;@Ra4a` z`Z=vVRB3={&K^R`8!+RI?Q)RTuWNXnZj%Hsj}AYQlAHp zOZ~;VzLl0Ix=rn7VM7ZjtF=xrKKj71AUZ?)q^Db_WpyTPY-`-})kBVrp03JIOGp1$ z`41`CBbwD{-v5T(6`U!d&FML}0=8uXVJ1_$Qs7_J8{;^9d8bMzddZj^l`KqnCPyCD2*BWM6XQPh%6FX$LY^C9ZLQrItow z#irg}T<;(+PWJr{#(myUiYd_h!zbB^p&ZnDx}Gn?o{{y7&Cg`|GxTG8ePgsgo*%6& zVupv$iXD{x#u8|pkCa`uO_9mUA=b2>bNMixXsU6_N^vZ8T1z)y@sgI`ydmUPj?>A~ z94*y%qrgPxhKMJ7U!!Jbr9AI6&eJ$?sJdMJYOSKbFt+o{*pq?#@q@ze#biFOw_EZ< zC10p{(lvQX;90K1fm%w(T*)4l-s2b__+~GqPSpm-&Q(%EBXiTP#U#@n>(3EaIq=GE zS0@X+a-V8rvniMK#oQ{seDrUoImtA$aqB$J?2CPi$vLK;(#U!=REngqCG`B!pVC9R zv-e}RDbAPTS9Y487ss1jl^pDF^`)FG>G5}qWf%X#WwC#c>}M~?wdj6&106Bdx7hZ` zACBwqIb?`&jh%k(kuY6(qAid*1vYb4p_28+J{AmyLiX0a{gz+!9v2srK<0a*(gyp2 zpFd-JP}lv15{tSez8KpKRl!~UwD~R?Yipg=CFz!WBy&@qPVTFuB41q|N-yX9kpjkZ zTfg8-^1GRXYex6*e-sX|L(Q1rcX5w+E^Dg87S50o&{)!y7C>&4VA#J+e%1 zq#N`;VZQI4JxiOc`os^eAyE7q=t;8_WzHH~9nIZTY<{*zh7rFN#EuPePW6s3HE0*_ z$}#$6`w3SD*ycyV6>OC}cP)L@xsi@iWqN9Oj$@NAkJOCnp2-)mou}T^h#bB+F?v~!z9uI$sjJ91i76qyJo$fWZYM4)~ zv~&!Chej<!1!7jm+Bp_bY#X=G4gxZswXLdnb|1RBERDA8|#)--4 zx#<2m#P5e<|71FQTV&gmS}%ElUdVSXdunula7Fg6u8QI~Gez%d3%J_E9whsGIjtw{ z7t#jC4i{7Xsp2#09tn{@J)LrCfoajXG553>rGX`0>&Y&P_-*ry(cWK4+aiM^{{qps z!rTRitT|@&K%%sl8ylF!o@M4~FO?NSNmVp``Y5u5GQ)59HS8~M26@fSkWb1{rVIBa z&`g_X`4Sx-(WJ*_REdj}h%)Fc%9OV&OikKF?uRUQ(l49~wLa2$_ zOS%E~PoSJNq8E+U=ZX1{o~!z)bh4Uy99_r0W`0D!skD@+ex#r{#mzPLIJ6m{ zd}~Fg9#}p?KO~l4M(*Oh4WozH_Mvao2X1Xs=|;*aaT0UY_+d;GE|dI9d%RQHc&naZ z&X_mNxpE$6litOw8Lr7p)85GS%$*7e?}|(}TdJ=j7sA8!U;4J-ETx+bdjv;dE+0{w+;2Ct)B zY(?%cSxu@6wYVD0Dz1q5FGqw*>}{!y_=KA!ERl$%E05NvWsB5Pcxb6$jhEbFGwV>> z2(gNFs_k#|Dt;56+G|*H3xX}v@!6AC`&IVCyr>b}2>zgH>qA7hy~D zXV`0C-<2eZ>?=AuS;*8RiDW0KiGJK}b^_<-x`EBMRs0|gmi*S9mK57i+gj@}TYr0^ zt(JA3ZN62ob+AmfxTFT+Utm9H6Q&BqxxuJKe&gz*4(ZN*VGgtHz$x9wR%c$aU6_{4 z4x*EsbTROtXPF`JVQ*^tl*`JdNbTs~5qqdosB`ePzoUP-?}?XqgYFM5zdPM!aaDFd z$!wlACF`AYc4l^0y^L^XW9OBOCC=*Um7Firtxm~V#?>Rs>lx_j5k?SG+5b{cWZDXwGY}@R_p6s|`U+374&Nw_lEytt^l24cpM(ief2Aq@K%G-HHtN<0@FXAd;rqo*+CuO(Xl_q1-_D0$&)sT`U zhg4MpKSpdKz7y*4NoGS$-tCTK&l_dP!nwGv`+20f|h@bFpuH~<= za8hFk4bLDW1q??bKa9RX*Su&b>bvT@=xgX%@A10ztX5eWSx=o6GPh=?X2dzSJ11vU zbyjt5NT1*=4&8}AGJa;X&M4q4_tX!p2|SOSmwzZ0a4Bz6TZzFw6b}`vP zzc4F+L%tt;#gEa?AtBr?c*g&`f0lQ#r>f^#7I80fPRWdQmV}CWo%Dw3!!kysdD7pc zr>3_^UzDDn9!ZaJ9!sz3tmb^;Y~87XEG4W>tY58Xt-Gx=tiRbdTAx}QTi;psS)dpv(TG(y_y{+P zy}(dp9i~s!^@-XM_)SO= zet=i);D>|bd4oN}-ezAimDy8FbLKu1&%R&=fP=gNZi_v^oKI)AvBlZ>>|fxXuj00W zpDyw_`Hx^^JHf7QC~OnFLNPH1I^-QimsmxbB+ZvAAFK9FTCH4Ya zKMUBWa%?LmiF~7mnoslrT2n9-heQKlLyZYW{Ga^Ufo!Sn&TwDMTH-36RW$3r3)?3(3tCc7@Vwq|~CrMS1b!@g$z34x2Dmf`&J4P`d= z&ue2Ob%?~Wow&wOi+C*lu*|alYrA5*Z{KN~>l`q3%_# z(^-FHj>sCGsb-~Qw$G~IESD8^o^-9sOmXeXjLDkfTH^ZbD(jx(n&X<~rM<<29sPfV z>EK52vhQl^w2j6fV+*)~dAYA#ejzGkimAAtj@e4s2H5Xd)9qF}ZGUImk1Bb-H6FV8 zEi8R4lcgcjB54y&$RQywR4QhIW4?{c38jUlY)7^sxa4P^zOWzB(e}sqoD55SDXz^kKIrwD8cu`PFclHW81RXSRXTqS;y>S z`Z8OYT+k)yi9avIuW8Ik<_U0fjab5RY>>&xR%M5=L)qrgFB!%j25&zGYt|WE#*)xS zu;VGD2xElTf++-ns5_1rUtPG*cjeD><+%~)+2>#ylezR3sG3eP`sj?78#t!_q9-C* z;lZKwU}m6UK=hwR79)FldS-iK-Fcw!k?gAE{^T0$?(A}VOjln|SC8&l|WM$Le|N?dVPNrv<78?}l6vRJhs{RRMC(N$sRJ;I6*|1(tErAC{Tc3ijo8hy6Fl zPg@sz*gn)2w3(K*);5-B7S%Fbnkl^#(}elrB=GbH0}Z$kSilc#Np>Amn<)te#LduN zKZ^P04ya$Ah5FcWx-0#O?m%h~2QwY>RGt~id|_rY96K8vrFhoLmcmK(F#Xv9&{6oo z_JW2F1$I9pG{eogRX8h1Vs%Uq+ae>7L|s^d&n|=p#}dxX_G9ZX%g9K&J+$O%7-jY2 z>Jw$6d@VXDvOnw%J__aylnI>n*YkzEPR~QAp}h8Nc3*L)yDtNYwaSy%JJ3_y+s2#I zHvs+}?>q~9^L%daV}Ea`mJRZ*4^9s306X-S92bq#6>T(J`Z(lxFWG!TWpRY%tmK69 z&38v%$J3Yxj{c6K4k5;{|8y*}^|E;^BdqnTNE)ODl3iROUJ}xT7kqC15gTDopbt;8 zgPAlk8EpI`z^vLyM|wQwIRDZG$U&^xYq}UYPfin)93vk|Pi8L8=NRT1bBu8TmA-_X z3LS_kY!0pyG$kayKQuH5lxU6$V5|uF5Z!wM`&kF8brc#aO@v8&XTBTvmVJ!#eG*z$ z58%?*$(XNC*IKFdl$p`{5iz_Xv^qF8Fy601H6`Q?dVYG=x#zpRuB+}dt_SY3t}m{? zUBf*ot~$Q%u9m((T_yba+)SXedk#Kc!3@5h|D|_xFfn*FG%R{KQcT;Wo-q3xo#BwR zm0Qj35;Fu*ms=I@yiF*w)E6F?3z1!>X?)2%N9sc~DM1K))Wxw0A)nCv%EHFQq7%c{3Rhw2jp-2u)7;Go0$I03?>N*5|7ws+<5L!D6&w{sYw=o@dJf+!Ww8f zT?Q_=38H&x{Ny>@Y22|*m;|OUSqKz%L(JH|8NE<%=FqpOh16H_?&#*o{P3IL_F!Vb zg!jy@}o%-h-ZB_(;6Vyb0b7z{bsnN}m+??mGy@nz#O%fjG=t_rQgy zV`z4$TBJg>zr0=Tt#5?eTxYPqM{zy*QesW%vD6W2Zewgq9qIOAj+ypc_Qm$ww&}J4 z)}oeX&@K|C9%2zOi$9JmTgOaeDpWLzu+y1Eh*}|XfPALQle2URoa`RaC1{TRoi0b; zrKi#F=~yzG+$V)`Gp&`K$Y{`HTAd2eN!rK!Py7%ai^k&+Bg(~Gt1a-ToIwLn8Px_S_g>nV)oMZSGL`@ zL|Y<|2&JT9;yl66??mpzBT{71G4Rcsg(+$c`UCC48n}oEuGl$}ggfsiX@%AJmyKmL zsE~ALhcZFLvclvS{RV0)o8VP75O_a5$TvXSZpMefsVr}-iu0j9yCnyF&&c7z|6p^ zKo9>ie^p<=f8I+6_G16t@Z|So_@@S{M=!~J^h0p%Ww{CB0)D(@u=w1@TTT0Z#}V9t zQ?321qovWvjeD|FnIWV%c}tH2?*EH1-$*fZqk`#zL+cvyJ9~#K2F0jmh{p?1r`$yi z-BH-dw->^wIkIz?nB!y&Er6w18s6Tz`HZ>?+*p93=(6z5o=!KXSCRhcl^eZp+19aO4ja@E*BxsxQoI!Qt1B9o6U%Cw_sx(+;G zH^A+)8Wn2}F>e8f^--u2-dzFdJup#|=I@(lO0=pblkrH1l)@&((t`?)&!9PU4qx1m~?t$!iwg$Mja z%U!`I?X^0rtsH{wmczj8yp7n38^CtPQy*m%GT#}OjKk(QoielQtMwyLN8g40;^04W zWyL1aH?gR+4fuLZmA_GYd)PgHX0x}jB$gOEqX`>%l4GXj(S{J30(jy{=z63{O`dr<&)&J6e z-#;V#HqZc`a{0p7LfM1qk7xdb*LjLZ- zdFn1LU{qxGv$v!o*4ozNmasI^I>1`pKGsS*qT*%gF+ZIf%FF{I^*^;5G;>Dl+0Erp ztbJuRhg0rT;wMMYVQMY>5~^F)NWCpfBvZ^LZsI8}NFLJp=-2R^eE~n;bo72%=wywe z9#U=TXQVwlnJWVYrEdHqK8F8}OluWWfqe_KQcJELYlps+pJ~p0Ab+qP@`aoMf?^&& zmtDp8;7)MUpr|+&)nx&4fw@c%Vdjw!gl77XC8&IM(6O`$#oT}~%IJbU_E0OX-+&_0 zhsc@8&Cs>LyFf>8p1@=8jX<41YG7?JE08BFgf55LM~4MP`9$zwD_2H5V69*aNKLKTr5@G|mc!C(i!4r& zZt-{cOQZ>vsIOMC$y1`MI;3Pkpt>jci5W9Eaz|&_^IMz^agfYmjL78M4zCh zFoc^+D$(=cq1*zl@FUGbW?Q&M@1_4HyUA^4GuwcV2M?x}m?iEIgJL^r1J3e1)Yc}u z5#5QE&?tO_sK1cA#{R)^sHS}EU98qoJ|3EQMTJX3Z`{x2fh@SdzF>Q?zu23&E8Ju$ z`3euvyJlshmY$@q(=xUBT4imkS`x3mWMoLVbx`-m`rXiCy%Wq8ydEwQ&P2Z691Nnz zk!Y|+^hc;@WJchhV9=i&viW!Vw|P%_3i&U)-+N1E-OTEeQO$3JU#m6E`SbvGw$)|Z z>u6{H-3pF_HN~3W`T>2Hx6s>5=3jytmRIx3m%@`H&7xg_hL_bx@Sr}59;hkY5b}#H zrQa=;t;=kat(C#B$!-13LP@2-?p#3~KtF5-bei74{i-rNu}!fAIe>3A24qS)Ue@iodVD4x@KD-NYB5Pwf9l}6<)tfbv?_x z;58!>W@E46nUT}l-|;o3qoarYo^^pmuc4R)#dFbbGCjJw3l|b4u zy%;ByRo`>VxH$d-)Z)rPMe#Owmn)2W^E!K;{R}| zn|i^$Ys4{ZEv7s@&g`STRwl^6WrCBkN?xi~)32L#=!@hUlL92jODJQ$ME|odI)Yb( z=3-Z|C{T|n+)*~3@gcswgI;zv;Anl+TRIQWOgY(0>_n(%PGC&Z3b zgm!X>{td3ge}bXe(%b|N$!mSAUQl1Hb<}>-{!{-^2diz7;o8)6<(zU-nE*uiDdmE~ zsG>Sl-3y)Tf$Dg*wyG<~p$O;1jBAx#M6Myfk6Psx@^IM|eHHBoMcU1gxJUvt80W}m zm2%nuy#%m0b|%QIW)HEC*)q_@e8YxV9!j_iQSY@!59d6ua?QECtW5eLV|_xkqdu8M zs7X{sdJ27k=AiM`pG<>dVk5k-rlc7;L+s3-jE$|pc0=VjmhFI8^*^oA?aWH%Dys2H zh>>$JYuL)p!^Z^N1&5d(OfuG{5i^TfgM95VV`XhDi@ItRs;mM`n8>IcevmXM)j7!v z@_;PE&v`<3Ms@a%+5-pWzVH#LLgl5(Q^{~bwovz=NIo5iyjtcD;|=&c>x})-muqF@ zG@k0G_0BrcPiw8Ug4!o_u3AIQuI5q;LUk{{YAURnpoSD#`Jn7omM8<1_DU5n67nfu z*Ilq{2f_oAu<=* zjV8Dd7#?a6|20PImxU*wa}lb!l?NBW4N_H;$tfLg*D5G96Euy zjPF`YhTvBSKC|*^|h(gH>wG;B`c<2-EnS~ktyIG*2I1rg=%3XSqG)+8^{#Wpu@Ng`h0yz zHn8R?gi}13ksrnwhlU_uRYQLb@I7e#$7Dl|5 zOEXlz`c*xno>Nb#^VGF!Z?&wN8)v&Plr$yPtGrbV#iL}$pI#{@c56koj#^LMqi$DA zBZo}an&~TbC$thr8yk#!Mt-xqxfj0c6{zJ@7Ur$Pp~;v@=RpJ-NG6hbU{t?`&UX~& zwgyuN3aDk6SfEbdkSy$@Yve9!b#R`^Gjb6+v0KS1vH-h!IGGQQ`9`e%GPs95A+K;M zAL0qTBwz8bT|*SOK(6Amckw+}@bMBkg?)GsU%e5%sPWJ=97KBI8CJuN%!g=5;2Dqs zNAmM{GV{?jC_@)PZu}hit*;iwe-jW%^Qkq}N@^*USAWOH zCiqiFwI|}iA@#W$QccwYRnFO{{y05FKcyE!Km3g0H1dJVy9>_i)$y$N;nn4!JJRdu z$2hmC*kAcc1yT>O5GYL21MjmR-eF1704KQ;P_OaWML+4UIO`v<4`0y_(1klo@4>$M zKwqM-(YN3sZ-pAM01sXfUtJ85A`9xKGM(@Lzc&^uWrKgaM)Nqu3ikCobXT@u%^uQQ z=@fc2J&7KLnz#ziv6>_80%=I{qk~uytFe?UN94LpS0&$Yd)mndIHDfKPBiKL zh-C$F`{W}|ti~s}+lSy1br*`cmB|1g`5$1tixUsE6aRKmx-Rqy+S7CBi^%ax(T(X} z)JkNf6%pft=+3w^_$Xg&)-uqp&&BC}*q$2mG_q*D!DbDR97-80P3;NI`1KimkoQpdu0!I$%m0Nn z=|uhmj^G($<3=(banXvsSrI$z3T_2~xZpzdzuYB}R3W)A9e9Cg5lbG^gV5hwiv1Qs zJ@ykC7%z|kEGFx*azm+6P$n2k-$So7gG!erFJu#2ci_puyYliVRiPThc zCiRqBV8&2$kVBk+k4bm)2V6;9#tkS5e=}N{Rx@lYGS(RV%&l0(Zf0Y{XG}Jl8nI?E z2C$=eZ4`a);>2#~p|JrZF6X4D%>;j_OZA^%PFG zYw7biQ8|!v#3A~Rz|Faj?v4HTJ1I&}pbf0j8R)Wf1loNZ_N)tEzX3Nx8@eFQV=w9f z?Ll5Lg#Jhs#Vz|5yQ&$w3m&|R`}7Bc1?%CC`N_#F$$V{aKp?s ztY&9?O+$Q726T1>a~^imaqP|jx+Hy#SLS|Wg|QC1iZf3c$<#OFi!lJVyw_+B*O)A$ zqd6PyFKf&bctYRI6m!0ro2qVJqw1SS5KnHKmnef`sam)NGqCsGLl=50c0nULp6sS+ z=wv*l50FN5e^Q-!3bxV%+x?G zShsV?hF77cNJR~oM*WKj@x$y+{X_RKw_#_MhAVCb=q)6oE0zy_cGb{_5YV}h;li~X zjy5B3-_|pdC}LXRfb*Yzii$JNVW$Q3snBSs0zYNL7-sA8i)0tW`=$TJ8-O#3^$x(SjGEB2o511j0R>@w^Ch=spfoRCA5{77?a_S z@XpMKb=0U8MpNvg&t^r6Gk4(WTHrC{Fi+CE%qEC2U(L1nKg?aIHB>n=5vsbUkU#8& z#zPC@r4Lcnn9lS;GK}s@EKEt<3_LT4z6>|Z-6Sv0VJ)(l8q6eEPk%Fi(jCph)Kw~qogG2EDhJ=#>DY0TjJ9+@ABr`cXdE!6QF)F3 z%pLk9dY#5lxy%K6F1ndIiYcNh<|FDB{G+<58Pp_Ws?pa>0BgOGehR&djwS~^>)Gb- zhR+0koxD`9Qm>E;mM5p6G`C7y3(cb>^&nG9n~wK6OIyJ#Q5Q0CdK>B&T~6!34Aq~~ z6tx{XK>4(N~Sxs^0DvNI9GLf0pwjAw8*o{z{>oq9&PQpd<=swKIEn{qzUsT4Q^ zO=JF{)&co@jUK?fq;oK7^fmH>c}_hhZnBC#%#^0xq&uxK>x^dH39}mWA4TX)=Au3m zep4r))e$uuOgmlRs%j0GRHH2Q%e=2O2QKb4?a)!7f|M5 zs*?KAFtijy&>QQjaYTz6o7Gq_TSaJMU20$akzlYmx(zc@g!6PNJHV?zs1k;1-xnO(Q8(3i%ADIVZi0UO@&@*@=xB z!#>ydGx5}Qx)U{zF3I4P;d!3MX>0{N@ig4rqfINOJo$7l$*muyV$2-WFLbBxA;ySC z87fmxH1{C}U4jnd2}3el>woJl&C>cZ^QG?B9~!ST3r^NDXnj{ef8vAI(&(-`wI_PK zR>Not2J3QU9WS-H$T#lmDOy-9px1_O+J4mjo%L3Tdc}-6=67%rP8-GO0jOU(VomZR zehehVk!Nm2zL|<@Xa*^S-1Io|&4KvbG}H_^$zA#_ZpB@+i3}`)Oe~T94eIR+QG?0! zbo78)lSlM20?s+T7HT#dP_G?99p3?U)C=SoNAdZ>xEqJz#vFm!jfK=j?sfzC`uy}o z>KI*vK8h-;g1LyUj5zieRi6F{#q&$lM`UC1m{TNTb$j3*m_&U?y)YELcpX{U1ZK*_3>k$@Yy5F*1uDh+w$ z2HH>0B(30vafzEmUNhB+4YSZC=poLBkJn$QB%3qear3{&`t-)h3R3ILVbGCZLf40` zdr#_ynMA*#s?tqq8y!aeJ|3%_6CK+W@*XvRRkj7@FI$=VY$IrRWU${zKDI8~ow0)b zQ-I!;nevDJR2ut-ue&qj+%wFt4`IR_S2Wm@4(IOroQ4P&P06NLo&gS zPbP=3k5iFX&^Qk~o_z_NnFsJa%Wpaj6&dUV{f)Lt^J?w2-f*&*p*2@DwTp5~j@Nob zm#E#9*>XGhEAKJV0{^HD=vE z^c`62ic8S6J_fX5CpHf$hk5rkJ_~%ntIQ-4VCs^=Oe(6!{G>1Hg63v9^Izi~{B1g! z&rPR(M(=_QJHNRZnaFEn2JZYB=2=u{nbaEEf!&&$X@PpnOWL3d^NK6R6~Y`L4Ksu+ zIDotY6Em6H4|dmKCZ34{k7G1*A2qx}orF)`OH`ZJsp+_FUg7?mh@7GYRhBxAeu#?x zd|TWz8}W?#;faky7JLJ9q2ln;t3`F7nou(CwED<1{zTsXOC&POIgAD~AhTl8`0d1PxpKwk9{K5$(zc$Sno&3rB7g|{iGj}6LeGDP)ALQYGl3v%Hb2Z!qfF+ zW01BKuFwt94f#jerFKRSZm#kmvO_HcPW!P);pm=7zHkopVW@&SJ)9nH0SsVnxqhU9 z+&fe)x5=Sxs1T4+8q>X-+N7O2^d!TsQf{tA4>)A|8; zL;Qu_*dx?pB04O?k)L_t#JPg3fv-$1bl>)}VXhc^k$=ODWuF3rRfqkbH^EA#CnmRD z$p+>*PH+P?>-cIhtNgwZNI}LpQ+lVo7bP04Ygd z!JSl>Dhj4`P5Lz4yBO-G*~qk+C6KB9fjtXfD+@7)y<%W(oL}6d`-* zf^f6ah=bcrioyA-hJgMxJBqExrenIY9$d^_;1iA{b(v!5P~1VBA5Fi3{|aNyF)LHw zkO!|o6}LgJj4E%bu@m=J9>b^4M=z(g`Q7|!KEphxC~Ae@P_e(Fe==o=$)w>dcf!2p z3wx41f*7%wi?aRMD7y;Stx>=(Pa|$-7c&o2R2Nknr#A>jeqkyX{m{q+#_y<^jT(s= zQLOo|DZ;yByV(QzKml_toFA8%eNeN^Lwz-c(orkTFdZf`8E7-SHGb%v@kx6OchHSm zy8c1UqgBu{)!upyd7j=yxr2V(Q>CTyFM8oKtD#P!ifWe@;A?QR*#cTgzRK;2f zs6IzhD{#k_q=zvc`T*iwVa#63F&!`q?*gTt5^P^K2V0aqj2?awCV?HmTqU#66&Orj z(dWn@I*UrA$KcjIN=>63z!c=6CQv?93}rCciKFMEZ=+Kj)eQHTgnKJwCQxl~7k5HV z9D}~jJ=B_`jbF&@|JKhVf8me0g~>qQ ztv>S@cjZ&^7X8``;ITOH#dp#V5n*PaxA@H54i$?As2$%!LGKhYsRc%Tc&B_ozPQjujqrJg6Gv2VGcAyZ>L?*-)YU^8Mg`SzDnv)ZK<*r zGmAUQb7iAaOsS)eM8?umX{h+*5z2J!82&s~ZKf2}uIc~tqDYmnW+U$z1qOze_{Fo8CjnVA=CJe$k#H)Q?a;-1!6WN~Ynq8%Lc+4nU#D)E@Ov9`md2G|Cvw(bLGL=QD0`D4Qb+z^v&h!lRi(*liC|$KIxv<&?3YXW^pgdS%V7i)ocvJ$>Q15s-vP>*l}e@A8g%iM>(R|da3n@{nXDcFa7&|92}&Pg-64RXXq z=zhFI2lG9;c#n`vCL;251O)!Et#WO<{MzZF$w)H6yqc!0LkN}&6G zPrjzDgm%_N?Y7ncN}KykVr++Va98R%wHQxu4edmvYeV`mo3PjBfsOK=nE}<45DaZFxj7ADVsquX{L9oj@UfDiW8iRv?4ti{(luNIfVT-j9Cvyn$GZZ=GZmZc||avS`oLjMTeo}=bjbbjs{QN5Pg z1B!+B^t?J_T+qVs+R9WnVg@N;rZYkHD3{SqzoryZ`(Vm=Se+v`R_=o5QB6H0zm2A- zBfuxzDK}M~$nVsss;Y&xCu$2lC%WhF_0C2g%%i{Hna!m>(G%&h9iGdfU>n4R!79fo1|}d)F4t)PHH;&^upfgwWACMpdO0st0w8 z{y(PP0^Ew~`~ROgT{pQ!r*wmKgCJ4@lF|YqNP{4u(jX;>NOz~8fTSQP-O`;m&b_B| zW`3`I{eGVR|2faO_r#o;J$tXc*4k^W_j?!9J(n-CN*zIkz!NHUYU$JTU+Ikcf>@JF z%$*myYw~HhGk4?1EF;&`nf=C(5=pBL8^@<)SM+GYrt&f=K=vwKIs||HYoeDcq|5Y; zCg}_x$b&&(mX$8*@5rrK_g;~X%d2$nq%ykFaskG!Ch>+DcyuSQsh`Sza$OzePw-*_ z(mcryPvbWDVH!v!od+^oYOskqcunuO`x_hyb->(cL zANCi~mamjj?7*0#N}Rz_FC-n}I}FSH@j3l70oWDp<0->C+7N9^GV;rwc5 z@K*H4(12)?dX{*BkD0p>>)22IN_$NP?y?+KpUWT0x5!H=axHy(dTJ}rQmmxb>$}6> zeH?UMrmi$u`mL-EhRPmtR7LgW^i_3LbRl`Zt}Btt-Dt20y0v`cB69hYZo0m)!HDCN@hm}Iw+vqW^hngQ&s0*y6+mU&yMrN)r8tyD{fYU@#J~b4k zs%WpS3Ug?juADwB7ZIFbWL!Ryt-4rNM-gc+wz-*nA5`}r)EC|-yIhSNUq{_pc?Y!u zsrqA_ZIphO{7C1;g8B@BPGEoGw58nnY%6h15&RBZo8_j3E{* zw8Sc50dwtVX$!i3B73;oh_8;)RCa0mQG2F6Q~RO^-0bmwSUX7+HdEb9WcDYu9x=;u z$`bY{R1~u^HS&G*X!Og-Vr4_HO0;uidgy4hPpDI9Z{$c|NF*_ICNwhiK71!tVTKQI(E{bm?s+)3sA z4E7T4Yy6r$7`qWqna2)!I}D1xxnU<>*fO%KtBFPBfK^D;`Q_{el@V@Nf9LSAtv`Nbcp@+>V+*BwGm-b$^s zKJpZ>FEP?H_K^Gx&7P^8oa(P0tn0rY+6k{$%it^T z+kh147uigeNuAJMe*xur&>poYXCiGiKRY^n#<v`zPiR9km~iio1R>FAVZvW2V_Tea(^ z;vKxe+uDoGSjBU;@jfwtr8@9w##DVV6IeCfHhoAsjd#&R_f}tDIwrT#9gyOQF*TFU z%VBMyHb|+`vEhc*2yqDa4M9qG|htfL>WF8xqT(?;X{4pm!d zzu}MhBO{fe(G$v($XvXJhJjDiYX0w{lY?fMTi*DZ1p0>%`F3{WPBaq3Q~6X7TDs&B2s^j*#0*-XYdj$)Pv<}SA5 z=1m|v23e+C6X2E(T85aaSh7qNj1kjh-KWO2dOtg7JkyQQAEW*yuf8$s)h*N&6*n9+ z{9`z6E@ql<4#LK*TGm>NS)Wmja`gW!DNkRcK)h! zN`;AN6p|N^|5b_A_g7(qr3VgZHs;m<_V?|pHBui_rA;nH%LD)UcqN8CEwxAk>KA*5 z$3(4>Y%G3SuyJUre^+o>u)P0h@TBV>|76!u?_9r>eaN>qE7tW#&cf`UJdeCJ+~2sb zy1KhshbQ{ZhI7IVLOWGCGEpt3o26?FYUQZmps5iE=7aVhtWRRz+q+xq+izN`J4V~< zfXq8%Nw%~xr5pXG_oh}xr+$Du5^LNIHu6(iJQdp^D!^tcFQfvxkL7$szQ)QQ8x(3P zyIY2uJ6j)sX1nD`bvy%c{46FhX0M});}gf;n8uDuAnilu7wkaY75@H8`g3G9pC~$T zWF54oR1^Od9YyU=!{{OQIL)9YrD^zjbV~SLw6C%;993FHwnn}Te?jF%ZDk3%MGC+1 z?F%&xrUVBECk5X7wt5SB*SWL&AG<)v=4{A5?rP+0=H8hdlQlGFU``#^b+_65#`%}` zg?mP@Y*6pFL>h$4M3!kSwEeQ5EcasLBU2euBIvql)|!q&wj%cP>?5Ane$BQEtZQjY z4f8-#C+cdu8tnS5cym3p($vTPN<6GE>(u4c+Rcxih@K!z>?T)d!!tUde`y|Xd26`} zi{f*82}gn04@6SSQ=Otg1TF1s%Dr?uIaS&hxQyj zq9-r?DIRobcE|lUdNFEb-+~#@L6O1A8Rf0gPmNX1MF&PZCH_+2Ju+8;d&OqNUS&818u5+0SJnzyk$Rg zuNvnVYE!KiAUfPz+DW`{0_)6Bq$*WiHzW7M$?SA|C2~)>AnnnO#~a_S1Mg(nZoA>| zfw;cz7!|uH{%)c@PwB(~iQD2P#0TT5$GwPq7xTpNv(;qDF*-oebvJC&^_3J@(@#e> zkpJ4Q7?m24J>h(j+mVOSe(GT5U*$Vi9WRvY*dsIhNI0~TN?N37v}W{1B&gJh914HV zuHq4Yqu^S9IltSJ$K`-GYlY_*SA|TcXJz^vSAA-^Cpc$j4M~6N`qz2fJ=D9^`L`#Z z_o=U_ayRsu@{u}8nJj%|SfD>{k<8c4M{O3kBWgOP*q+B^*f;}^5HYpg(rDcsq6Sa2y3yy#boH(>N@*GC z8@2}%0#p2Df-k)D15xJ}-ubS0r`|i<`7rY@xX%w|ZE-HlIGnWsKK%KfO3oAB81Gu= zM(;hZ)4x%x4YSh#d4`gy`^GfiRM1h&`pov!kzvurd;?ErOxzQD8T+^Pu*qymH}*5k zH7!RwRM+_XV1+D8X-@pes8dmw@%EoAUWg3~l z8R4;^dBL;6^w2rxP^|K`vX8wGvXx7kDo@mv)i;%ctOV1jsXR?3SvBpfI#g*CIjd|5 zoeBRP>Kcd%^mku%?{+WD>E+&%?aXPPSuLx<`|4SDGvA~qWt@O{*1MUks zqr8Iy8?rM4LwqA6v+zE*8e8aB8VVU3T4F8E*v*dQxIA%XZ1ZCrww?AT=)y|&m*%gH zZFDI_Z#zn7sXqBGG9Y*=Pz0Q?fnBWo1WyDC2iEyhgPzdcNKvh-)Ptx}P4@Kmn08yA z#Qq(>At65DN!*Tv_&hUGUZ%}V+n4%0&#r{G@q^<_CXR@|8{6Ko)O}F{bH-&g&-piFf6l|KD_K|4^E+F-&B=c8?z8N1 zPHX1pSwCf7%_{7=?kVrLQt`Sp)G<0Va)dbLAoF*Y@z!pY3h3t367W0@tgesX=5 zdBQm~`;*McnT2zv!t4v1W==d6V>!N#n=IM+vIYZoy{VkOik+SN7 zXui;$=pA*Sv6=0hZANS{>n8J8AgH=p55~5#=Z(8)oo-CiZv|mw)#Zp>r8Y8BH}to^ zkNzco(m#}gYConLa~t{(K%F?ZPx3J`@qC#reDodz%5m>O|i={ZS0@g6Ktn! zP8eK|fou)iTR2u&dRrXk+r~nsW~Mr(&y53#xUbi$s6*6=WK)vhcgdo%=UY|*-%InU zitPj5rX<;-R;GJA`dPM`KZivm*4zaQZ(m~*eJjHT`uLJOiRy0y8+B|Cf6OYf7dry4 z3*8P*3%v1_^tbh00y*OK{^9+}H`6uE+uqa8HO)7~`NWgttmiJ3UD-X&^~5zI`?YgX z`kkz3_5x4$$W}G1G>|rizXs2>LAS4g95%)*^GK^ZW^r8Y_|>*s<`ME9 zwLa^U0>u4^E-3B7d4pSgivm3Y--gFUIw(t{`N;`cs9!2U#K&x$4F-P|yZF4oD>`o3 zYn^2K&0fJ#Gp4@nhV`O38HSi$>`T1Y_LXg`Wv6M7;R{_FQM!|G3;2x}^}}@&wH3sF zi_2#~Kxl@M`WgBiRIRIIVE-_du&e-=)7zY4j4^gI4m5tq*{ZOo&|X6me20Tl5n^Z= zMA5oQBUD+r5N;jHkK7*!)h4=Vh%5~~^UJ;t-ZSpPzDu56{-xgAo+|E+IY#G2R~6?i z*W8?MTyLB+vJW{|Wt7UcW>(KEl(Wmd#y3j|MIT{DCn#?~S3cGsv(|_`#QqsEvGc4c z#+zuvCbr&j{bGhYCYtj~2R*IB$7rxo+BTNH(mMo5GRO%(xcz-e+Lm z_Q<1Qy=@HIWHGt?;qX#a(fuTCC#rE?f5w9jYLG~4D07uHdk#N`%J{hh=&TdYqZ|H{qJ^LRm z@bvT)^Ii8gbWQa2b5?bA%02+Q#}(M?o4S`~r{ol3#a-lcKF|#@gp6sXF2-ScA9whG+`%72(pDKa8P~$@^g9(cuZadtF#HY{#cb|q z(iyjsyLtwOXMB$e0`KI$MhWIwG#s zLzjca162a2yiI(0y*0gEd|g3n-SnLB4s|cdxt`s@{V=;LdxorWFLZa!Ii0aHGm`z* z{W8=`YY+ZsH&Q=At)e~DJv6_G`6Mx#cq;yi{a?#CLuY*><5Tk%>t$O9%NG4k?enk{ z=ji5B+6U{0?aD(fuYQ;DwegavgsBxP@l&QX=9gB@cFKO$ zJ`j9s8DnE&wr=@`ZmO}UxiPCxab>O+p32gTlkZFT+#7L{{@ncAa&e`Ldg)gqKS>2yMY=&@JRmB!NB`FRDORssVkqu2$92;U5boP` ziN+L52TLELUtdc(8tU(_=-2xTg!+ajN82gSBNle}iO_R}uv6Kx%do-tqvfC_#rmUV zlcgCf+FzP$T2zbO+Rrk`d>tIAux`tBAL?J27MshWd#)J~~_QfjZptA9r0i6Z_&O}vkNIXZ+)!8!iv{yc#l{u;sc{%-!wz9g^Ye(&z%uA5Wa z+r^!kQ_Ne~)h|2FotSNM=F6FwxjD<5(aKpkyOi&TU+*Z#Vc;q&De#=?zPIZrU?QvCf9m`smwfp_ubW`>k*Nos5X_4-{(QS0d z-)Kib(v&ima%998N=`_gmvG(D*))YZu@!O|;zCx-k49DBPpKTN;{DtG+*Q}tEOBpYPJmFc~Z-Y-?e=UkrarUt1|_af{dVkEw`xujP=fsC}a2klkpj zW}ap23La*qz6tZFFS}65h8I$4wKn-z196@jy4w25hA{c5r`kd^=vZlyewb-CIK=O) zv6dUgZ}96TfEu4|Tmn~=Vr*bwM@K0|n?&W!9;*9xQUiKOxj|0tX2=!16{s3W4i53( z;PHXa?%(5m>)Gs?=Sg!<_w2~&@BY_SGUrTobJyFfJ6T_3^?e_CUpupMc1=&c=o0Oc z;gqpDSkbZ6#kDqmY0<|%X8(@iNwX43)!swc@%p`YMLSdb4yi@}qyLdumQN zXSZvZe_rTlq@L1D`+?mbA5r6N)ctAfX34QPja?u2EdEX0#MsR-8||-cOC6^i<>Oi< z{Fpd3;Z@vf`!4f1eOa=LGWC%8u~IHws&q5DIMg-h3`WAY)RMY0hD}DdVUXTHW#uQj zONOfE->mQLE9|RmJ1pf)oeYii?aBXsX*g#XYN(?xAurMXP{u`5qvOCk)}d~oMl>CU zoo_>{g0+Kd0w?^3{YCua{7rqOd@p?8dtdr0!nI&>+1;gbMmPt%>O0#pzxHSRlU_IT zx^t?lVwoKYZG+}q7l&b!$Ct;g!A?r!QT?=I%->CB&*{{H&={mwv6 zFaK(-t8S(#(K6ron?9^tCtWrCXl@#FIpIm3-if729+fqir2WdoNO3BQUQ$(gRaqap z;cM%6H0PpgcDCk3d&{Wr&6tj>xN69=-k~$?HN*$Z}Ny@*; z4f8CE9}xF~eeCA3lU=H-v3!BE~ za^uW4?=@MBR}CdiX>gZcG(SWC?9eZjd$RxC3gUBLDpNsw{T=NQF@^`BadUjTd<}eA zzGuFtzS6$mymQ!jCCgROUEbX}=Og!}>`~cgvp&w;ovwRdDC1_%zMR(4+th1P~vh8L0L8y21$`Z;jISI2k7z1*|b{h?>B=cMNc&q2=^ z*CXOyo~%b1ebeW^`|j<;(z7Fw7+DNjB2|*e>&??IH%+w1 z#;r)IliDX=i~Q5_j!7GqygcD}T%njudsq8g_OjSu-fhgUE2Z^|DxtaTjaV#vFIrr& zsU4}J>j65l4V5QT^+ye*;Ttp=YZ#jtFOoa^2R@+R^m=_ootwJA;xGapBr01pd<{*o z(cjS@;|qD}dcv+6Ik&TGXE${2bJliV&)S|jCS&0Hg6}?j|2XS;&gqa-UTQAjfDhL3 zgDr!4&^6X<$F79g$qiB#Bp*)r$YwMwP$!4%!7G8Xq2EF;S)qOJy~mzcXEGOObB8+$u7rmFudt*GRr%Jx_34j8Gy`Z9)E#6dq`B{Y?4`&M!iNYRE= zC8bD5sHJMG6rnbyZuk-i;}L;@z8+rHQ_+3HUC(vcWz0#-?whqWWBmJ%-z|GP`u(NM zh`VHvY^d=QYfbxbTN7Iu>j3LON1?c1@-$CLPQA%}u5d(+-^;RkA~F{2qbGbA28Lzc zGVl;|%DkSLKYM{I+gB;nftvcMQa62yX^nMwY^%hgDaX_H=WUy}R9b@+LsF^4>G4%! zV;uRcW8uE|T?;F(BYQ#XwhW&~8h>Yf+bFy}@_?OaU0~}E$y?=(Q~^e)608D0z)WKi z);%xo0^tZBF&rfRq|?f zk<1(aHs+44tfh-#tJFmOI64cD@k!)mWLcyVSmG`0ocvRCj@nPMQfWDZir9VB(^=s{ z>TDQfC}t?i8e_d0j#i5-3H1ye3%(2({bxMEoX?$7W|xd#GumgibM}DKcYgLSS^s2g zPe1j(X4ZVy>A*~_vbmf6c> z6>c9c6I$+9JR5WRXZOptyRLXUh74*cU0GAmvd>;Ec1--%#2rZx_olr|EuFSG^=0zS zJk1mCU~xY(CmUW8x#c? zuu5H%HmZrr{P1LUQzg^y|IXLWlh<{^>B;;c!sROR&aBMRnFBM1r@zYh z!8yP;CbU|1nY-B4nCkF_-mpcjkL-(MPbGYkR3)WE(wz7K_AQ2Fn57GcJ>l!&1CjdS zspLGb<}Aw6XZ3LY>H5UqJ~~A{WIS*2*d92dj;V1Q6B3gCO^M0dBj3rqt+VUsTnN}{`)f+aL23rwJd6GjuB~EqD=-BdRQwheE`kRvfq(LY-w!V?HD=M8dpBPMxvZ_AkXIs#bak#=IZNcL&2MvLGS+?su+y-SMtQ= zq&wGT7jP}`nIjve=Z3=8>voqTJw}UVudu|L$wyQ3=dG9bb?SoTr3t%Ydf1Aamly?$ zQZFjZ&APnOI`)DY6+Ic9tgv5_v=LUcs$}LngYH=sok=cvp1KDlR9)SP=A~Gr4BOn8nT}(&g_i01=3v@F;a%)Zb1mW|_cDxqYqNcE{(_;5 zNDZl@aWQt|2>j1cYg^lQ_LR5-^va&Zo(aFk7IAd97PO2ouQWX~uZKS-p#Muwl-g-8 z)I{wU&`1KIb)0PQ5_Tc#quiz>k_!iegsaPrrv9 zL?W=pBzU8qJ)Y{`X5LJ1aqlJ1&z{q;K&5%s_&x}j!|lL~z+nKF{59fnF@^)irsh+Y zZ){&#%UBAVYZ(NOpiTBt{kDioqlMZPr5$w&8Z$gm8mFrP(jt@fq}g2G@`q)B?U4PZ z{gLC4RP?LsOWy z2kU3Uv)a5BhOJGJ12T8@-%WbvNsYNJswGpc8|1;bPZP_ z<6b1RI`mtxRY(_V6nsr4>5_kee^X#g@P24$xGgLLH6u0P3dl#z^HAzh3$sh*ry!x4 zQ5UsI_W|*P>4t9b4*jYxqc`Yts4-i{u1{yE(|8V6qY0c7=X5oxN=(;xHI$^T=Y_t5 zp; zt2dZcFJn@dRpQhf-tSP^%92Lszg(APHiDafPrdCaYHcThP&y2ISqZA`I>C*Q zU-D|VKoP$r>tC8Wyw8Y@jR1o=m}<}NL>H%kv$`$4qXzRmb$c7B(;EYVZ=Q6RyFUX) zc?mdzPpRzvn7ZzXa4sx|Iib|dluHFtO$39wdmfGc8+A->= zm(kK6v=LO-kK{Ro$7irm6=odrfzYnTyRjTGPn$!Pe-9qr$WAQN{?PuUW_=xZGm-IH z$4Fh%9%z@T(BDXWej`_U&9#lRAOn=rMXvpnDkeQR{X9~fWCb0TAQhyAMY#L?v@jok z6R8Q4X{T4q*52~_9G=f;9k&l5qdeX1mse|pw$c&(Fau$B&1a_GlEcIrT`ttfG`hE%Zz|+86HX zXG(MN!++-J>C!Nc?Ev!Q6C`c`tSxQ1Rt@6yMY(c8z02Q(-B5aOA^-Mi_BAbl@!l$?A2;`hb4 zQepbs##N&9kR%1D1%FR}zMwbnl7UF4|KIZKD8E)Xn~(GQ;JFO*K19u$j%#_jl9Q{v zrSCKN^*bKaL-OkkelN!BHYlZ=jM+mzd&c;@;{4(YFBqTaNQHR3{=ZMecTXATyUZ5R zf@i#Y0>1q|QgDa=clm#v$9;}@$B3u%H-le4=b6b^2WXv}S8w?3V}ARb^S$K%Bi`N3 zy}D?pnBRG%voLS}>4^?T44SeI@QLhlltyhwtBULbtV;)97FC$f)XIcJ=(Gg=+g7?MvUy>24 zz`U&q(zOb@qO!nw@xKnAG{GOP%j~Sk-wMob@lHIO@oi(O)9dhThF)ohCitAnjrQn( zkNMjPz0eazmCkTWv_b1ML8rGzLv-ZX29CRiQY-j7TA?|_nd)=SPkG;(<7(4}%3Qr3 zpElr%wRm5VI7e+-_#v;WO66#Ycwd|L)#I}oysyK%N+21k^R1A=BHUqF+FgO~%Oi=& z$Y%i_W%#xr|MT*h&{0BrTbcP*`pm{$z{x{CMSoi9ZG~C&4oQ8*rxNc;)-x7<7;GR` zWCS$6H=u)3koE#x!$#k0%t9OQ<2j?s_ukxxOi$6~>bT!=To|-KN&3zxW#md&Hg}N^Stv(Om|$=4bI;FcK?40&AotFE$l`0>e?Z@U zAf@JBH;!L41#Q&|jUZ-`i>t&jmUZa&IL=+05vaj+K0p>@$bEQn?@0Vl=GqAyQ2=T5 zGEY?cFpe=P!}!$7mBTOL{F;m%smHabiqpl~BvBjoC9WG!Br56h36_r5l}f;EdKX6RUIAR(}06?;g>%D%|@+ z@Z$ZM-}Uf#h9l1}vGwPXr0U3NDmBnTOBbMDWRPD!;A`9`3TR+#&*7uh#aBDX3a$^9 zR>32=PR+z$T3vc%Ha=p1?qm(#OG|319nvU9KxnPm;5{AO(bw84dML?Jf*|QcAGLrbs48+5)XFk*GC-dl0kOGWYa+eE3w_El zE5XopW~3_9mn+~6=t~a|!+#pfH5b$07o`Kt-Vp6x#=ClO;$H>pH;+4whr{v*erc6X zGS)S~Tz1Ciw1IH!M_+c9d_?mW;%O{roX+48pJ08o8xF!QpvSKxg`dz5U8rlQhc@{Q zuhj*=j#1hNisJ}l@EEVI0-nnga2^{$Gqsg=qDQ-cq!}%(#hZ%PdQ*YoQ)kk{W+@v? zY+tny{_ZzO+7UF>RQmHge$8T}Xcb7i8~Bl9z&;t|r`juZ99Me+W_T#_-g$`jn_Cur~A8RcGF z>qoR(1AOuWL}OFouz9UD$LCs%C$&nt0`fBg-oGjAS6@qar9aV+z&}PTr86%XU({#2d*t8M~) z*GqN6mew9T%|{YkSTJObR&Q#nweR82s>bg2Q>fjJFlvg{5M=#0tl4e!&mi@C^-t|Q znU=g-3bHUCS(&C40DU!=8j|5eDZT^&k%t}`#X9daXzah>^~r)kqdZ7EJG^lB!JoEL z2ZFcE)P7(d&7!S`uw$FxHrqu-vtZWx1F0yawI+JqGq=4()RD?J>OAd$bi`1_w8-F) z&uCvO$D+&Bc2WoZ4ED_)sh^_R)v3x1? zi1H(~t#j4(Fw3lCPLVNyb#kFH2=<*lk^+n2Vf2_=IZEwzIc=XbNLPVX^<&Vs-N3lD z21f}eE%RDOMe{~RyA>>ZPFShlNeAh-DqQ6z+(z-Bmky%suW19+J*?snstc%U%0j#M zV+?<0E;L|HPejt{&{t#8o~hXS`pl{6;M~$!fo#HOa>4{>litA3b^y!WK%NeI%Z?n) zg|lh{TxzY^8><=ng#XMK3{{>e!?a)FJdorb@SCoa-BKS=v~}UMxvnT$qP$C5Ef<32 zsR6U}uG*M&Ul`_uf8o6RmMZ%gRf(2XE$ZK(QMbc@zXbd94_AxIt6585#>XiHPgqrO z*cH@(b{>|l@$lU|fw3up3cd>3RL=8=wtptyQg11&Ex`i62itrKriT97sc6AyKDE8n zD3|H70i!!rbw`g!kEw3Bi1vahY^m!gH`i*Q*OFnw$frLq?~n^gLzRcru%C!3%1Y@& z{ZsISIob#{zv7C%P%mhMbY}e*u#SXaXZaa??r&O4tlDX`#Xj0lOUYJmsB5(w(hBfm zBXkwy_sSUc6jECQT)q#y#TVG1WVIHywKrFt4rA6qnA+R1JKZa#1Z=B!bW!~Wx(is5 zm(h9(JnmTHJ8~nrWG=_oUTn)yd4RkS&V-}tP@*E$h!?cfhRH494}Bq@luJrcn0B5h zPPK~sE&CBX(qw6fHc;6Vy-C&edg(MQq9HT|@p>go?F|xojT{Gu%}QycngOzUDEBlU z#<3pqZ9GREEOhVDK*OYk#DLt`h}zO*m@V{h3LC(q&eb0#l5znp`zLofQ7Q(Ose^nD zw!m+dbL2z?<5pYxb|GFyOSuUa@ig|Z7)am3R1w^fzSed#%Sz&hFQNz6Q44+ty>^ga zl|v$@F>)GAfV0Tf&cb?*LFcDyx8ag%2)kPouyQfv9>2u`8^QJ;kzUZ7Q?&rI-$X3= zTj^Y`XLu28%Q-ne9I2D!+1gE|pK?oCt^UNm1G4S|-Bh>$iolI9OlhlLMEiCHiM9>i z*stYkpzu~;Rc_;r6oR=R8UJDxdS{zj0S)_FOUGJX<*d87w|Z&~#jpM5C#%`307}a>HsLVug zpKI&!5p?p8_yMQ2CGs$EbhG$-hQ1p`eQ+Ob4|eBkszL>(Zau7e<)z2cZ1^X;=?=o; z0r#Qq7E#yJ%$p`!UykZVTW?c?p~8vtt2PGBHB@(x{TvG5zcu1)f3f~)3*IsZ|E4T# z?Y*VDV5IvKA2aLrGuNAmIFYnNnjmZVP$RW|Sh%HlX&bP3b)@!MH|Bf1oCnNuG2Kq^ ze;sJwdHDS{OCGrd{Fbkzd5mp$@ZH0(vE^af>@R;T#W0)v$|`cPZL~SkGZ;6wV&R5r zbG7;McX*e-NZ;sQgBZ_|`pN~QwP^c|;Jg;eBhY+Rp#NAU$qhj+&m#x*fwqM9F4k?3ddm`) zu@F9WDfr!X@Y}rd0l5_!-`#Q^Z6(qtBCNlvKVvJM+|y6m2IhcO_bGR3z^nKjv~EzH zgT@;!cSaU;tj;$u$2Tf>1xdYz~7z%9cZ*bo%V5QShi*iKg=oqC9aqCZ* zIc4Pu+A29t`dlqfe6kImN>yqc@5__nCX3NEmbS>BDks!8(Y@L-^#kVG4p~(?XjjWbp~l6)m+t7v-&O4UmM?I z6)bSK;NmKxJy09auO^AAU>KL*!Wgt)xvhOc^fCp8xI%aj-I3hh@*?E@vU&wxm{d3| zE@;!>y?zc?f+|JihIpCluv$Orp5lp4Vh6Q6a1uR&6Xt8(W%S2xSY-QaW7VkE95i)t zeLYzJhQcZvRtm5tIH-TEFJ`!*>kq@@htZ`<7v%IKn5$apUs1s^T&)?Mr5x5O!0dbh z&V}L3#Az_r3{lb;uQXjL<1zg;ZFBT}WO}q4n)#w!Twj~Yro;FU8oAMUf7T3f3u&tz^-0}Wtd7A;b4D$CJmWt6MZ4D}Z| zKdYev+7hg+9nACr@(uO0iRvc2>2KIS!WUgbjkR80!%X>}Or{TJlJ8l88FVGdq>Pe3 z(k?4$>UnTW8>s3{)0_2cVP(A=U7&u=__dOp@?E%%d&$P0lbDFO;uXJ2r=Nq`7nm4uQ$c@sa9k*2JHpfqsy$u#blq zG}~dNE0%U1IZC^HLX)*r_#!8&yVP&Vt&GDq?bm(>lUfjt?~T$Fcy#W{oyd9AQ+ne$ zAJT5AZR7>8(doWvWrK&&pxy6g|ELOy0#Q~Y@^ zE3GVYN=Nah53pYOj=W}FsTWLNH_+Ba(cFDmd-d1rb!PbqtK*f*Z|H}h=EbY)i(k@N zUdbB4sI^viV}<6*FNqY6XJlL93%^tIao1n6<`~4@XKt7+j<8A_3x}5%y|)FMjMb7} zvbs*jCnzG9W(E8!K1Do|uz+~RGwm$Qck_v|+-3bm^aIA(1z6G2FyXP4J6Xk-M4?7W zvsm*@Mq>1E;_Q=Su%LH{hV^0f*qgmUE5XM(liYI&@Y%1VAG8@@`Zwy9$xpD*W3e^Y zh%9YW^Ks8TVIry~&%j5##g0uiVYKnkwqC3Nrr~9jgTWyO@3jwp&owG4*6b8(Vn1ubx*T&I4{tL*%N$nkZq_Hoxx*){NC#uNPQqn5l{)`*XoeR=RCZ|Z z_{9Vwjq~NMu(?#07GTA%!{^=*y>5cLwks>BHz1yBFiZW+*=%?W%E15g7BA@vzUgS` zA>3A5@B+%f$@5w@%FT6Csk-kI*!67Hk6#wNEl7_;;~Pl z0`hCTzI{?_;s(```{HLd}rE+p31O-9lV$KRx#j3(#9i(tX2dRzjYOlcPM%8qX+SW_`Gj75|6GRw3eU zi=_QzPrKkR)?`I_No@ktSy9#*pV3R>&@oe)`CF(yYbotPcQ54LGhr>B2ZnzhcFHQJ z)886$mqHvXfrwl?=Erz?z9M(9j40jDpaMptWd^bGYDU~-Bs*1&fkh}kvw5yu6jtC? zs+kytlYL7Tf_WdJ{zP9#$%1wv@*6|E%0z$UV72<93Epc>S(9wzY=01$JF1$f?CZ=L zaEjIz>(C5KpRSft?!k+>myzyG5B9)op3G=>#J?y*B(s$K3)ax8b%853Kp(xt12)Ou z6T{QgRmuak8u`64+PW`R=X*?)flvS z4>F4fwF1Nt>fq6RgGSr0Mzo4@I{w>s)?Fe`F+uuR+ooCZt@m*B9I_8p>C0`zFAu=D zGLIZ~1fTH-GWuS6M}cG4%H1~Nj4~1RrNm9!&{L^oGw~r=>&<3u`ZclPB+Vl=mJN)F zOZyl}X+pfK9oJooURr=G$Ko6QMod;Ca#)o}M!fWldZ9h&(U0-=dTI~x^0%_;n9JBV z#!s4pJar)2cZ+iu=cwmuar_yD=wvI7-w3NpRk*+7kd1w?a#dw5UrBmFgnltO_XYU% zrP1=2h&-jqExGf-oM9WjG z@ztv{_rJ$VWTVHPqg58M>RQPtyv8bxk*1LI=%V$feOHKmhp}-Jx%WTdA@xcnSzT0$ z?jlMQBhLf(FSrGdp~>zrVylSX59ADQ7`K*EQ+l&CvHH`*?5wbtf2nPP`8Ld5PG@!( zWHgMdsTL4>n21&VK}4Ha@AM%SI-MTALKM9W_q3N7(nPe1iy1qFIOrdoZ4J744pBz< zp5^?k4N5c9Qt`<#S}Sdm7l;rIw2zbL+qlezgEi*S(`<{R~u_AQu>>hezFFQP~;} zLPHmky3l*07XzyU-owh|2uL?^|Fi=MhJ!Dc{nX z5b2Z|uR{3mVPs(sJy(Njnb8lmb=XAs8 zNkNNDNA~Us`zn9Us_rT+dZq409+nXCy(;fPx6G5aX$RB?#JnG)t3J~eVwRp}SHtyc zS?yOe#X{JNtIIYl;2^YrOSDZ-`nLfq`iIQyY)~M>v^e#~2) z0FAH%m9?hy&M36RCCMi5z$aKiRZSLAlO2rVU2Qf}CCTrJo^RsJB5z0)8d7za@d>hH zz%)FD{X~=Xa1_N!RJf2!IfcX=W%Sx%iN2)Y2XXeVutZV%XeT){6XR+|pCl2RYDLeK z#>PH{1F8`SwjuO%edh8AJf~*JTRp7FB(%#bysfQxIe(Lh>Vc1OhWWCP)-L21g&F%i z?BhI+oL42zolJIbKOWG1Vj7>2Z5c>D?J<##i>x$mvOYY8fAbl(vm2Uw5*o#+9!1J( z5G9*MHYS_3>3E{Kr_f8wwf407K<;B0fA65V{=j=VfuyHkhrXr`=^%NMX6#g2UAw22 zBT5MwGWj6`Gh_hVq#wzRhydkh%o=HRh(b0KT7uQmbNVq}>VmB(FCy*OLctAd;;82M zGrQ1*3y8u!!m}KVA32NP>6zgz zp}C#ps_GDBd5SI_gN8Yc?kkCv>;zM|UOEiBLL3#2^VJhrul1~;8snFo$JSJ5_QMN= zU-}-z-S7Bk`IzC|$)c6U8;cT$7*3}7I`i-{8HWLUdz+pcB|GpRdJvH~qz%{hV38ZM zSH=ju+$QvHZ>-!?X}Rn}dmcd>k3#lN%#S7z0EG{6n5HfxO*oUH+U*$!#XWUrk3 zc-uv>dkR-?gC6XNMvuWhG(q2ef_D7}`_hYhjzh!0rS_m9I~@=O)6!t-JPR{t2unB` zyE<5Hh5Yzv&wCFUW*0Wfm8b#1!V{vwg`}_0xP3Xt2y{#%yc-)ahkRtSsV`);;-!x}P$5|!JspKl_%HNUP5Nb& zXdT+^CM}sMZI||9{Y#UHe8f>LsakMz{7$rM9oD?>nK>GEzZR?O&ykjE$juvVDpq|1 zpTy&_{DdsFCR_0@`lbomKL;=HETa<7y%-p?U3m51<42UH$1gC?J?xMoa>*r;qHpl| zBY2Sq(Ay6ABj(Xd#v{tyJ%`0%UG@j2wR@eu~nlEJK ze++iZ)o7fO*qtGAUR?<^Y+V?Wd$LBE#%Q)cM_-^#`B~eD>Z>}meJv4}*~C0=!eU&E z9us+UQShcNO+Ws{u1`<#d&ePlZ?GeOu`=3D4*L?h1TT8w zPo$|D{+^kcT?DCEiT#Uk#*M5>o-=cQ$A0HwlvbnpE0G1z)6@C6!(>|CfaC8lGmjxX zP0?RfSVfj1hjRtbO~TfXVARHOt`W%dODtMGm^E{#do3!g2AcC3KH5IK`=WRW#W?md zM{cK1wGo=(2>pa#jMt~(n_OiM%tISC;G0>@tUBaF(&UQrBi2&S@Ulf_WHuhKftJO> z7CMdjIhOIg$(m^oR=P2Bw-e)Z7EAF&?WwURJ3Oi@*uzXOmm}K09V=kKTgpS$RA`_K zEPX!KA2aX_r-K%nOEl{=v*ZES7>phsik`2Im;MwVK-3nkCw`Pe#vzN^tTeoj4Au$P z)hloreuRXMCpK{$xob+~=WA9-!-@H{#hWK9#so2Wn32NLG~DQJIL&TXyt%O*Z!~x>lC+l5aYLUvQH- zvzXSDru8crhri&Rm`9z&bTVEBxvSg^Zlg8Kq#am{M6x93@v=+Gi)d>P#xM{x^xzzb$$QASTg?|g(t zT8rH`5%c*1-_WCWz(P%wN)W3WLyo&L)g@Jt(hnJrAIOO8=lW^X#jHSIrVwdsh}Nxu zj@g8U`UrjL!WTQl`SPI^8sKZpL;mR%EL;I9g4R*#u$rEGMGuKR)}Hci6LJ~E4^E^{ z)-m&AsoVL8%9Tz?%qUipBapZy_=9^m_Bs_cqGm2XIy#$iDS#j6#u9G86MKraD#cpn z724?>woO#Zrf_UM;Z?t9Kp^(1i7>N(}utzWP<}d?#!EmSl!D zQl0cQb9(|S>_734m$Qr5JviI8v!2$o``lfw*GyZ3*4<5xQNbTS$M}hO?IPA^7a5T$ zteUr=b*p0qmoUR-;~_qwzuV)dRH1hUaGiC`|Ho30N{U$gw(+d4BzkLgt|W;1?g5Ni zl>52G@0KxNJ|M1InLXzGY64ld7_{C;RDYC_7vO((V|F#?}2^_ zvH#yTdbT5R)RHmIgDt*~)_Q?`Oy|F-*ZT*}at(Pp%i7P)OwHjN z@$Mv+B$JA^lw5m!j}~sGzklU7QC3?6k*F}Uvk5DWCH$%*EB=Z|%1$)JGLF241m(fT z*38v->sb$c&lw8PE8lR(uh1qx(rPj1YV*5dNT`e@dqAJ=rA=OZumQ}0s>ti%+)qMy z51r|W)>u_h)0&?d7p8q7v~wG-{V`FO$5_Wdb7i3xeOQ|-$1p2Jb+D*Yb#O!#e9ERs zlc=#Y(bu&&V>cwHGUH<9(*!=t$4tq~C#~@+MIEDuetN~xC3)YRxl@$3#&WJ|98;cm zX}M=l;K&r7aptuf8XH9^Ry-#3AS^^*PQ3C{&$mu~DeAqmc&E{# z0H5XXoiF!)7RL#>CKNz_sk!a)B5hV$F6wImc99=Y_??Nq7S+v(^iK@miF)M1^l2=k zC{Pk6?z8}XpTZFW6A;7S7$iU-SMu->I1+&jh%l}R{I(#!&dWoLS8*QyRr(i4f{O6E z5aXK0Rr2t?z;-0^>o|TT5C<{*K928Fd0&8|1b!fa??rXGgHe%lXQ6>B*|?_=y(o4B z4KdptJgm9nD{un|xox*F)&kG)AKrzc6dvMuJ)h~ho4V+o@s_0vBRe|Ra{UX{R9x|;kH3QEM-1 z>&58jM-D{AcMNCG!%^&X$LpBftBBT0TvOn2{5+xe&AqQM@5K8MeG%ba1N4n}FXo0P zcck3B79%C*ikK||5#*M-}E(addb=UJLw|K=Fuq9?#!2{^#W> zaAh(46*6PxDc%W7k-|}8<_mow^s&Gx2z?aGLts@>`7Ac~y_ntNwfI*27iSRiBU&Gy zDQhld}k4Os(lSI2k%LP*Gf4HSr94C-W;#h%U65op!38b9RDgVi~K)4CHROz|@ zX$hexgbonifRGx|`vU8gFIN%@@mGw$umXZ#qZt3iYhhjE87(pMg=Zk_ia>n{FQgQI z%i!gd;`X+VTHrx&}D8&bNCt*3PVvRh%@9ZQFL5 zG-%kMQDd8pZ8o-@#>U#lm>2V%H~D{=-)wegu3w+$obwRvB`HDDOUL4T<#EPeFpsN^ zxgyaFGy-bZ9P5vkcz3}3uN`K5?eT8)zn^U|cO<`ez)G(RW{$lvBkhlQ@F1)>f5*CQ z98jh|v0`3;_ex-K+pyv|gx6WD5zgR!8t-F3O7f$5eyR)2Xe(sQAY%?07s$9$f_6ldiKIvPh_ zSWV5xdmjE?284GLq8Ep-5d%4rHPkvVwH z0%q|iaFcPU3w^NGB3%EvD5ZGRH9cypgeOk=KT;w&=v7Id`T?`+$GCzsz%aJqUg1Ui zzdQXCcRLsB{N-4AZbuoMz_WUc?UPVKWHvzInBiE*)WMm0VD&Q%Je$3UFnk1G#SMfe z9R8O*fC5dUmm>4oF0g#B(J$x^(9h?9LtTKJtTx(7Tj)%Dk_c_XbNU9@jJv^F8;xwA zzW`$*nD+tPS4W&D9mvTrpvdnKxu^s7_Ixo8If{&c) zNerSFD&!#=eO7%z(~i3w;Jxd>J`DF5u@m8DPdA)(NMgY4gLiF;D^jXRShO0+jy4FC$ zr=dKKqK+KIYIYTBRzF1Gyx>AF03Yrfe8f%SO=MQ92Um`l{4RKwkKza5J(*w5@8NGF z512{lE9?~tp-mZ&JakHA4bKHu?*!(p14knO8G$TJZ}t>iPg`&^x!rJXcnhD?1nv!Y zj$6xhM`w+$bVY|-9%aFP5LoYnMKTg_(D8m5|}jPz{$lc1G)0v;kj%;evSrA5aUGsJ%KvW z6*`tQL?XtcevwF61aPEyX}RbVCJTvtZ}=c}_hoy>d*$ALJsXiXVvJ{|X9|8hm_HAs6e$2G>N4~np=kHF z5XISm_SE}F_UJP3f9%7KKwiuB$e6gE{evBY?0A3Tw=>8)m$$ny zL~cAX9CYFub0HkdC9&VwB=##?g8aut>?ig%oCr6wGjY5ItepJ{H=;j~)vXExZ)^Gz z>Ovi$v70c0eL@du@IOP2%!YXGnJ7UT-1Ajns;Dy4d84-7YydyZ*N57{o+6FAP=;5+kOmMVKBJ1pCRED14KVxEHcqmNyQ zr_qYpKwk!yCP`(bI$~9!GcuFT^xpHBJQLh{_c~;HNQTqjf*-Hr&Zo|XuFWpS{lZGcSD8GO_A+xwAtOspz71&MPu=+oR zr@8>W`%a9BA2138<5>>^;_m^!#zjv+M!S#f1g@4WQ8raxD(|W|uE>BlsZ&v`cnDAY zE(%WZNInm)G&!;zvPQC#+(@njS?`MBlR^{)(A%H{=VCU}4fBR&==EfXi;qKVx`6)X z8!~!on4!pKz7Z~D`=FHSh`vF?90GI8j@mmMXxMu6@Asv?(j)W%FOeIdkDw8*@D0(c zo%a2MTt!d4YrOTm`#fDdiSD}YB-d0|3)dLuVK_;u9qk-J_7V0>+euqb+h|)$+gMv2 zdsD|U=U8_SELZZxKmA#>4H=u#9-;cA9H)#>B+2S< z3z_4H!N*H!!Zab5Z^s|zKO&2VP3R$Jh_ewTH$r`Om3hOqM7EZ;N~fxYrl)qY?tt#O z&Vu()U9onqw!7wE^+VMwWesJ0g;(~P>yC_LPr+(EiN2^eaS`0*+ zPL}Q>3L&EpLIKbXS@9~vul63=TN7mG$^ahy349_8@{1;WN4t+ZM>y8lu3B$d9-8Nv zlS}88I!d3H_A;x?`nG1aX!{HM3uj;V3$IlO^6Qz#+-P};vX*)k za)q?iR@GWHjkPn88E1~Rt7eV5xvC2?sdQ$$(;j60o8(LOwD1(T=eVD^FT0iKr;
u4FbKwe9=N^@BEMXwB~Vq637U%xTgc+6PQcru`Fz)-_vy;e70GgtLd zVU|_pTC&Mtne{*f@)+huD}iut`7zIWj!}0feV9#!!cq^nV6S|hB1*YNF-cidc~enG zDJnWEBzdlEBbR{8p7Y@q)&-iUC(s*8LN#%>(4PP1`_0?Z6YFm4y5lf7uGwZ-*IUk8 z?B)e<|ID?tv~9NCw)R90jZoWgTO)Xk?y#rW*V~=;D2En~n_7EUd#+=k=eT#e*ckJc z-`Q8(QrR$hWyKSDW5pGvN@Y{2G>_EZRQ;7hWYJtZXrx|9XOL01I$XJj_&#}qeUChx zNA7v-wtLR|mJ4zIzVst_$)Axs;5Ph3y;L(ryFx#}upnTdX_=|BX{j+LV5YvTuDRx$ zN~4;ls3Bi1i{o0cF1j+b0Z+slVj(h*=@BKIjrmVs6Fro>B^>=4+L~D@%OlHT zWUpFd8EoxhX=gQBr&#hVBhAI;L8bp#R-037MR3b5Z@XrDXC7opHIJ}Mt^wYqlAdaS zOrdYN|734vlyW693Qf{T+Bn@XeSi4l^^l#uCr=B?&A?!4q0 z=8o_wMYsO}Gn6|cJ1B3cxT`F$Zmv}tDh6nc8;u78lmSEZv-Br)ziUP9W=)u8o!YF- zQ_O_>cL@6zeF$^N?#LQdAERzDU6c8O?5-y{wQQF>UU37i#rIX$RQ*(Il~t506^9iq z6b151^4Hu(_6sujaa5o`L+mIG5f_Wu0*_p(OEJ@k6s$rWakZcVfBY}xTXpi&y=6W7 zT-BY^9EtXSZ0BsBtSc=`EhFI2+p+Xrag)+>C0(tRO7B>jSr(Pvv_x2&TdrFtSl#9- z)+sinr?qsS_A}GDUCagM3ENsRQ1xE>O4n6;M)QlNvzk*TadnwH{+?nZVK>UQwy?nK zbGEZ}v>D(A?Dw?euZv5hhEx^i2v;5s&s9|~)Fs;80gH{3OaqO77`RgYHZYM$$7 z=)Y>ut6#}K%l>0M7^8=Cdyup7FYYRrAbTydu?H9ZTe(uf zD{d*qD_Sc2GNpVB9Orj4R>ISPj{G(4r3Zl_JVUC9*`ddO8d0NHh(m3KAHizOHkt|5 z_?dhMWR^YZ8{~cKY=t~DA4^V}Pnri>ldRk9qpW8wzgfN%AF*jlFIop%s@oSiCb+U( zv+Nsf%PsHBP0fvLx#m!(%4-t;M&x4(_yyaf-Bf*9C-oZLID^@6M-!(usRqI=_XF2} zvmuZ4O?ng+C?tEfIbS(7EyBb*@ayWzHp(;t01n2Ao zu8Qow?2@9M@|LQWs-Eh#qO5!ax0G4V{Eqy>S`J?8>>pH+bP9Q|dW+2vOL>R6#CTvP zU!^AgBT#MC7yrT4Z57)KTanM>yZw$e)za23bFOh%93$*X+h*%-%NBE?y{Y+tLt!6l zui;v1m%FN!M7my^=Q{@58C$BYLg|Z=O@;j}+w8yi;Dsg(7gs~+R)yXyljH-HC7K}Z z6vLmo4amFF4;h`h(Q=GXcj;DsB_HH%<-X^d?3BCix<-0e2=lC3BIr<;$=YFpIt8k26np->mBt*ZK6{!AuugA~Dv1mzTYZ^c7ofAC^vQ;*AG z8_KUJGLf6^yV52fEeny~l2=hM%0zizMXG#`+`(0sEnr6aS24ROiF<*$$iGlA9hCNq zl>deJkVJbi$0!5~;ElLbTqmT95#XVA6z}l0eD6JH-A`OPR~zSH$94N2S4-!2M+e(V z>v3~U%QMSP`*YhKTS2MJyvcIOyv%Z@WJgKal1b(k<{PDlO70duD_!Ay;!TlyickDo zrT%OcnL?STuB&XO-Jt1#`($;OR4LpcZXk08Sp!3vDb#ypjp^!J?o0G;@?7#>_RWL8 z`wOU2#v?8}k6A6dtk|qNrQWR$(#*y>lFxNqL9BuFu#tlE_VK3x$>>@ZF zc1NvNvVq8oo`u!&J0J(AF<;ya^uY)o(`@LIdx?p{Q{;92%0KqihQn-#_nLc>bCxUL zk?7jt;G72>r|fSXKJ#%$n)#favG%c=tphAYwl@}&b&R#VSz&csa!QBWvP+iPs+5Mg zhC3#ESoHmKh2DHy`kt80Os8(p@$#zN5LF*pFFCRoFuUbPnCtR=EX}I9KFFm`K?U3e z$isMW2Zu7-k@*<8&33gjAPchy1F2DtRURjp7> zhG+i^+7&&>7ZGg@71_Qn#G+)I)3|#)& z^jT=F-}qC|KisC?NrR|NaiG6g=mNI=7pOBwKyiIsszObG5A9j$j`$mRLu&sq$tZS$ zB5)`b0iEVj$qLo^TTutj$qw-dIN77Yxb26mzI(v&Rl;N9BiOxPkb~6#oTDH3CS9?H zdkJ3BSs)q{fKL{IJ(3L`S7-W~e;N(8D18a@=;e$GjQMZiMffp0yh_WMeb7HGrWe7f zA%ZpmHBr-Rv9=7MZBk#X8M`ryp~C=YoO(!S`#I(ib(J{-RYVw*0!_^pY6Me2--VOq zGps2)(G$U^YY$dL6KK1p;mjfair8NZrspufo-PkAP*-Y{{^G5A$WZxviv7Bb!M!B2f;g?}L(FA+A;cBvahBcpSI^e0$Wv%rF>fqU33z4tG`_0EDqcq*a^d%#RS z;*IfEXp|yES2+8s`MY#0IC8r=Fw~X3Y6{q0 z91y?xBj_pMp46v*qee-mfj9?=y_x-ZLVtt(R*!BeS};TX9mwG%|9z?q7-q|%i{YSq ziv{~T7%Ysj^kr!VW0hV*k2Dght$hDn2L2mhoHeJtQ2Ab^bNqGa@BW_fsHsJbVgkV- z{z|V0)_a5c$owvrVJ~8JaTF!Hj=9fo=juXjQIj#yH>8#R8+0$)Ep%W%P#I!tC^@zw zX#EE>jp`}Zrhmu0o5uB?pe{=nfsgM+JC9*RXmNoz;_qz#bb5_6p4lul2J0tH8ig(M zrGK!R`2>x_a!zfuFF6>t@L2lt0<@B_FjHl(XS zr!X9~NtgD+qFK0hW9D;o+0-D|oJ z7+NzVIeXk+Bz6UQycAloWVF~wgm^5-3SCUcK>xK&tO)FK3s}Dy{)zOj(A&iOlPS4B zlP;2uLUDQrY#U@rM?I|pEo>I>X^%7nE-;DIZ&EN5L_LsZp%?m>P8V8G3z_-C4yHC$ z8EAu^+DrF>T5kf~R662MWO@PxZ3k?35M7;bEqhEE!~%LI91uj&igMp2?gYM=Ar5EU z;(n$T>hCc&MkoSy`3vxhlm4erruT!JX)#(&482f#0bcDS%xY7_oy>Lc(tcy+AvaVY zc2Eeb{@t9374J8p?de_!mb33qnu%0WENx z)Q}E@Lh}Q47NtK2jH49j1E5p@TXiEn8mv%gF@Uy@L4>LUeG}N~aVZD7>f`=qOeM6s zA!yHo;O<(T%BQ=Fn`kv_5T{T|==jpn&vrzZ>j>_6wDcQ&*zXaQ%qVyWtfan3>**Za zzlnMwrn3u$%CwAnBYkAb3v1EWeHQKlL7oU@`9t6>wWWI~Z3mQWlJt_ACbeMpNn^lU z+$de6FG-haTI?+Kq!;>^(suA+-Tn?h#xkWo=$Qvg(f%g%5@;HGfVZlI`fn(8PZ~kL zmu^5qby~Ur)xc(OoS#XV$gSQQ`r}n7#V9aATT@GrZ!<#rlbXfUl(f*}Na#tf3pbco zP{bZ*PUHUfipQn0)Inu2@1pP0G1MP4n3il8XfG{H3;I3AfRmV6@5NYg715@1h?^u) z9x(~ApfBjTKKV~Gp=j~#XouuONt~oc1E2Z~4x0p}z~7RB2}3R0fGEy;+9LAM=v;@2 zxwjO7l?yNZMdu0~&>l;rJ2dM*Lu=swED(kg#++}=HUD?&G?d#9{T;=NK9ILq$123qb1!J z4^TE>{||tdYS5;tOF`7{Qg5)d_e+ocWHep?XNWS;LA@r~{lMU&@oNDMJrbr*Col;wU;gXiYP=o17TNi1E3YD6E$8{Rrl;+27d{vtTHJ;71jg4X*HO0529`LEFWp5k>K zn7IK@wgZ&c6X7G#6e`a0^i@PvDnQp2L;s64bOfj_ zQFZ=>cJmkZo`TO;fEzp&zs~>)ISzL>6F--uZ&-;4=Kx@n#0zsMVrT=g#ccdF0`n;- z>hSk)>^~i^iP)zL+Vyy}_%`S%XW`qC*k?G7Mk2EBip1kO;J12+;I=?~ zj(B=jLqAdpQMYP{-Ic{}v6z*GV@^h5r6fX5SbD*T21j8#BcknuHAc8tgj+^I7ekzB z&EORf7veO;H;8{_IuP^>{6&}tiHH^^AbLlBOTvqM3LiaenTv1I@cGa0;aPzv;Orki zNlY>uSDWx-tCSz-AsoIOY?=AvUNf;h@vbFN;-3-aVmu`a&PaGeUN8X&@9HN5i$#n` zg{MlqM#J!va0~Q5;_`&8KuUn{0OIiu2M;0Y2M;0!(QLxti$IwWAK#yB8N$aR$N9-? zAf-jPH-xD`_vh~F*whPd$N;ZqWOCl0*1_@0yk;kFRoEeV@|xDylZ&(Gfpi;yt-2$PKbPgs67{AT-c z7CVkeIHo`U5l#tVjuAc!sUc7?qAUm(htv;J&ZK_)QE0z$s4M@^qajQh z;uKK@eLyw*tb&$Jcs7LLMWXG*i<|JdV)0j*AFo*aB%ct5CgFHS;s3xN{1d|VB!7|5 z34iP-3zO7W8F&DsmJq%nVWSaV8(}aJJ|l4?C$;`3qmgirNC}YVOx&7FP!6PoiSzPL zhh_4#ewF}f@juyeKmR93BaYYPDHEn2;Vuy-oD;v3zsV;*_aRI1k#d&@*E6qQSRPv)#NUuUzQhE4k{!ucdZjy2*Wn22=KFD1O|2rD_jCk>q z-pP*7$yE|pU&03ZfBqR^Ao%eOVfK*{{mJk9`IC5_|KtV{-W4fl@+^siFyRIfMiS|9 zNe@7nM9q)lMu~4oT_op> z#+49$AYl!WvqfP`!Ve{*31J`-E-~4Y^i}d7pQv#La%}RG@Jq?);)Y7ogl0f`Avx5c zG|ofVCZvuLwl(4L5|#+z!xBekI8^+2@}x$Pwn!MEgnvb9jP&C(!owu(i+oE;oj7$9 zMjmnRCM>=*ypn(PJBk16BZ$vCsSBh97vM-%{9pK^_T`}_kUq=B4Fh5V=!#CMx0EeUhfiaiKd zk&Fx;Njv>u+_n5sDh)A*tOvGUW6TLFqcr1iC-K-9S&hKktBSpvV`kV3 z`+uT((;G3L9tWkaD8@r2rvvxr1Bi|Xz;_x!Ux8+5IBHuCo=+3tWY>Y>R|1{<6o&u) z)L^kEKc1g1^o46yH|7|a83k-STY=X5D@Zo+DOPwX{z|~Aeg&Ii3D%+#G~h%vm;k+Q zANoA*XBt$^^Q0tcJ61Jm@O<14Zq_b5#Z^G8ir{WJ5j9~HSSxSgbFu_2=?WNbWr5FJ z#1n6T8S@M?@?aF{X87|b-@$=ilf$r3f_mOG8|qb=TJ5}U}R*hNk-{q z;|QcpllJ{HTLZ~DvJd*cMB9w>Z zzi^sDt*lAA0z0@P@O3cX0yx?{Fyp$KYwy1T*FZRzVlBYWW@gK|AnOI%4(m28@`s zSow@cdnfb%H|V`?qh2>f&ru&}stvrgzp>WQKsEjc5P{3mQ>h=Y*IBenK3Tg*wMlbZ zk;c9Aopx<>y|)c={OU#~S}K}zD(kAQ%il7?8HE_;r9F+k3w^JB!Tt(NG#vFfZk(zq z*O>jpxTTH4T$e%|?5!xIiWa84tb}G2-MLRpsMH?}iK}=u^ndW|bR=6S#IdJ3 z=X#e2?djXhbLI(Mk&cuO2&;Xo_#u2vahLxNU4!XQ8RdB#C$lR4Qtem$CQDYBxRLZS zVL#UKvph5TvEnkZtEh)7L5ldVVC9p%2RsqpYrZc0J<$YiP!nX%8^Znz56**98_CLZ z!gz6|^iX=^pTT~mcQF;1+YIrA+5uI24dkHQA^j_ElG*`hcnrU_2XJ~ECuaLQA<}$U zxFMdDb_)}z6loAs1-SEj{}!r`)Ius1|KUqSiQg`Ti{*jL<)d$G2UNO%Hp=%aKPdmy zOxL_pt>jL#BmIBz1HDFHU*7<4b76uQEbS6{`bB{eO?)Te522CJ7mm_J%x$(N@<;WT zmvBvG+3X<1!)L>7EQ}rxUgL0JOeTLtc&k3*{}cxzhte>9y?8`;0>0o{uv`X9G4P-H zm%ac@Y%Dxh2cRByM!!-Gs_IJc54nK#a3^3U?QsWM#7*TT~k znTyfml>oy|pi}(Iz-666w}QXbZ0Pno;%L$EvYP-z-%9TQlCch>YaE_VlK%i$eGP%G zHj^p>$sdR-ii0jRA3kft{T+Zm6a!x#z>EcBEghbHI(SVDWp2_2CXpUOAHdi)67Jcx zrN?3q`0|+OI5+@307lpkh{g=;pCc^?GS?X#tp;!uErTOJ2j5i>7uR0QNN`m*(~-VPz0F;bx4vIZ{p3)xlInF#8$+g)!2L5M`Nz{pO&j>H;108n7=H0~fy!w5KV3 z1bW+E(tmKr_zs6Pmenzz;CZ(W82R6#MPenjbQ}HlJGuh6b1T?KaEICG-vVZ4KSZFA zF#%k)?esvn&a6alb_m(A5)pYLk)PU38E{XR!DVf`WI_BP4No!$tlPcdZPr0vH<51v z*TOMinI^Fb%pvv#xW#>dkWNT{iOb=% z*aOZ3zcDqK_Ouc{x~ZtKUg?F@6K+RZdJ!{$j>P!x5i5yC;Ev0Hti3^>c8pC0Lt7>> zcqXsl&)}qPv2KhS<&l|tGF(`1Nm+12=#7yLT7NnR3`vsD;1DpsZIlu$xQjS>Fe+j-hR^~HPlPQD`z&Ol-SHZIhN=~%jj{Xwq3?h5; zr3IL0Y@^>G2VyIx4RGH75EXV1=ei^)Pb5+=$VrIUocjXdpQcWq7klkO>n#{;Oi^>??nX^;qRphf=)tzD~!<$_vVE z3b*1eI8x&kH|3Y)1Hj5FX0tGMy{A`Gd!#x-Q{N^J2$4Bjj5#mG3^#M8)A2HDLc-N)P&oPWEzIkwneSeu&9nvWLiOK%pnDV|eO z&+2uGo)%Ip7;FcbjqGY!s$!+8FLM083A!2d+!PVePxDmSLh)I)LB3Qz5le-O)D@|c zkM?|b_4Jgw$9WHkr|5aIVueb(Nx#xC(->$94el6PJECb+m8j4%mr~*N5oD zglyf-)&D5!$ttrQ(DxUj?R-VF^ngDSGq&qs-3+x3u~f7cTW8y!*;d*PTRxZeDqdZr zE2>cxP&~bGO7WONW6`Ai{M_1kCE1zzrFqv%A6lD)nVYLWfsA0bv2IkAgi=gl6`-rc(1Y$#n)Qr$e*Ug)^uxhU+T64^X1O49&Lr?2yo<3G!*Ql``jLB;hPZPu)e>#0Z6Hl%h= zElHK7-^_TNF)25;V6i34Uzd{wL03%x`M8`g9^`pS2E5!(Em~0!VnxhA?kQU?U+;1-{LpLKPnqi zabSFxSTQs|sF}W-+$l@L6aOsS0(*IbcYtq?@1XCoPbqYix=5?(G>((+kY~sl`5dkb zbDRD^Igt;&rRSPcW2sN%kfVNRCXJnw*p}A|)WBPUh3n#zKF2 zXaEybEA&wC*FZ72LqtNXGXAe}z2ns?RnJu;Ced*U?W8v<+b!H%ASv36kZr?G&}=eE`+xD z{{X|J9k26V^ZxEx>b>Oc$S=a!T#-4-g~-ZsGdP~P4?pi6lq99`mAuPbDGt#hH_tAv zng2CUm95MQNjIcUPVJC1F!^jkcuN1&_8D~w`no=`p^8?<_WBb+cMRtOTZA`>To7L~ zI=JkY@c57&L333*MCr>s?!>HP~z}Y@4UeDVzBzyLaJkGv(SM3}v>cu4wNC zt_e9EUboE9m=R^mm5r*{x7>ud?NK*_*O?Y*$E$iOrg3c%H7_gW^WpHoUc;{dkGig; z0q4Gi{mfQnlNpjVuoHZLmq|zXKfRHzmCl#88ul^Pl+xBkZ3_ePV)JHZ|C5!UwIJg| z=6|VqX=PH(X{(AKI@ik+xmO00wz+<`;dVe)*qX4)(f!M0l^Gj(Bdk;CdhKFaZRxtN zy?eg9rKh|*)qbfowkRm~Snlrplf^RYLx)EAg_)`CYVZZL584`-7jYz#E!VkR_X<_Z zH7?UUYCTGINx(Vn1XUW8D8ckB+612O1LXf?MN0eye)5y(S@2rQ6W4>`-xTA+G_0kQ z`8s@@_l#$%d$;qBy_(HryKQM%no&pGvp0v?p+Oj4EWWb_drE8C6^HC4xsX`nuYEHf!nkqSzu^ z(W;_Q%U$n6|26pr^-bMg{d`@QATD%jR76Zb+>Q9U5nE#a3kwZZ1@_SgX;ms-RZH1T zIgq)F`R5ycEMG$!FIJ+OAkNrQ7$@BKeenM6ndRNh_vGWG*BC*1hzGngT=AYj?t-iVbTX4s)))6-&O%dS{r^B8E?>C*%G*{PGFH&|?{jFFczeha~{ucH8L;j1$ zC)5&$2wuLe@R)DoeGS(0RR4Om9W@_5+e&r}#^Mw@44jcuQZjXq-;Hdo^X$B1iRG;` zxKvXz)~+b7YCD+U&H6e=R?HLz7JhN`Df!iMqWH4wfCD*cL=*Q?&dSFss;N$DTd9JS zW0e-w71ewBTj(%u&}ZoeSh?Nsj`cG9SfPa41km@nGytrt{<6EuRvNRml8z2Y59n;t z8RqGgx~tlB#e8JDKg5ob?PHCyHGVVHEIp+*RHC$*`YiP0wg?v`mgf9X!b3Q2I~ms9 zTs_|%#J&>iF-NIBY;E~`e!G9VG?UXheS8(H3Y%J@`R&d@(o)}Vw#&j3*RM=~BhS0V zcZ^32!#$ijiZ$g&*AX9aCl7EnWqNA) zzsl5MO!BSbC(RQvjvA!0`VJefv+em9O_Xb${+igw|4AR~nxIT)rb{hU&+W721N~2Y zOO!pF5xP;Xv;I(dec_{TA=k?HiJ9j`et3<_I*py|_Nmo;qT?`kUJRmVdZsfcWp$YU z%)60h&uC9zrua&gyTowYTJ=KzW=3P~uIQ>rw}#M*{GrTLp+0kx-@-JZ%Zfw!`y3}K zq&VM9-&Hn@v-%QThv|iMdHDsO_JqratWL66D8n4}|6Wa(o%? zW3oY9u5Ys#;m**U6x~o9E*A$W=6VJww>x$yz03^9QZ7;a#xxQG?MBr$_K~}U8(@tx zb>`bRo~TmAEUBe0Q2eMZZ=Oq!QjE2)S8?`g+7et0?k4I?H%M67mu{qw~GQn{dA1?Ti1e0t?N*An>{ zWh3VT=j!hl=V_BGb>tS7Q@jTs`;~Yo7Y3pz#C+{bO#%Hj=N9mNMH`MA1VYq=F` ziu^?RRAAB5_?F1p|f|4!r9IK{S271{si`K zzm?I_oBU^hxO`wXf-C$GYu~5PD}8~Rn~d4z@6Bugukt-`qk8ak8&0(Vavz0gOb2=) z*f>#iq`v}KV!tvwfUBRQrvf>GFSxXVX(#0YnI0kKF~@~$RxjS5ry`2;jLwqA!j<<9 z(B_fgWRiSb1HseHrh$|AchdrxvB~eOe8!qDXcmFVCvJM(A73! z^-MNYJ=fs)yq}uOd<15711!i*^irxn_(=)KU@#P1-v7Yh8Hf4bSGXN_01r0@_>Yb* zL;VWI;RNv3rvQ&?j`jL|ocXZ-IzdKUK+DGBlcR`xJOP{YH>|nO0%c!~C-gUR1(3Me zXX=$yoBj*@R5ftAYjCNT5K(G|6@|p zI{1NF;gz`^j5|AOTN`jyH8{^UL`W(^UH%f@oGYPJi$gxdksN}zz_k0~uSnO1OXo$fp?cB7z`Y7X&c0~m z_-G54yt2^ezoQPp(NPA5SaqlrD;>lE|g05g=NUW zrS?-&h;&ORknW4?@no+98(m0q>}FtL=lvOUGei*EQ~yY)AgUCa_UQjPeo>~!HatrZQ-9AI$og)%?{5BXDl<(XB` zvy_okU@E==)36<4O?m#|;&t%4tI`jp;b=iY(j~CB;SDZ4lH_8tzna)iO6GeDQ=z`! zDGU=`B&EKNqHylmj@HegVbq?+OgWUjg+wU_=u{xd&v8f*aq zy`9n`KDHO?oo>MV(8EHPdk&hJV)$>*g7@SK#Hk8V>t7%u*Tp}Ao{N@9yoTGORQtg9 zAR1Xa@)2j)1VoxePleoRZ~{$-bL@6}ng=Gn6L}Q8XnV7usmKKCDmnIj(*Sekz#Zv!%CB0ota{T8vZ%~&7mkjt4HNSUxd>Mv>>C0W?_QEKpe$Ve*SyAoRw|5y812b`y$yo4hqWD_skn!yVsj->khX zQ_YJ@$Cg|y{#Z1*s2o;bBMa&m49-u^y`DQJ=SlXiENAAotV+4j1~G!(7PE;}k8t&Y>g=q~A8`W^un0v-jdFxC!e8qgx3h9S*hGHeOR4ahbo z87Btb2@)1hTja^8#XO;K@YF>ylz zm25~onQ<}aVnO@jY;za;SLZNaJHaU}X3nq%Wwh#*Mx{Sx$TluBem5nX+6C4Ld|*0h zObf8-rxS^{eaJqXGR zIv;dBXl|g#^w@YT;D{kfSENy^yC}yiIuiXQecK-{&E+$^>)jgX9ou!w{nEnXxS|yW zU-KH~2Ia`Jqq62^?#u{G4@#esN~b+dJ(pTDZASWzOg^W6VTaOA)?!CL&kkQUB0u|> z>9RlNqGF8tH?>ns>m2$C`rdk{ZmO=LuB`T?`ksnb1}m=1vt*;7liet*4ZUMO#cb6P zb&Phc?t)$s;14)q{2I74@J!&zpy@$D!M_Hd3A{Wz8UfkLZluAZp*Lt94c z*4)*$*L=`igF1ez@;^m$`Dl4<`2*P@#3}d7?TR_74w~}X=empfGltQ|0;4i8FHjex z5Bd<49@IMMR?y&}oq?SLbBw2qdjifFrs*5##%SxRXR4+nKL1HJn#;vp?K4z4(*zY? z%d^zg!_md|z|!2TDeYN&weU;9O5;v)6yB%$+oGkp02t4IsPwX6CBGt zkZo7El}fc;lcG)6F4Z!+%bMESRq74umxvgY%A*wB6;tKsGz^x$Ps4G~OUoS5+$CBAEr%Bew%nq6L zGVY{r$wn4Y?SS5c(vvZRnuT z@X$X)Qi2-<#{@kw1(=2z%M75YEG*9DCa5{%g@27V<-!M8F25ngLRYX z|C@Gk?-Y_+r|)RoZ!!i?2pJt(A-rY!?=f#L%_PXbvDgW%|Ufb zRatc%WgGb6*;&O>bzgm4Ksl2-C@~~IT#5>h$%tMP{VsY+ znI2J55qrbZLyCh|1%5I{2Ry}ikOJL(FWD^F1+F3@(TH@@(NL7t5byC$Z?GrWdBZ;4 zw%sz({Hdf~@z|m!1+NNRdDHU@xe?iOvzul1&DfkCmUcPyY*PKCobP87Z>Js38eCZ3 zdd)7lmbe=Wv#_4ZRIF5P)GRS9)o(J6G*s5L*Pa8@=Y;Y*C}c33BMuWT_*QxDI;+}swjGvV%q>eTMH31?=V#~5 z%|Drwn;V<`HEUgF+w_C!|D`@jd7g4D(U{aMVR=Gcnk6%&VQe&qy7_2(g#2Pg(%w&U)ENMYwi| zVX$#U;ERw+kt?IkF{9&-#N3bV5j!pVQyEiKXk=d4gOJHV`-~3_yS1&=7Uh1$WO+TA z#7#hvt!V| z7N=K8HKsI3*^#s>X;i|Vv~OA63l7+qJHC1cdFo1sm>_nuI!cqHU1Drv_{Ss~o%$4A zRrOA7ma3a}pr*RAk-RDU26-T_iigBa;&MMLfW4)4i&beIYyMDDQq;QON&YL`g}R__UW=TqSpyM`dV%>y<&>34n-iNR4M@oQ zRxYtP?NaVcYYY1u_h6TW-$4yzwCeZja9z!SjRC)#iUS>n=6Z!XU$b8QK|4s3rx+_M zrRPu;z)o7r|IQB*)=K3m1Lx;NrLcT{-iqXWbjr$yXJ*H9|rjF4Y%8ZR# z5YZ%bZs1a5iN2Y3plXov7e%bx$9-ewVxeN9+DbvvE3pycvvb@n?OUyh2}%lm9~l`N9rrTsR($>Vhw)wG%&}Wyuf^EPC?ci?hX+myh|q7-v{r|y zqLsxmC8uQ=;2S%ok#M8iCJx{kPkYB<+X>qeYrZ*-`0^A#C>&lGQ}`^uXP!4lnUkGe zE9X{b@ANaNi&6$B1t!%0-tL>@d$aU{>~WT74%XSTrcq;hYib|d z&~QiJN_|f&X)|@d=v%Cb3ZD)Ss@O;c&H&_p~ZJmt%KIymdvd)eySs#|YZ23Y@4`2e=x5~{+rydx%#Y_%=)R7lf#nS32zb(L#8t#b#cxt+eh1T=M0z2I~4idC2cKz zqCVc_FkT6{8Y~Z3rr)3n!nv#ICTW%{=5zb}n?yppW z)d^l3rZ4k8hL5`y_oD3Ma#>~b;u7P`h@f-PMd5RT!i<0GMD+z_wqg^UX*$BI;5^2P zQ{pQ)-mMgPemG_yN9-;eZTr(Y&7!bOHt#7eE^-w;DEzB1Ht$OQzML9)?{ns6wa%QJ znwgxCuq*Mv_m$s1ek)EIn>nbYr=z~Ty(`RJUqWs@Rd0P0!?u8pfxiUJ4Q*&VsIR42 zpf01+XnSb~DNVAqlqfV64tw)Fyqob-e2g>%ZZ^-DGO{zu#-@v*OQX!us`$O}k!6kL z#d6v4!{V*66XKk)*Q0KPPBxhgr!@`Kt&~F*mE~F7IVOx*0_A@x+QnlybyXG?c_+Cn zj$V#q_P=cxthBX?xl3uU(ubu?NzXztzh{2k{04c!xs9^RWZp@?pIk1vPtx4)vlAA4 zccq*1>)B5@$~k9y_WI8IZQLvk9nixt*CY?F6nZ%1t-iBdq zTR}hN&SRZ9!(cWJ3Z4`CGa0Zl7bVZr)ll!#u3iThyfR zdVWTpCcjf|VD6WUUozSyT}yH&Q3;C^;=UdGaw731jlFYpG&de*%Gd*l#ME{5$!Gry;Ig9%pcWw21?QATTV*AqyX-Vz^doIa$ zTnQsDVA(#vROH(6P3%vCcSlc*EmQneiFu`WmhN4)LD_btx|Zr!;+GOPi_eXnA1w#Z z3kdbQXCH3cBlH!T@KcbgvIY@?iP{T9>(%S^>d>cG5-mjiJ`IFuGxtDWB$1?tmn#E(k$O?<-NX& z8U$-u9vdPw@iRhKMt3ioQ0!i*`K51{EnBXAnJT4gm1Ig@DUn=sOH5Qm)ga!LWw+bn z`A3`rnw(_;T;&DmnjTwaSow z)cO&Wz9A(S>(>1EYXJ0~{dvTPr|&7x|DX9TT6FE8vD3%t4r`)pw~dz@ZK{!@eHm%b6c&gvjz8}$cWnW@9IwRaEt zDSSrZ*+sUMYF9d>-1hRP%eE+Ux5V3$KT6n&#}^T!I)*h2DD8K_79{NCc`k>IXS&h# z&6QfDx*NER$9!G9jXXABml4lw_qx17Ia*G~oOQWZ^Umd+%=wx%J6q0(_%SiHZ_>!* z3g4T54@pRg-}U7dQX)IXDer!J-|5@jAJ6i3GJK$&n&yJRF6`lN;*moDbgTX2~Q3@=0D8dS8UDav3uB`(ITzP33@T*mG7l* zq^FkqguAyOwaETCdrJD^v_2`9l0GCIOI-DR zbNq(*6Q3)5)IMx~f9L(5pGT$D&e@s2%l#F^44=q*y0xvHt6V_W;057SRJrgi!Cn0Y z*DSw^t`k6(T4w7i9AXnqntaCkV4mDX-)IhjoploX!ci!2UeLyfj6y?-b}f-m=6ksV z<$}uXFUghcS7K4o6966`8+JA*$+f^fMJ&smf)CLPPRY0AzFJj|mdbla=cC^9{_q_2 zZ19Z8Kbdzv?~mMtz_M$Y@9&;a=kw2B zfBhW(>F&p%Z;2T*a<00|dsHb^`$YD0skWtl!vpFDCWm#1m>sq_c&UGJ*AwShzbr>( zyChcNP5P#pqR-VLl)Lf+;Q4`JivHcUGN6QOF>INw3%x5cuvo2f&q~qd&z724Jfc*1 zjI)RqK0j=JsN(v=zq~ELet|8_#vnRy1$|a`=pkw+>6N#w`@Fj=um)Rs&v~*v;hsyL z2i~dP2cB=<7w+MnL=WX2n0GPjaz+)*--FURq|8fMo76q2d1BAR#tGlzdw<;?|Mc6r zk*HElijOGcC|0q==%ObI{ZpuXh&%LC zKsWyeeqY7uwi3(??hAZW$>uEWuzp3ks=9rGcQugmM*CKH<9!a_d+!nHmBh&tr6ImW z-i6W#Z!7ob{9bt_bAHa;p3x)2n^rrOP2Q7wKRG-pAgN4ZT;hM-Ua|dzM%SaOvs@H(c zy`izzoXR(L&G$~D)#Po~4Qe;F_g+fW;BSLNk$*v{-D;)J*v5t>6-Zq3| zg-IOErqL~#rRH^`J9Su(07C9_{e`k0xjAu~P1~w&SHD3oDAI0ap6|80wKUfo<6&|q z=aqw9i)0naE1k74YjAc@*7fW+=_1hRzWvyoUO6*1V{Jz3%$b=h+;LKbx3}^unF|k* zOu;YDv*Lc<=~v$oZX0bUwkx(czKS@X9l^KfbJ-1i96N`*#*AZbGgr79_9wOxj{U$0 z9T(a^sA_0_$o&u@+!gAG7!Z6g{73MrfE520euzX9(BHb>NnpiB}YIDLxG6H0J-VvZQ)Y$SSI zcl>WT1y^zB3Fk6@(KXC9+3}CxZ_d)TUaoFpuC2JZ$nIsE+wa+%+FlstY;CZjzNfv{ z4mdVZZ{?GEb!UjY4Y8)%?6qk%Lvu8cOF8Oh^THpROaj#eWh^}ccQ~FpqV1PXsTZYB ztdSE=_0krQD&9n;vc63EB(Ic$)UncH?YVnCX(v}Ey_73jH~pQv5?@xCF8@YVRVPym z+()UGn&GQMg=y#Lzer=M7yO&u1j~yU2Q28TW+>l=f5fn`;(EmPbShg}XeO+H9_j1Y zZmzOz7gD%8Ono86F;Jds*Tk;kH~om2>vx6B)i3dNwS|u1%rwMMj2F+cCv$J}e*Tj0 zS5n2i>*}OPdE)|?E62?wt)AdfF0=VS!Q8`*W=3kOybb-EdNpZpg9J3hV#9LPTn6 zOxQ#3Z0dkt=k$j$kG}(B+Fp%1pdHJaRHT8YSJvc^xymiJw(phSWWQSPlL{YFBx*P* z;yG5hSz>Yj9EU5lY~UMpgI<~cYD`linTz&udApgL0Vngel5yN%exmJSP6^+Wz%jOj zjBEPo;6CJ{bSV2%@mt<8x#>oX{{YwY^orC-S9xuovY1+Cd!y)%V<``!HW||;mvJ-m zX~liX#e9#0i<(R1G-a1#hwWr~8Llcj%r=-rsB4rPfk%9I)I5HlFC@BBLIVHJuZ89` z24}YoN)KF`^E+9J2eOK5YN;`vRXZYQ77nS#htnq*O@fY4*T1_-k`4WR&!Jl>NZ1MKFEwiSC&lU@3 zFXmTq1I1>Zx?*8ty(9$o^+jkODL&vIAcv_xhQE$h(ynmE*k3{)oC@gV&G!e zyo{P^inCV0;m<CLl~$oO_E4v#0yO{c226S4ZZgO!K?w+rr-x{`J70B*MHD#Or$IO6P>A+a?WNFve$;iWM>65(-uj6ImWN0ueH3zen1%)ur#x=bC}tV zDPfe%%XAD5SenyYkKxk724#Kr7Ndh=4`fDq&x&LGCi4|i=Q&RaL&z4YYwmN`G;TT_ zkXxLrirtj2XYbWJ7p*Pz&JxM%Qf+h2WY6_mrk^WxAYRv=`_GTqo!Tws?_%x!&Sqao zObF{%taxJkjH*E!ibQ9}zZ)0ZJK{mYK4XD+KWwdfD=FBv)>X&RE%|TH$zsj@UVlHI z*``>}z$u?Tx~CUP@b8@dJ!gAFo)|2B&7DO(3VB8f=gidCghknlB+;aJ_!_Q%PHDBO zf4CCBMkB?IEfa53nQqfr-glDs<0s}nv_J7Jx3whuxRFY* zPBC}=lCx>uUx?;+Yv0IB+Z64LscWOeZQSzw9&}IUI^R`$rLMLQ*N^#T^9M|sSqId% zmOK%>>I{!a_i_A#*=~!FirgUg9J7(IfGpq#=az5;=^L5f>1}3;I8d4)1^H`Q1@?8u zvY_FC$1_VCJ?tgb0Iq&|lYm#Q>bVW%XzJhKa@^N-S zjj+98SF$R4p+L{k^94WYiTG6iE<45Q(wf3UGl#Ro6}RIn`^nvt|G{tOW~R3dY9`b+ zUMWkZS0TlasY++x7IIw?#N}(u4ikqm%QLpS{9O~Y6wfxPYuI`=E%yW|9cb5fdZL}< zY$c6&?|kYw*TboKx*DbIZHT756sFk{l>oQR7HNO0dG#mkUc*n=C(HCq_7XEfr?QU( zRpB=)4r#nCJ9tXwS9TuLlWwOccoziJFbW$b^~dIE=TUW%dY;^36~2ydzhetk#0-d)E-m*y)jm~%{i(Ye4?kzB49`hH*)R_BlNPn05Nb9I?Xq|3$z$0qYvc^dhP z+0Tymt{1fG2E99ErRT>E#9mO`&1|yC!&Zwle zWCU9%@>z-*2lVa%BQ;GYVigwGkFWMTW1`rIk&~AB=qcrd)M=a}rU3 zW62=%0Y4a-@QsN=dI-9r($#p`n87!ZnzE&V^qIzYkp|dasZr(~W*>W=?x5Z$pU5X} zmFzXH1h#i4`r^c$#&5F9n(8&?hkDdrn!2S-;d9NMQdu+B?$O(`{WTZ&(mVmQ!&>Gp zDiDz;Cuo&@Y5oG6?Olo`!PJ zFJgQ&pOCNmW~MUo4)gfm*}h~rHJPm9i<+%jO-rT+n<-QlkVq@B<>l5iN4oM?w2$;R zqL{a;lb2=7M!ubXIxNF5lIc@q0Usj`5f@jQ8flCqCyi@nUvdLk%}sR7>B#{2%Fe)QSXN!3576?p z4usJvkZ`4ozDNnwqm**$9nCAXfDQj&} zzhVap8@Obl9Tz9~n62DL-o`iL*K(Hxhgero#iLwvaT>oEXyDy!>9!q0g#D5D)IO7c zDNYub3TMPV;yqCo9|?(L!fGXTqna-FQ)sMwewF@{dMHDr&tU8*?-}Ad=ziz^lCS1u z<@;x8`9?-?_T!AY*(bGCWD~{lcc`)aPIj-ox3iwTyR(#SjsHpi^Zvb@8Eh~++h{>l zYBT9$jMA@b**XW7n8I>p5(1>8ZaDRX?#o}VsX5`0iVnv5(2(}4!Wc@CwsMCdZ6A-%T%lB(JG_OP$l_@vR)pg zc)WwXtaqvBfv1f(&iy#Q3M_qPbFXIi&$*SkKdVUgCC^=>s{U5ELH+|`Ng~5J{}BTG zZV9>en$CB&ht7eB$t|m2P-h#r^<&fkQrT#Lh{cmY{0}g8Xj@1qb&fH4x3E<#!?om6 zVJo)Tbg?l!^t3pDUj{_)G#~~1saFMJLLj1v>jEX_FQXxPN@+Gm7{IjxXUr0@r#()H z7d3Fz?B{n2;leroEE~%>*`<^O2B{(7;wcN1>@G|ztgr7g*O>^`#YS;!`QDtyT?amP zZ^R+AVzaF##!^?7&TgK3}NFmIuIrkbVx$;1L1=@>PQp2S}j_lt`}A75WcN0$08To7~C zY!77AYKWm9g{WW37zoepME!;Fg;b@gnQ@344q*Etj`b~Y^;d9_Vus*n-zGK$rv4T7 zCr$^N`e?cV(3pJY3gG1IMYii6@)SI8Yk|5DkC^60$Zh=y6l@>?0y{j64P#evN4YZm zdj2o&B!3rKgg*%{coEzNm$@a}1|TPVU=A_E!1Yy1pGsRfs#;`Nk zee5_k5NP%3OgN&GdV>q>HeH&LsWU+0`OWkiMa-?hQ`km4x{WLcbI=ULe4ARlW&=}I zf5iSaQ~mK8rp{Emsa|!sQU;brPN}Vg%3I{+l3SW9)%TV2752<@k9K#=`=0+JZ;NN0 z9HbpHyBJ9rDF#x(>?D4auvVDKwcxg(CvNAy&^4(4z=v0XQmOGkE*WUdAQ$xh1b9(s zyT3>^)NM`3!*Z}R&0q=pjWhWT!bssOxHh_QiC_qeWGgWpfOp;*oRBRLZ7!NRkP60; z&!iGYeupuO#G#~z#z`R8UO}Ek7*iblVZEU3v$$ToUFanEiH8MESSK`rgbw6OA^z+j z+lzg`JYuqdvoZ}#TNJ$-cnTHJV_PvZz?#;Jeap7yMsU-(d%)!B#nKeJ`0gF-A?}2!^lB>hYlYmtefTZ%J9lA)z|8ewGLXamY`PF z&LPIShqgl9tK5;FOM>*yd(`{4=YyxDXOoBVR+Mfi8CnUTQ{(`bilO6}m7I(3!nuJ0 zv4Ppc`ExJW7s!WcO%(+KiAsAg4m{UC=o7(cv&Hxc5jPBKvl}yrEd_pzscZ=@o?FE? z7G%CQVrQpw#W{|f#olFd!2tIPeYi3_`B#7tT%I%rK0q{Zg@=JBu|lMf#y{Z)@GgD|7`D2iW(G3vfl#s?OnXhhvo`{f z?0IxWrYBR1EyMoGUS|{7v(S;VF^aSSXH*(9n)w59^6iju5(Mn2JLHT0KyR(j(L3q` z^-o$ha(FjtwC>dgL5sQYYos2eKhh3pe`<5JMVhRhQ~O{Ju~;6XB+HNGnzH6g^!@N2 z_a%4>`v!Yzcmz*`FIzsM7Y5SUHgE`BH`_DU+3ox^t|#y)`Y^4yp4?4NU=PyAk%QEl z{*@UC%#r=b+^7u>zJHDK=3Qhw+@zL(Ng;&Wi7`0@_SNp(4enq53_q2x#gE~V5Yw`N zea6g0jJ^T{EXCYyb} zETps<|AucMTo;ZAMG&KCfQ=#%+A{$oumfBSZ<(RU5?H|$1JmCHI<=rqy$p}hrarJi zuCi^}%V6%%n3YT*a}IU?FLE7nFx&G1P3-Oe&_IeI_e)n|p zE|n&$$Mh+{e)pIcFa|B5Zn0~*>s(cS0s3)yb~F5Z519%~1{F(P2i8(KS;=yAiz95NFzx3!I?7!_6<8-a6WfbFI_P=T({9}rPLkCV8A{6u8#loF;1@j@gh zqqhqUz+!cu+repUCU`Ce;3^_8>e_(oV-`Ce+Z;9zJ^?#Fo8N@(C43);!7llp$wP+5 z0s00oD1Qeo`97l}Mw~=&{_TW~(bN;b;<-XUtnblPy*{Zz?Bs+#OUH~(H?-Q&1@|Gl ze;`}Gu69F}6yyvj4b@y_fbtYw$_-EYjxWb>{dD_p6VBB;h z^8n-0Eog&3fa3R<=9seJro4^uvJCqabAs;B70FnwP6bA{g&Ncfc`R3q>Bb}S0PKjT zpwq`0myMlZ<>{g=P{*oev;q1_(ho>p4}d7Q!Kg>t=r54H^cJW|{q@Zl{r!vwz_IzD zZ-*~wuDTNJhS!xE>L_)%c1i7zDD_UNU0Ed;L6paz3a=eDO3+K$R=EEU;1U|Z?FM_< z4fX@uf_u#_VQwJ9;FXz;jBPu(Z-8Y3w7z7t`ZDz5!N3EIpkpw;CSrV8kD9oQp1Oz& z;*PTh))K842Xh9o&CTc>WRp!mZkY^B$6TNZT72;o(ES)7FSf#1b_V{mEFc{Pfn%~L z5X{Sihp;7Zht^@1DWHZpb~SjifTf0+RRSk-ySY(ZA@IvRWxKJZ*z3${@FQ7VJ#40Zjxtso2V+}#BP!w>kvdKow zKZg-yd4a37J2_3VNinoTUF2ZB1efPez&=^7hhU!YMXRVMYfrW9S`#f+dkWpVL+y@y zc}aP!G=UB-Dvy+INM323QbF592yjTgvlqGQ;9qQmxbbsLJ4D>aLJGd0&}4WmX6QBk9>xl ztyp9+9R-%+Nah3P3j<)OT8*oWV}D_5vk_=Z* z=_YhdwAld69OfXOtS_wr8Ga7XH1oiOyAy27bAWC?4;iW>vDX^;X|X^ilMSzt2{zYc zWV4+$W*cLS%0?lq0X~q+WF47~`A}1=rEDZoe}R#3ixoSohw7KLfm)EZ4ifwj>%+O& z!qj(aidKW1Muh7UdIM1ZcA~Ej0xs7crVdjHciMybm061sqZ4w1jhY5FdUa-2BY83g)ek6GZr;69{sZ|gLP;DW9HwG*bB&_GZCr3 z3%y_=whh?34944JLqL6Z14`vZV6IxZcSjNXOG8IlIG>e(bQlAMyAZ6KD$||mxscT# z;5X~b%mM@KVbtwD<`2wi*Miq{2Xg?7n@e%$QOpOdwfyNFK=1P)L*N*)&ej3J?TmQ| zcu5x4`i}zU?cXu}R5PPZ6?OU+>;#!cqTvFLWe8CCGK>e%Yoo!yXfr;OC*%Y<0&TsO ztS1XF9(IF&-a#_->%g^|4CI+c;A}0WN9bjc^HvD*o&|=}KeP#2W8k*b)Fxskf=n&_ zJlSaUHp`;tR;I_|Sr$Wkb;MJ4#JBx)3i9?w0!?!R*h9-OPr$P4pwA-5Gz>EF2+ZDB zpaVVv!Pdguw~+l8Q=5P`dJPN&DPT1)z)nyUlHD9qH4Lxe*t*k==`iFarbA*+LR&1u zIQtv0H^GSmtkbc`MzlDzMVkxOsoa&^H#_f{sq5IM0VyMxZf)%AqhOu zl_05G=ts23)WBoN4&*N~UG94L`}^#-tRIP`3= zM_+SEgYc+W z9QGEA{VtSt4ms~{p!pT>MimA&ZA!n+9PT4;dlt1DvU9AU7XHFB$3y3K z#uMHEo4NL2J+;Lz#>sZ`JluHe|Mz>!@F*&Ywp^S(fY>;@#>p4jh)IZJPRYX<3# zMZ0;Cwf+s50O?peJJ4E9kOe#xEEp594Zv|73ivOYqP;tSFJl~DQ-D7^0(H?G?Jodf z_a*w{DU9tqk#~L(cl!oqagdDKIA$!cCs#sVHeuXYi#uA4zB&i536O>XsL}Q)y&B3X z2`RFW+6zO!mc&}BAy%iIalPTV-guNU3H@|3Ru+GvpCdCFNB~taUQ7gnsJA;Wfkb!UzsW3mRp4P0bsK$109}$sNmCR*TooN`a>7A#F$eBNQKpqM_dh@)Fpvw zOsEgA!yJTO=u6c@uCM$LM8Ds~H;u)9? zkE7dzE4?4Ro_>KW=6YD?b-|OQp?!;kSEU`$v@gPLy$CqWr@^q{hTf|Iz4tS;(yx%r zJ_Q=95o!%t8aOu1LNI6u|24FQDr5=H)%!Gd| z7Ec)iZJ&c&^WA9AHmDgJ=8j1i@e+Yp{Q#H=OTgz8X1)WC?I3W@7dD*Go3TbG@W{`E z#h@E}3fB+;d>U)z?%Dx0QR%E4mfOl+iAZJTOY&OfqFO_LPfmbOt1vs3Tgev@h6t5} zdwe9{llzNRnctZvOht^oF^mrx>TT(LKx+7bl};=&svdy7aHzIhQ?w=E68%n&8Q^xJ zlBr#Y%uQp)v;VO{TqUk3EUz2clHh}GiF!Cf&qfA&Z`k}wpeCjPx3rnL5~Fc5@*EPk zKtH5&WENJ)$Bin$4>f@ZJqx8S+W-wN657|5H19mq%g{{Z>voG+3{V^uo#@~MTXMO~#J`1h$4cfCKW>FK344^(+ zc#2Q;H~LjQSx?i0h>H{}UuE>e+&fwXcB@xgYfWEM^c|Vg|ht@lWYwgwAX4z?A%t zd_%6IL?EAkl`ewQE)ZxXGFz6fAY=*y#r5J?kqAA6HT*^H8(R@^^$i$_&caG-D>QE^ zRwE0{9man0MlY@3!+2_HzrlAf8)(J5fn_oroZkJ|?XV(lU?;=c7!3Z$gP0j~V#3g` zk7C|l7@UQJ5oI|RQuwnG2fpHRqz>r^&hZJxVskR2;W={$_(p9(ijc@x1hVsdJ_GB| z7bt5!?7i*jdC=h(AS(^ioT4?^r>6nY z;GOD3?D92bw!B~J>5KG^_gr#M^ECASCyiHsHgf4y{-v#ov!h=lzh=%~?E^%Ko56m? z+_xmXls*m=iSarX4Dvo`2oyXKlHewu-N_v zOV4#Ci>bwqV0U2svk6R{`561}U^e+HbqTtBA^Q0{U=i0h6eHhkN$+F=VKY|vVM0?Y z6OyxY9P9+oF(ds-*TlT~6;_gGnDR_-dL*zmwgVeH4%o9@z!tg(jB06^%ihBJ@E%5) z=5#i2o(IC>k_yJKc)SOq?m6_cU$D{&Lteu!{h4+d)}t+$9Y2F@?v2LlJuyGtqBd5? z0F6FRdhbi}KJ)zKsqGQHO{H4ud?SXvD&BQY@!t~QA7FRUjz5GoY-74I#!8L)8ByeG zp)HFGpI0N;ZyN2lbn(P+lFjj2;(8riPteu~a zT<5~#7vVI2j(thrGuM*GnoC=+MrjLlU}}RgZUonoUk=-b&UR%=(sm&G`vGaW6C1_n z3ZKQAw!UC*PGi0FeBc01(64GIG5h(bA0aP|_E_gv!kF}%`59KgG9*i14J$DC*^J3R zB!5Kr#aREGJ_&x#g=C0cU&~NGsv+7Dl+X&byBi2F(P{&^nA8ms%H#5m=fr2%&smc9 z$g@K|NzvjVzx9D{f_@3w>i@SRQvA#kSi2jrZZ1~*0?b@5{|}R5yj3abtvfLPer~Zm zE$^Fqrtg?ioJ7&5xcTBM`(j5UhlKlI&cEaS<=*jk#Q*HQoR6H}9p!C1xJNnc_C&^!K-|wjH$9vo#c5!YJ+ucrhHNPyeE>h9An$+r%@`^TvB# zYOFrcpMjsSB5Zko0r7hyGm>4$&Euy48@Zb}M{EWko|k_GM8zt!Y8=uH^|&%m8L0ly zdIPO{D$|#xxCziK46_@V2s&Mp{TCj*r+kV)glWjZog(z%zpy_sF~Cyt>i_DkbSGE? zj~Ma5MqiAXVlez!UCh@&lWu@nO?~j`4W+$w6U34-%o_SH@VK-y|20N|IhDY}QWT83 zql}_tEorGIX^YkVYOXv$?l0%~>PiCQ^R34G*bW=Xv_SkXR+0^g7 z^QxngqqY63_=~uK|AngtuIopb0f}Z0Lg_;lK`P~)=Khxdw>!~OR{9AM14EEmlF9Dl zL&YOv1zT5Js(1)~$zI~`Vt_3{_$Hn~G{9%37}c39*Otp+@R&}h_ zTMr>6!JZfkB!t`CK;f|1*JiWNvQM)w1a`m$+cj~%5YI&-E+LA#WSoW1_PF*|;0Njn%g#qrg;)6&(}){{+^b>yzx-QX5IXY;c$XV7IyJ)VVibX==oV~g z)v$i5gn2@QepauAadZcHhB?s+eF-cRF|cz_CqX0+d$aUT`abP1^|@R_KI6UTE#w;_ zMW}_yRLaZTI5B9bpf_I+)9tm;+U0{g}gGlH95vQfn!9q>9oT zUwx^%ToU<{LEyLig+0KHhJhGQ<#zro&dgzn7L2P3_m-vJMXo4>%bh_;~E6+nV| z4z%`dxm-C#E43FIj;UjwE zHZV`0!Ww=daHc+(?O;nAZZtREzye>MT+}z~^Yx2Zr?kO3r2}=JZVfB-L$rM)Y^PJ% zmZ*(cbXWLzPJyxE5V#henVrBVI^8H|^d+}-AFT8@bbnHn3pcNTcK0yj3^ccp)ub{Ii|!)2L6y77i{*Uji=2R3T_bgqj{2_n?7pf} zHEFimNoxR3(}Ubtp`z`T?Pt4Tt7@wb6p-cIEp`=?fHAiwkW~HzI{s5VUp=RMle!}_ z*+h26SVYUmXfO3nrU(B0U)dF0E_aj9E^&~{K(;}>%%yzD@vx_Q7=?Mt z7G-5N2+ZQA;Wr3kYQaY_7PclG^V@SqM_An}5s9>ibx4Ax^d5W_{$P$cX`Tfm&TXs$ zd((l`X1X60g}!+Ke5yHSbz_myh}1%Ma<;lcU9MhH?yECoFR~Is)FIMd`L?uF+9j<< zM$H)ctMXj?m-M2B0eN~hw~JrOHx$Oe2mcuUw~kyntZPof1MSc5z`AxJq%Iz~;V$)s z{7T*~4_Exu#_&&0B0XU9YGQ(q1n6H+=*f`4Ul9`$2DBMgj1rtePi`k$mq~>Lwl;TS z-F6d^00Uu*WXLD|7FOS1p#z?1)AUB9xDkVuRV&PbAHqf&%YK3ldI@}gUiuy8;4fi8 z7)xisqQrnf`IAurGs;o$do&^?F>02lt5PkwZ_E(>2K&rL-;jL^ zf7WQO923uUr;uuDw!=L)G7g*ljY(jYjnQV zA`=*nKY(Ifow^5hfjY(zV-e;zEy;Z?P5W6Jt|rMz@^QJ06eYg~k9|d8A5T_`13P-N zD#4zxR!u;5?*zoFRiUf0&A6`oC-~LZ@;`7--PvRC1a@Si5$o53UBeKJYzwfit4s#! zQSb(z(z@y+$u05_UcL=z_c7EHSg$Wp!E|BfIdha<&-LXu^3(ZN{6@}T`?DLFM_4x| zVVz$R{^f00Cx)Zm>(a}>LKn&e0}<{L7{`y{c@6@J{|c`4fiA<;0d7fAc-Gp$K9O!Z znRQr+-QZVpeT7W^cVQ|YA^ZVtG+H3+5MeAIBizS)yD8rv{=;FAjxVtGYp^&cz)s!} zHu7RdACiGp;R?uY2nm9{^O>>S_==Uz9Uv^W!)V+RoDR|E0IcV`krpIY{{z~roW4sp zluUW1yhW0Hqoj%6(b77pzb{bPFHMuiN#mvIz+i~hWNnmD6LN5mtHWJDk8UO`=bH-t zSOd-!CyGN57q^{%$&O?f(mt%2zC%}b(EE~Q+EMMh76sqgAuU&{1FQBFa?oI~=3YfF z!mKWaoyt9Aw{w16Uv4gU8oFW)`<{EwOoQJh1=h6E@U1)`1HcCoPa1J@dJ}W&D{3Zi?qalaKvx{3u2nybR`o-P8-PT(`xjaRGjx*6dvRE*L1<(68wgun5P4 z1-KeGR-C}-yly^4d_jT{1>TW)@KY=UyPg0Bv!dX|@Pi$(Dr{um;CZ?PMw8j_hZco3 zFA>)1%ka}30)EhTst0`?)}&GNc=!-@0L3N}QDfy`@u&xc@^^?%ISz)d9n3?z3p<+G zifD=x@b1k378CMm!TD5|xq-3Q1zXt_*fbtu^r`?Wa9vm(zrm)(V{~bX)pJR#NbbRR zImN7v_%jM|oh)oIBj9<9Fj@f{q!<~atBASSrvIy*M^4r(eW88_^06M4_xi90Hq(Y6 zPN2CyP|E`oNLBqcEWp9~Ms*T6-%_;FdW;rJ1lWn+!k+UEQDaV6`16e@su}EFO~GUT zf!=}|D}ma33r3VXu#dgOy3>igiK5`7xd!CU>hwJv@dTE;*U)hc{U59a*Wvq^2U{TQ z!mtA{Ofq7w;$g{40fypJtPW12UWdYaHwbYA1DKx>^CZIG8NtZ#gC;ckFj&Yf^18(gL(nJ5# zY-04$kCR~VPfasE0_XEIVlE|P7Or;zew`b@+}&=$n||l)xQ3;Q`2o zr&2ZdQAzNC^atwcd%6m90sUqO?3IUMiC;(UgGR3ao7x8aw;ViGUtx`0#H0duvlrI| zHjlaRRWxT-Kzmlg*!mxLkZNiOOcK|UuBdOL+Vh{O-da(NCHY)a%|p)F#+#e<^QHt# z;uYm2oiFS*_bH{gt84~!TbV>vumxb2vk+)D&8Q}NXDWl6sn($GF~jKAWHw>=V({ag zFprQ!^bO{%)`T8m#u@^*$H-Mujm7*Q<^k=8+7b~rWt6JM9y%Jfi3oic+e`a~*vvqE z2sc@gX_g5DDpp0hHgin>k6NWY6&mT|wNzUXdB6FL4_3}{KXhPs<>hkmLXDlwb_eW&gL*zMU%x#896RAe59R+gUY~O{8}jxm+Ek z3A>+NLp@PHnQq%YWh_$;ertz#TP;TIHd=5&Mkss9YuaDH&N#u`uW#ZznM;*$F;ELP z2HTqYkSn2WH{*md-V%($PuDVxe2O7^**09Oyz`Fv+5`44Q^=nj))TXwG}}%yRc>kT zC%0C+_~om*UdGg!t%56`vt2XZ%FSH|J*w=89H!7+v@^p6;#$>4D~QJ1#cDoUv0Ft=y!x2$i&GW)s!Gc)?|=p|%fl zBW9`q2ps6kt6-)I0f*4w3haq#ot741fBG-kEMKc9R}M6KvAf*lx=| zGxzCxT8yKURN9=tWspMLOEuIS&TrEWiwAu5r~!^hNn)EB5$0|FuyPOa6oa+Jko%`V zrE0;nA+^}ndZh51mZc6t{jXDY3QjiC^Q&`*I?9mwDZUC~CKc@~=|Yr*F;h%dSJPkB zVS*psOztbxBW~lZa7Z$dhh4_k(NPZ!9DdwuqlNi|-m5hfu9G2JZQCy83VVotBwZJV zQeV^=+-_j%KQKRQA2^Q^&n{t0dK``f?JpzJ)+E2QcthxrKgn+DmGr;FwaOG~A5&8+ zYpdXU&O{6E-TnMnIiIS|9rEq8f7U0`G^gR*{jWK%)w|pmsbNUKnq3xU;ro0OL?i;EmIY;ZUK;SsW51I0Aj2*(t(N_YPYd zdYjT!n66i#R&%SQ^^ObrcHixQ2U2H)2#tO1oaMAM#pBqg?jj5AO>&c+t=T|NqHTtH zfvwID%G>IUqE@Td`1RzF*4ddRf6?!X^_7a^5p^9klXdAXZl&1@;8A9rY#a!_*whgoB0N^nZj{* zd2uVHYmN9jdJ8g_9}MeSUuLCNm>o|!$O-nEKH2QTZ&nQoo(ndGs9HMa1rvz`yxs|< zG@A*#V{0-El@;78 zdW&|QNkg1+2)lutqFb6r^v0s8te_Dh z?xl{xyFZlmnK$(^%x5q~)u8f>+Ehp8gt3{N;nKCo^k%BM@gAe?Cqt#Gk=^_!t-5)J zeWFWD7bD9IVmgsD<{lNMBNsv|!OsL|+Dz`U_JZrEmqCQ-O>>4`kKd>N!k*QGnD59f zEd*xK*@&m`B1-J0Ih&bHuCcpyc#Y^0q%I$=H3bJ-92v$RMTV|FvbMXK6^+XDIA%TY z_g`r&)qM$xd(j>C$24Di1-dLOch78<8v#{%_r61 z3n^`wU>dw_&IO-xHEIHUXwB)mMhB+65l6i+iyJ}gNHUz3%hY&5iJo)Wph*qldzea}zWCzp+Mo2@Lw)z_+gg z)cdcvT4$_WAA$4w1S}d?F*mWeqN5NkkYFx>EttjX%?ZEmPr#^eN9{JF;YF>D*>7K1 zIP8eVNdW$OOGIxkKvZcS7=mAdzqKfQ9yPF{<`7H3!J^OrKBQ3i^+nXzP;8~)y_{|y z2by>UJdclobRL1oq+;N}{TW)M#MpzeFU#LOfjx_L~6rJ_byGCxDWlg*yv3 zw_we>8SHsoP!n^2eYyiNaHGxF@C0=t6V1Hz?)FV z;iIpB8E`D>YByx=A3RlK)OR^>S>MEaJYpz*0y}LXM4G(8`sX1^Pk_%m2fq3YTq_J$ zRuE(H0DJG@(-y#X5CH3f^}iqHvVCw@#gL1(1H1&+@Eqf)+2%I1cmyIc?%;X^)`{}) z1jgX0b8yf8c)G@D`xc1W$O7NQU6i0gcFP0FzYJ=s43!J)`z%B*G=kNoA2ODTUWh=mf++1fV_YLB_*S@>|@0Ce~0Qo*9!u+_S}T{s#A&iu(6Q9WuCDI!aX` zH6dWBPDG7e!)9?vmjSC}G@kq|-mM6iXw3h^ftUXZvhxKJ;iR78OfTXyD74>G-0LmG z)@2~tz>8?L&$#Ye9F>PMDO^7dE&2^O_$(we2yN(vti<8^DLBHwzcTJ2zo0zpDYL=jckqvEU;9hLl_eV?G(VoHhRut#j!IjQouNcazi!y>BFFNGCIHaf{ zO7%xHP(BzC3ZWgU;CyS8^P$w@XscR~7CWNylFPMxBGfE)`XF3Q#JwY!1BEdhG`8(tsrODZ-iZpPxT|KFK1e(|DJtoR%mcjm+9#lDDEwWGciv_>xaF^l&A zY*rooz%h9^pMx$?asGFFx3mI5I~$OvEZlh>$`J5`c09iaXZrATF4Uq`D+=nxgMS&^ zr4vuaqdrA!k*K|B)R^VN4nlicJ+KH`$%%a@>N6BA5Q&x*a1@85!f;juTHS@C!tqZG z&MJg5{IM?-oMG)-Ef5NsaN<`dw2TX{qL3?R!L|N3W>^twc6>MRq*iN&;d~SA3NjJ= zqvLrvw5}CLX_ZK!1nYn6S>Q0lvnY6GA8H^6ZIO;|DFs_D-mQL~k54Q5$v~THsBMCl z)A5;)+Dpflg|jkHpFgm9@PeZstq4Fr)Qo}u6&!E4>_#S~)EC%@~ ziDSxPeo+B^w>%=RenM(1fIkD2;{eT9I}X&a!^P1P*5^kf;Css^c32*|uH; za%1Tv1MOnTw2D3JNJ~ns7L;+0brgrAOuSp|X7yVhSDo_M~d_i5>3hr3IyHzt9o<&5BTeWQ~I5Pn6 zmK+2_F0JeM;U(g_R(;y=a^TY%mn_ZhkC*korJF2$W7Uf_s#*6a6x^v3?d6Bxt+umT zkjB1RQ14b9%c${u)Qt7?$S5s%mr+O7xN1E&%#qNNR$0!1yK~`e2KQ}AfHh0A`dS|9 z(mFSzpx@;d>{+#M>2MS0TiQdy5$=McGtpyzpf4vwAACZj;umab&=5cHdp1h2bV?46 zw%Rfu`)k~~CX0=NYdZE>$tzKImI-vskYDIjOFVGVu&|AykzbN!n>)8LvPjEqL)=VP^ ztz`A(=z^^TB*`kJbiuh+8I{rN%b^5I|HKw(jc}Z09bXi$V%Q4d=$~**1srW@iQ;&- z{S8Qk67W^i|BaCg@Nr7cvbd*hy5`w!pxR~Cz*lcdR&SI*h{ z+0q@*yS|4<8vOh&tU|M3_4fs?tW0=yKKv{V-t!6kn zg$~Si`1>DF6dZw_(*t1s7ocq{ zf5H2IfLFbQe`Q zGamx`mjZa*H@Kru!jXLpN9rYf)30z&GU3nf^}8O%fZwaY-@AcA`2N4w-|yi!r@&Dc z;28$^UO0TJIQXWT|6P5*=c_;eH=g}oy|bXB{_UDzU{yxLnywJ`4Eb=(FT?Rm0Zs0F z=vrI=r^t0U&L1H2Z7H0eg`l$i3`hSJoV(vEncwcXhw$&aaAbhK_dm5W4N93u@EM;- zhcGMABG7?MrMi-MsHA`$6?)&WHcPq z-c4cOJRDhq--YhZe?%v$2R%iqL4XdM9F5nJJ-~{gOY{PNja)!kP&@^KoB37zR`4P_ z#C(1fkl)wvYl9<&y_`{|3a7|%T@diWy_)S;EVZaH=g*@FTY%r1x%pRUEfa6~o+&SNne?S4y5eb8(VHO_=icbbI zl}TwZsCkx)1k|Hkpt(;)>*EixG;BV8nAi?FmmzpfqCc@!Hke#aejtM2U3B8pW&Noe z^dGbmlo1z6ijHN@GCDab-^yI0&5*-XT~SZbSn-cssmNsp$q&nq%8$t#$uBYeA-LQJ zX+|q#dh!W5jpXsM*heH(910Y7CwRGT1j{|!Tw5GNo$2mouAi=}c9qpx5pLdNK4s0d zYOVj6o0V5D*Ox!5_*@=s_E>saSJ`dW=8i?~|J-f-sevKlP6WYbVRm9V*6Ynd7+UG#AeSmfWmfiM|5ju|=JX0#0lox!!K;Ao@Ha4D<2fO4DsY(V z4^;VkTyK7+G#UL5cr2Z$g|wSak~@{PHLY|%b?bDcI!tdezA@f2{xDL;L;Bm=qajT~ z+Nz%^ixr*auOV4G1Kr6SH-4^o0U<#dSJZ<=CoEFwAVr%5Ih=Oo!Xb)q5LtFhX z{ZGvn)j>9ixk?EHfz6cu5b}e){60{cX4{e=v1mx?)iS&!v9NssnzJY8R`%*#F?&{y z_Sg5!4;h=Y?&UCfGYb-PrxrFZ3wO+S6x&z2C;KTW11+Znb`)EOJ*1-ayF=}fvmRm&< z^|&=c1kBX)usL{FDxg@Y`A0A5efmd+>fys83ZtWAU&Z7`OA(fc_2H_>%i&JbA^lwq zqSnc8({bcu{Hb(_b9fV+qwW37?<#r#i+gr{;jjL`s$?c-InxjP{FT-!fLaf0U%v%xsdiC989x+1$YFu6S!Vc)y7>ic?3~G0;r0Y8HXOp;JsuY);g!Xlry_ z%!$}wba`Zr2$$iiHdS*@)ladLGLiGK55k8)JKq*>qUXG;pX;sL;L5Y*+RRSM-ylc{ zLwT?GOG0po8K)egJ*d5|jWcAJ+C=t>>L2TlEsO3G{US0k>Q_{a$kk!BLT%bQnk(#W zW+B~H)&Xr5Z0P^!Y-0_#^eUfMvY^13t;qV374e2S*ub6&$uUij?h@NPu4N1n=ZwjW zy&9iTsafpgsCr@53`6w)X+A=d*+29KS*kde+v)$?Z}MJockv{8V;sqrJ$9$78;rnd z!FR#k{2Xx()=sWgHPTGf*fnYTPoY8t7JDOZU#vg2QS^k^RdF@q>crlO42f_WX6e`v zO#Oi!NE@)9!l!`D-Obh7ddEBsVi>pPEX-PwRhm9HO-gP0ebd*s-(Dn5`}XW>y)R7i zh|kW{upgNzqkk>Qn_+%mI=pzCeT&xz83RuGGCP57tg3GOE2_8>6~8bxH2!U@wn~Sp zb*j#&YLERE^B}xe=n`WeUAK^GszT;8ktBW!cH{;G9{MJGV|?k(+UDK%Qb(siY;e9; z6=)VX$xlQ^Qv(!VlpdB2}Ws18MeJA!x%>0<7N=@RDD-DTGjNB2%>r%Bh zLiFl`%s~9ScqUN8JKtrr6<16sZj%2zYiP#$4B5|zDbv2a{8}yP@W%z8FMJ;KiAb9E ze!%CB$;kH=sday93Udpq6znOvVwnp{78R9~0Lz{$%Aq)hX3- z5=v`~tXjFsqS*4N)QIv>zp<09yLzPjHX%ua#O3^qU=RO5-+IqTOOY+f*1+>R5bxgP zz3)l#cNcCTJ4q+x>N#jhIY(y?^MnnJUKG7A@@AzakzM0&#D~Y~;B0uaPe+Dy+Bx`7_(G0d!@$o-#z;HL^I?TzO~w)3`sXT}t>~V^Hmm zwO=IIs!ff*R4G2TMpSO3Xy~PXs9X)B_26<0^#~rOQ5AMxV;ko)e<6hkq%{#?* zVlI*ptQLNiVV$bj@6xIWu$B&nv0Km!4lrKN>&kKK}LL%ZFBPx4pmkDe-f? zq}6F_axeetl6@?vXT>(pcwa4PD;`EarFyA%h4zW*AOC0E>4z9kWhNU)}SX^^O zjk2m@<%O}mVh=?3jW}j}qd{36@G%pJ&%~c-MBpE{%)Q#($N!K2s`scDb4~IN4gMo! zqj&I6pnB{{H&Jd0X{9%qmPMw9w+}6jni-K7v%FGjRMY57rf!D)hDy498ii^M1qs-E z9bl+;btPDjmZjvE<-}**_|Yh%X-ZzQ>1*LTVzPFkahE?Pmrv!zhZ%#A<`)cdhSKU~RCef1dxGC*J$QU%*un_k%W8k57a+|>wnbx9-nmnW4{ke-%k28?akQFYd&p!jeadlnV<3V$KCV~*?*aj zS+_Y$fha)XF=Wk<ZF<*YAvhXvKC%zLCsM$_tdyoxoTBBZgBLW z@WzI9+E*&KQYr7jJO!F#7r`2I2)iInX_e4jQpwSRhM?lxQPilxp&zRH0c(n8q= z^&I2AFpt3#J~cMBsv{x0Mypy=61r7C0?7`VgaK6wD#yokjQrO$RsR6C?LCyW*vT^R zfr%tPTa1SEnSQ=#Z$EFmcN6~yc+m^^=3;B)6TX@pDZj_&Dgv6Pn)TXxkX33kIZTI* z+x2ptLc^$Hl_Bh7#s!pa-n+(+K@N}Jv9|c%LRoI}+|-;iS^gjOQ-^$ApH%PLwvS)F z{P{Ki6Z-jRsyU-d+Jf}?drD5b*fB`|1auU#0=9m-63sy$Yu7f;)-k|?4D+eEydSD3fCI)4kYg$ z-&N3f%7xARU7;(ex&EXwnK6oqEXOw0B!tY>r5G-mI_VGTHis-^8=DI=2$zF+=YCoMH~!nYgrKC!m zBe#d$GRpM@+I8CN>T&XH@(~aUvcOjWE)^joX!NxS3=J+8u7W-`iBA_|q@PGz>Ii*C zu8{YV&sP4Wsu@xTl1s;jEYyrt9)kR_p7QSUHOy?{9OOHG_09`a4ZLxTvQ;UmTXM8O zEGo_$oqHy0RNCYpJHPKs+n;>qFKx4Q=??9 z&@e2C?1YTO`|(cyIY?9-D@8$m^E9C&IEr29JMcYir(rL^oMwixFO`x?9_l2dX5vGM<}!*JjwNzTpo*)e_qKPC=e$E@zFeAB z(5WQ8geaa^RQ`*~n3y>}eRXCIq+goztLAnn>YI_4JGpFQ#bJLNM>p>P_jYeN=OPEI z{!uBF6;zRGQOMlL?UBc#osrdyC1F)em!d6EDUtiax9H!7G}cTptkX@^&kxxt{|1Do z>103loboMvho{nasUKOBY)KnQAKnD1j|sAFG^sd1$*BP{h&iUb%rfRKl|}Y|75^`w zr{2IGLJq|+_e$4pSGJ>>d!{YYe6}Q`Xif2@k`^Vc^K0jte=)!CUoC&l$f%ijukd8S z>XQ4Ji3K~#j+*=KSIIuV33$Ms+ebu`xj+I{ zkYC8pWDQbA-Ig_!O#^b-LR63jqt(ErmJod9eeB)ns_nV$lG{gHT2v&K)8(Jarxf2Q zZduSVuWjzoe0A=OY<-R;V@+v`k~*d93KQ}}t(!ax_|2a2fpSl(*cHghTHSZ$R%QqJ zNZZ8JG{zUPOkdB?PJ6=?8-6lsxABtJr@E%RuYIl$F`QJ6Qyw7)$upI=6c?4($#--^ z8Ai9DOQEc7mKmvMvg43aW@Bd49(p-*nvQ3J#Ac=~HIhU@A9n*yk_KXODV1*@eCcDo z?cCiwCC<^#bC&#yK>6889lej>EP*e7jouM&kFvu(#tJUg#VOzrIhV?XFFjO(-8T`f=gV*@RuuC5Z znXcZXPSDiU3=a9Gz9t{ZoS^2(cPQ$z?d2x2qpUj7m-?IT%?zbi$xZ>8(gwb~WYP|m zi-*X5pqyBwZzzGh5MKLx`LaDPyki3AeW%^y93Jx`v(Xk~E3vh(3@@EroKX~AI;yOH zS=G`eMb7*|1uqNsmCUkUvIZ=-E2flxwEgX>6$DzVR8LfahVv+HVX7#jRM(Y96*_j8 z>WKE1p|)|M?uzz7$m@{r+Jh$AbU@$RP^8(TI~1bQ&eX2ehH76at1Gh=-`G_OPBBMb zN1jGkA=i-+WFq|*{8t8IGt|erf~xl}D0PO2bAeU;50@1z_h0ubeP{jceKmY_-2x^{Uc z`(JV^U}P(o7;qKcA>!%rY=Hg3&QUg0V(Kv=4!zz;7$OWywafJ%O!Y$jq03B7j3;!d zx?G(~_f}U)TUB#ib(y^&&yYM1_8t56^2lK1Za~!CN>E5Q`LwxOi>^S1EIa|7}yN&Yz`}HAL*E8 zUt_Ch{b+q^U1;Cwh;@~_9=p_@>fYo2(%?2eR7eJb@iLTwYTGwiESU*}+yROL<$AWV z>b$B{-Cj3ce-3h`leGmQR_zzv7Tswr71C4HNtwwWWa}!g%XiT}vZkyU{EcsbZ?+VD zD+a-#?`&tJ)_+NT=(3gy>q;qygR&tr-7%XH^ckJ=M5|i3Vd6B zl`x!-5i6t+{!$IR8C#>U3 z#rDET@Lh}ocULiA5BS6FI6-^@tBqx1r{EkQgPr0-fq^-Y%LzW_T5ybTH&{#P#RY+_ z^b>p!Jd&=}{54+KzaOEbhfk_u|I7kR6=2e@Lx;9L$9 zUD7gPE$HyS3QfQZ+XY-#X3@f##jm^|jupC#xxmV40CbKY;DqTXtpZx@b*Wr9C(Ren zqd+J_`a{0mwfnp3ulju3=Hu!RWf^(~`@Dt{-Q^2jc zDUKJmA=Sl^=v4SLV6@Yn~+0-LF|E6Mh0V{pd(x-oJOXIF-SO&XYHV3dW)8WqHq@& zD3;?AB#pYTN|5g0f_^~|>cPv97GftN8gZe&mF`QiP^6K#0Xu|`hb!wXxtnjm&m=EM z7eo~{0Wo1q#M8K$+lki0b|NdecsPHbfiG!8?xH(*8F2ZJ0y*+II!qcYp2So6$>JES zKtQl;aig>fh}xHt1E?V80N*AJND|kviDEc1Q^S|2TzDS;NH9rH;LHv~UrN=$L0}|$^AEu_+Z{Y=86b?gu93YS4jhO zx6oNyDojR|QY-8_vPKw(t>Atjo3L8ooM41$z!H23-4z4>4Xp+KvbxB6bQ{nd^TiL~ zuk8h1xduoUb_5ihj1Z4DmlSxR(1u6=-8hN$gpSiKv>g`*>!bpt26_%}!<_|H>}a71 zC>J{r&>ewJS%vtG>IuJnFt3*e@KdDAWM^TgY=tln59O55GutI5pxuf0(B1tAJi|Jo zF?a(wnd^vhX&;fxcLM562k{tJTQ;8WFVN5<ImuvBRxc9Xw}g@~KQ7no062s3jJ>?$ZM zSt&9VF`l)@`RhQm;F1JRY6#T&3ml1q%_2h;zFw~@JgG}=^L zAl{YoWwixCw1ckikaS5r4=kP+Xb#>&l5z3EN^l*0LcWPD$OZfabd$6keIV?G*7_La zEwt(j5SNfFWn+EOhr&Hjpgk8>Vz;F>*gC!=<^i_hS-h`+!zU`nFAM#U`odSN9^7?$ zpd<~!G~88C%bmve5wC;yq`ATuaUqg|dWbH3Hr87FPqgz(kY1?3d58w$B=iF4v9Iz| z$Z~F+*aqv0?&FV1(>X-83fl~9yLG}S?6~k%x{G&$&Ug;wOU)7@@cBSHuPF=xx8Yl& z6ITSC+-n5Q?Gnur+Na?Sj6(bRj-vsWz0_q%B)7KJh0ghcKdWKBl zP@x4{M3AThOZV^KLeO;Zb3DU$5_JR}7mZ424WUd{HLy_pD9w;P0bOhhBvINcsKgiI zW|YA9OBALR8iJouDQtmjW43rf+D_*1%cVi!Hf;6#(*c9#pCE=FhNQ4;@RAQlwqpdp z5w5Nq(jl}5l7VdpedJu+6gYt;1x>_yNlUWG5`cX42PNweS$%0Nk_A3KEx65|i5u}` zsW~?iwTUpoAt6L#QI6drx(06vBe^zWSG=JD6>1Y1d?e`B*9W?xPX#%8h?LOUvLX)L zto{Xj1vec}M)Wixh2fLMO2}-nQJ|XCL-r1tKwReTVktrdbtTYPXdYCFUa}2-6Ui1o zqfy|^Gx0ajzQP3J7>xLfq#R@bUxONv)xsUb zkIn%W<7MoFza*H5?vPBR6S4DA(Ca^m+k<7mbj}3sdoA$-9)nZ^Z{{UzKG#CpMT|#{ z;5b8ZgLDYkL!Z>l)RWBD8y7>8ETNqOs!) zoWl2T%LyxFBfRG(p(gY`c~VFP)xZGZGU$#|AlbYh(U#vWuEQ?E>S7Lx!^pD==}I2r z%emc1SMV1kz_qa!-nlLISkOpkgI(#C!dC7yu>Q&}g*BC&LMQI2bP6j4hwl%h8yhs5r4!bG8)G#l2Tdf9sHJZLDsN;`Sb ztCK$Qw763`gXJK{@Vzh_P72HyXMtl%Fa1GA0Rb*n7K04t)w~*e0Z#n_Bod24yOUo~ zOVBM5GEUqVJj9#8UCYUeiE|?6-z^Bh)x3??A^PwGv8_aHaEc~!1BDrUCgFs>f)`r} zd+#Se!2ARb_33b5{DXFbnR_+58=M(N=`Xn7@?jPD5s#1tk~O7g;EWxG^p)OCAjjA2CB%GX*i={n4if#g2PcLNrtwgV( zYq;5~v z2nX*K?g!dQCnOE=88j#{kl}G3ND2?RCH{;25;PUqJW=>v>7393#{a#bx?U%0gIdF?JHuXj2&}0^K8!mre#3U4P9VlSgik~v zD`ow$!?I-3BzD0DqO%C7#mY`|0^S2lfi>hvY!47hzKM_cd@&4bir9kStP~oF3PeZ5 z1ApNu%-2nT-}?ks6_LVU;sS8{Ov3tz;pk;P6pI(mz`ElR_#Mnh3;33yXm{oyc7OU98;1zBsGc&8-%;UD^LK9kY)+%VAs}7I3<=NYB7)*oDjnn$;qhIf})9 zU=~`11%aHG!50J7gvY$%Suq`0SaGoSK|!ap2)t3L0wtb;8Mzv|QmPRClVJ6VRS`v` zGrv!o$7g`jq#HT`(hA;#xA|Z2q8vc7giP>$w+DsFIt145;Ddeu4rY_6#ReeD#W(za zNFDJFsK#1Jn}Fa(ORt4L(Lbd}$TNNsutVdd*FarcDYX}x19j?=)Cp1^YeF7)Kx_oM z@9V&@BhV*87Ge@1+Wpk-PPs;MrpPr3~6w^*8n zR)&>-A7D+i1znUGG`b?VM$(YMuvV-LV`(8UAJSnjc^p2&SWzw26sw9+K!?2#>)9KK z4CW0lXzg-AA5{Q4u;HLhs|2jpal$5n<2yiy|DE;)Zn6ndvCtNYhr9QIgrP&g zv9lkp7EaoQXn+lV4+xcJpi2$}o?<#^iyDJx{$H#&a64iC0;n{I-u{@!M5Qu zfxSH+Jq`>y7tFL0Z84^- zY5y?)5}>1u@W1m>z8z4vE%w<09k?OD<~Rd{&cUGRTZ+tw{lGf#XYRy*K`zQ%A_q6f zzCbF>V%cP(p)3x#Onvb8SR>%todwPx3)fTtsRAU4g`gfh4k}0y*6w$JHPIjbe-Gec z4#4ZfalnYt#CzFP$XDq`R>*$Inv#iB3+ez_42-J&2D_{H{Jg9;^fevO}<1 zS&IGvjNTAz8hRJ%pU}@oze-`i3i>3Hkjirsb{Iz3+qM!$^9}j>{Qcmn;KBgQy#`A8 zyx`G5QDA7`OyIVEzF+OH;;-Y|>pkq7?iuKP;3@ImfxM13fi&OqfF@W8a!dyECxN9A z1su?5Ac$VWdgJeid9p~^TG9ncU&TO_j3u8#bt|4|N+|I!SVOEPkUXCPZT36tlNU*i zfkj>f$ylevC!n1C4Ex`~K+5QeHzaP$3S=o{Kk_71i)um#=mpG6rV+E4$*1qo3Bb;} z4}E~~{BU%s?T&FJCW2pHDfrYl7szE;nj>{uM1pA257CBLO4{nN~L4K8oY=>wgoW_7)=P(onqxmzrOguYq;u%2O4_2`$hzR`S-bI+Gf~qI{$WA zo$aA6o@hJeJm{EftFV5vU30E-b$6|BAM}j%zvnFcWpNiIGNzMi@^3Og=8`h`cV>+I zDbt?Wq<}m}c@z0WCX2p6>!^{kgV<$I%r+873Cp9~s*OsYUOB=wslXLq1X*_ab;H*ev)tkQ%rX z>=0ZW$n-Y|EcJKsVZMjn|2)Gybv;$w@vh0Pa_4a8UyefieEUjUq`jJbti7vaiF2WQ zgjeV91Y@LEbRm1O-b4ai;m4>qR1w|sw-B9Gv7gzd$^ffX4r7}uBIJM3T1e$yh7W-{ z;3B9RiToy5)vo3y!Z>*u#?6h~WpIawh%>-TIu08|M3J{C6|+V@UonGypzN#SRLQDy zst>Azs!6Kl%EN3!wua)o{2eoszD=zqN6RF94G>hkpuAr$6^pebR2nZX1SSqIbP$z5 znz8`tHl4Qv`+IG0bD$((3}yxTLyvWp|Al{OAUfdlxxL%HU%X-7o*u8ej=Pili>tP) zk?Wx2vtynk(%#=T*m}~s&4$`G*tgkL&M=S1TR-UIItp5(6B(;qdum-s=lqBtG20TtF9_%v0D`iMYP<+SRf;6 zuWS*q9bW;v%{QR8HzG6P?)sfX(gEt=HzBJH7Y0J!h>lO>dI$Bv0l~3>0r22KJ39*Q z^#A-`=uYJLYx^JidV2#NC-8o4o~G{Ja1{Ewbgp5JJNB2hIyT0ZY+Yy@V2^jcb5-|# z@VyVt=NYLpdJWqEBYPU@pw80`dNDw3~dNq9yF);T7x=Qeb}i z790hMh)Dr5u+<+McnSCXT>p!}QU7D#82_?>AyDY+=O5}f_&WGHc>8+#x%;_mx$~S$ zoUlx`U$j57Ew)axOf-Kqx3EmIG_$7I+B?>|_Ir~3gM#lMr>-ly4GZ9jvLR$uYBhZw zK66jSJNA=uoob{i9j@7hpdmY{7%u;lo<O?QGD-=Ze>ahcLPy0)_E2z7CLHyL0`4!yFs@72E~{xV?cN0lz;l za5Y#7a?&x&MJxe`*yRW;fuI;XCPK9HpJ;9z~{c722d2Oj@>0+&C z^V(N9zq*_HuLk|>U8n7dkc`8Y@Q9V?hpnj|}D66x7 z$$QZMk)vcMpwFR)wZ2MvDog-vlP=P6C$d57`>HnzJz>cPgi8zmK0#;g~WRkLhwj2rd z$4UHoAk;SGCI;#U1_t{FSNdmyVmc~#F<=VxBip}Ot7Sx8>vz?nfCm{Fi6KHShK-QiGy^9k> z6SPE#f>|$*fKhcXODH*(QTo}%>~MB4+k`!;NSAMu z2N?xJ)BC76vI16)`|0hYAZ^7yjA+TO~M{R^&`bC9WxnJc}jaE<4j8+$`o+)|7 zRQX-{4mkw$aPxs(#X_xiuGkB*8uDN?JP(;ZsCXRmOFM((;v>|l&fsN49g?HQF+1eL z6wz!Nt5()k&Qq>d=Cc>sN9;RR#y*wLk*6~AnLN5P=+5=TP5dDAXu`17KxC+kjFL^F zLLn7lW$=D*mw#5UF7W+6crL@Nmk|&=hT!8sLcrty=^y9r?f3!r^8{aWSG2o^`>Oka z=as*=ueo!xbGp;(IAytMu2Eqx?^4#K;z&h=s1!DO{3H{4mqhYs23^C^5OJw(u|MB7<49Ro%X}}YXvaK5AXF>ul$2=kPc_+csJsnY)!gC|zEX09t^4<*5}* z!N9sQ5Y2g{7QmXVP1&gR)E`WYqNXxZc~m)F(^D5?l8n_&O$`b9T3Wm2Uv>c%1P;nt z*yTOuzi_*Pow>K%1g<==5X^b=yhnZe{Yn03!S#X(J5P0GZ>ySW{92cxbLf}Qcc#lm zqv>d9pD;zZDl929D^v{Y6xPw$K%1g^s}SgN*$1q?ln=YGI$*xji$}rXvPfvbKMP*= z9d|n%4II@RTO89IR!1FIb8mq+)7Qj*#rx2e=x*igVMDE#%)>4JRV=qHv}RZnZC!1D zS%zEcR>YT7D*Tk!B(GP#J-<%jhq7^&J>H(d7^sOYlijA$<@cCNa2GdbHJWC+huXb{ z^``e>eIm|;nhlRN-Bfd!FtQptUuY1-gEJtHF43RlllcaE+e7l?LdR9xcH7_fDeiP% zko$)GlJ!=!2^nJOU^-%27@8V(D?BB9W;h!@Bcd7QvnDsrfF z;u_`w-h@}I0%=I~VfO6IX9TDFPkXnx_c>eJcUXs8=h^DmJHU#iyK|RwpF7dj(TT!+ zYcP}M@nvbHcS~QD%qXi_y1p#FbVbRt;){iC@?E*jKwquRtI#NCV=Tj_l4*tipw zn#31Xx)y7TS{(5pRA#uVX~njnQ;EIUW9VaYKuL^(ydw{A5Jv{<`D=O!cU$K*J8JJ{ zKVpAxFL7LUUU8 zF6TO5_akYT!EhK=DV5nCetQ4M3GVrIl9MV$!~ zj4t&=c06Uod8idB0xx|_J+nQR-RqtI*e_edt^KUmEVgd6-TjFNLb&t!5`%>vzWj0=n>lX7ik_*3KjM9>-&hqZmW#TB-8yO@u5q+~FRi^S(=2T)-^{Nq7pxgJ!S1qa?2~MJtRu|DW$Q`}rJstg7H%jp z6^6pMR25Fizm@m;mnyq?R;wIkwkOA4=&MMxeGPu*y~qOUusm1QCq&Zxr9GzKXM$rL z@gur@OxIX6E+cwQL|E8l?FUshbCSrxx=OS7y}@RF(MS7pJ%`*`cFgX#<=NibwXRp5 zuY4@hhAdDdYWkR7gsCGsM74~Lh;_wjE0L89l~z?AUS(J11@ZZ@ccZ37Dnd`{{OW0n z|L7*yV0~(~l*gH;nZK5gt5BIaOBY+TJ=sc`k5%ld zK+F1-pD7twG`ysL@yp_F#l=Nk3r-fc%3YK@F1u6C{G5W^x&f!1x+POMc==-pr5qXg85eiDk}3XI zoFYby+!#K@R7>Ag^I373=3)Ic9b1X6mv#u-xTk?m{#IU{%WoZ4aiDZq>6Owg02Fv! z+`Qz^^79q1D(;rgEnQhWw4_<_s6tErs{D++Z3PVrCKf!%pP%0>FDIMM&CJfr@#Ib@ zG?(_bhIsk}v&EnIBC4Nas%ocdm%5%tuKlPRZT!{S3%jiIzz{tbxa&LM&GC%!GJ*E|1F<_=B>PjoNb^e{YN{66EbL2o1o(Bj zMQx9CM;?xv9NjZ!aP*X@ArVX%Ydou)q{(C#GGobOKz}Yqrz1g{efsIlwewtp#x|ACAC-wM{)!N>=bX}(2Zg^yRYwBzIYRu3t z)dtlr#eMpyoo4^KqKCUyTd)}2W1;?u32F$KU?~+bZ%+$ve@#=<=4wV^IW>IWOGqY;gS5~`Q7tV^Y`b!FE~;V zQTSKksKP&r;)@E3R+Rj#7;by$Zsp71=1J*0^d_G`#MolJL2 zhwBFF)FI*O2JA-WIe7p^g;St@s{?DJWMPNs6dH?@g@M9O;WX^GD&gG-Ik}DA$>yv3 zYV&mm^xq66#*Ze^^x32_9W~00RgBjS0}b8u4(%IFOLe$1RUbQZY-8^7od0jx+AYz70V{9FBg&>qA|x71KigM3KghP&HOB(>x6M z5>gtnD`ZT_Qq53x17#=01csrvkYi-ciB7;ft_#fMC1@zRA1Fo+ z`56Vz#=#o;kb0x0KtqSz((KmgHQDNR>IbSCs>x!|xq%_*&1Af+HuTV{LM6Nc zYOH->hlxu!p=ZZ)>A|wV&cMdNkiZcCNnegH1NQof-cOz^@M2VTopJu__+byXSF`0< z*MrycrKOX_V*XH(R#8|n&FnBITRk?`nd>U^RQ68_;`|hW6R$}bNI5uLq5CPDMJ}Mu z(z(njc~iv~#YIIO#W%TLeweAwyr#R+mFR5hC6xpz3kx9OMn|crDbz#AFF!$9sao_T z+Cl%IbLmU;4!Q}QPTi!oP=lyHsOpr590t$xP#i_j@2TVP{J$w3^%&CL-^-54rUSp( zB-=%BkVm`#+YQHVI(TIw1dLa~JUKZ4J3ODqJJLJXbJxweTDYbs>=$v2 zh$3b|RgZu)^{WgQMvXk))aF6g2zD~i$!SVb${*ow@ra(twJTM(k zk;5qiQ;X>?KOjG;=&D$vI3RDX_^ohmEVnWrm`LVNsO+_eYEBdSDm|aE(_LT(2Qj?iiJvKDD1-55G3J%?4>urIlKX#93IB7sls?r zS+0OthzX4JxAKkfj`q}nYct$^!TG1NpX0rqvn{esx4yF+HIFytg)6h=Aw#O73<4ymBB8uY*5*dvSsB} zDoQNRtp!wqAskVOi6yEe$g}KTj45vsIVwDDzs4dTqkGc zuNemuBPXD`Y-cPCEk6KPXB6C(ixf)MzG*Di_K}hZW&?L zSL}yt?q$jIlD9=yic$*`3YQf)^S|eH&tH^}6($u;Eqi9xI&7}1-oFC(IiENH*$CZM zR5qKG(~IQ$;BM)w46*3?lv|sd5tC$ZnWs@>KcIyIG}jWc<3JVR;XXKAoVf_ zFwY;OAxJ++^e7Z&fwHlM__uJIJ09rov$;!M-JD09ryY@wb`G1(WUpy`Q?4rCTXMYk zRq>O;HbrqoXwiehf`WT_>+^iM_wvUTb}Bh)w%C7p+WFu58}WUG9mqBO1yO@8VN}Wr z)i!lNWmRUg@ayIy_)_-O`!J4_7e4o#*j?c z2b%{Du?E0rxeY0o8^LLq98B_0^_jeDJsaK2Tuq#99V6@yVGXjk;z!xUvWLa3OU4w> zEP7w$&)=IrE4ORze&If3m7 zch5=~36IE=Un%Az7HD) zDkc>=LfQwZ3bXiQko;UXI5W^B5b*ACop*G#-L@^Yt+octgUZuOTbE8RxmtXupjlx^ zUS57&{=3|w{P=>Qg?mesrFqt6&M5ESzCMALTw6&_bfrEqwd85?hYAz;G`#E+wBC5w!0IdOHfqO(Bxo9`ZYWyGJ+z&EN_Qn|5e(iI4~MRM1|)i%ge=Hh zv_IAWTM2n@l_b!Z15m08M+7`DLkC=GuIf&`z1VWAf`|R;)zYSAJxlT8tipftXXHJ| ze~~w}sA;jTSYNI(uXD_FeQ;-Z+6LDN)o_BmLWMB7^h0`p{Jwmt^02C(m5=Wo2|bq#Lw`s`nw1WOKj`mBzMVyQrIn==6{DX}Xk<8ydBy z59}z~YG$a+@|AQPeU9ozG2oIbghaBlaGsVy_Vx?tq%RZwpqDI!^v-@#ZE-nduV6qd zKFRxoX96Rj>zCmh>RsbL>w4};wd0NkHl_74%wwm^#+AjC?kE{sTu}U|xF$gTZdsxB z@7&>}+y(wQ!ckBm{f$4wmk?#LJCwkTVha>473Y~DOa~at7t7W3M)GgjGeU<4KzFwQ zxe2LEO_7hF%3T6E$t%#>_*vO}rir9c zx(DBI$>82Tg7iZp@WF%&(%4$l1L+owmF~{`CI85Lpbd0&>JPF%?ASU%<~)T7$RKnJ zszvVuy=MX>sSd+4h}{3<=q#YEsM{_+9jDLj;F6b4r3D0(2BoE?ySt^kyStSR38lM3 zI+VI~x?}R&@3&aXwO(H8oSFI06MO$QH&oF0d%{vt!S186<+CLUU)ZlmsoQPwSZvk} zCcj~r*g|{_g3bwkEO(Z93bjNXq;{`IW>p`gnU>NSeYdt!tAK^$?<44x6=Z%sX~CCfBEgf#OPWg;>~72bJt-ai(yPKg`X7UgKM~1~ZuX zjLQ5Q`g>*-bA=hjY+-N1kNJe3i5g!v_i zXU8T-Yew%!Zeg0bJ@D1v(tp4=58e1?_W|cCN1Ah_b0y|tlU!5X4ZUM9uiq4E6Y)ou z#wy8cpevt(-N6#d3}(?^-0!@P|5>;vY!Z7I;$nGmq;P|;$d}+Bafi6l&|)O9BiK#s zdbTDzgzb+M&U2vQxcDtd0qiCQMZ58eah9nJHrF0-LO$Y+q8dvXs~T2d({fgr~wyGD?HI1wUJ2X{s;bsMe2QJi}H)|L7oA{10VNBr$w7Y zH)6|PDpVy{A+W;l^u6{@_LM+%OuBA5J3Bu*LXIBJ6xVY1Qtus~-=7_Hg&stR_-5&e z{6?X*fL@4tib`u3+mEZjcN0pY`rU#1b3BsdRtghAS^AX^a@Ua8(+RzFE;of+%-!TT zzBzxIuPdw(t_h9B-i8;3M#e+NWYcZqV&h(86;zxDj6uUV!%Xxu5AY5+#1BTs<`iVW zc3{>(e|ZvV&O8)b#mFvDH$4G4@SffndcBYET5X0RE?X_D)Rcqq#CWe**;te4f=H9_ z{E$7g0^37N!0f;6b$R-EKA`Kq>k2!wT-RM=-3LA2_zw8@2Sx=8h3(PYSXpVkYy>fD zg5H}P0(XIj39wc8QGzB+6_X9u#V^7;;iPyBC+K*%RWpR&c>!+hD#&c4v4uH{&-a!; zAq*2+i{r!{hU&&T#^>0YPB2|G9yK;a#ed59)i50_syyL{uotJxMs69K!G30bVNTFB zX_6`e=gS_X+0+4Bd9t2{bh|syThGv@s3xTQbw)OG7pY>LK}yz%h&%iw^aOs&hk=Yh zeg6XAFmG>9vZsl=zpImTy~FHCaBOp|c2cfv*HHHc?@M1|V0iFD$P}3u{WacAeyPk+ zNxcVIg#M3CWfyS0`M&&gVGXLUOTyn^O??%r2_2xBnJvsj)yxV3z9Vnvn<9-dm+Qu_ z6ShDBRaY!;_}19h*x%%Yk9)9jfw75cChEAL;j36&+%EhsJQ4~CUAcOAo_Tf?bg?e# zG#q;UpeWo)?)iUeG>sIhi^Nc(k={*3Hlo>zL$H-5;QT860>Yyc(Vu{XITLex_L9M|6@G z=rD7G75HrY%h!pmai(q(kBisDka$G=PP{C16Mhs1;Bze&3JcHq%X}q#*1AG-G0D)z zFvbwVoG)g41*KEAd5Ec&Dao`7Jcs8{Hxx0f5$A~<`i&Y$CcVh^fzG`t)0XZ@H9`iT z2XDpuP`Lh&j6I520_E-mBn#cvW<&i^MXRDFA?t9TBmuo@jh+eTqldi^mp3mTyh`nXQ%N$ z3gg6)29JR@EjQWCKbfW)n?ZxX8Y>%)i!X&)!bstmz~L^YgfMh(5uv46)7a6t!}!*i zY<_OOX}xZ(ZaZvUZPl!W?YC_^Z5J%*miOj*mW}4R<_5->;x}S3@w%{;53-Y(UGz%& z7CNhH)DgH#v!P*ai4<(V%BkNd*>I>0zr8-bu zsbgdg2#^FS(F(*(TBqwQ`7@dpR+1CQ3VgaMg6%d{z9DgSp{tz* z&2ueBESoG}&Fjnup+l%;u5RLtl?-RapT!ihfl!#Q$*Jr{wiCOPxr6lUC-5d+qOKv! z{5AOw-ry7@)eb|)K40Ij1=Wh`E4dDmQQt%>MJs@QRw}qWaKbP9D)}mSJG#rb9Zt#F z-;wEPoqsUDbpD2dJ_Y5R_ngH&@4RRIt%Dy!*CSJ6yQNM_wl-ZaMed~>OddOtHw!^L zS2t0SerI@yQ}U`X8amNG_zrx8+r&NRlK9HJm%GaS!WsB9%%mUlUopdEK_z)%NHUf% ztuW0;kADo;Jp8tnm3KGd4OuR~p zqXmO7gkQs-llP>dry*Ay zX6R!u7%GeVgcYc@+v6QmikrX=Vvf<*pqws6uRzXS4%~~)p+>F*A43=70Gz7b@V`os zS;!GM1NXs7?N7C#(iG{W-ouzOvrMo(1k5u76!$ zT@Bo?+~vFreNO+I;K@+s$ie8^cvK>h;oDHZLTp3z7G~da*ZAQ=Yca*J5EH$=28TF6 z>@Bi*bBn@y{t34eGxIZCQ;r2`qc$exbJ*XZg6_?^uwQt}H$e?IM?8)Bpn@m$fEY(T zH%Yu9bQH$$Kl4`p2JZ!&Sm!_N>wrq`vS*;j;XdWQ=pN!} z=`H5l;GY#38>GV>BDbTzKx>q!Y*sU%?J1A!R1(uof-i+?WR{qZ`RoC)6W$eZVV6)} z*u_V|*Br++#%J8dHbbWAX=V#*@y(15W#CBmAsc2}a(-?de}Oj&5>gr!z8>1>VJqFuwVbj8~z^={;}Q*DygAb3`&v` z>I&?}GBLZj5}z4=6)P575v?9=A9)lugxiGX1+#-~184k#|7YKF@3&s7*Y17qdFHL| z>w)a|$-x1k%i()?Ys`puMT+B}YGY{myyO77I}+dLa>#=QqhS-~wPrkj%Lh7eC*e)tLRvP^KiF&pEW4>P$T#XObD@c6`!X$oDL#7uKu75&jsCvsN0f zy@AebE3(JmD8m#Asgu2B0y{wkTDau+pI|&>MXyHcN1BBvhVBQQfpLLcXlcgzXZR~( zOFlc$D0n1zCA2U6M}&)gjBS=eab<9tOUku&w zqnzQ5xB=DKErA!t^8cdS+Qa5CU6|8!5zLQ%r7FSAv6~!(%+~(67gwQzH zeh>eQHwZK;0bbb0w?U1*luO~B;dJ8JEldKlkgh{>^j>N*s0lUT<|&LH-Ka&>cIqDW z2|Ny-ZiTeHTXYIE{k70Zj>i445MGb@|Np70VaH-(MCKoSCkyD#^!IcI>f0x9`OgB2 z_h-rm9sD2UGBO5tP&-hNSL-ctCQrmYenl;UwEB)n%Umdzme)w%N!R16;}&RQd&Gp; zljt#YXtlx5bwvsyYNSlGQ*?XO7j1@g+$!;l@eUF#F8~|kw9-apkQqJ{6zCIBK))cn zQyywG{fMs1EMcx9jay|3vvpV#E3z(}!l&R=na(s}(wUq17q6qo;B?Kxzd04XRS>C` zU#U0H&)tB-c0aX?`uqR?i{v}}d_JX86sjN_olY0RbuLb4(`7M9%|dtHkgkmMH=0Xd@q`5#;k%>QzB^*h>lI1A3HtyGUP8%~@*<)*Spz9@B<;_)X)K<|gl z=@=%%hhvjtU1F7E#+VPEHWD?(ih_1|EA}~Ni?506@gCAgsg?XvZmhh(4EYv#8oRVs zpi*>0W^8Bj6Id6!D1q)opM-BG3$xzd__Q0D%gh1h64J>xGK-nkxQgYN(u__A>DTm5 zdL-QogyFh$A}u4m_6?N-x7{mb^W8+faD%!}U7<+2FsAB7>C))V>cY*_9lw4OJ&K-A zPo{sN8_+fA`gDx?m)cE@!81aEzdVwxO+JMK=Lh(_{-*#htlxsl`X{I(9JuZ))v-{5 zo7E@EIb{wK+Uh8UlxOlDbml|kZxF!p7OAr*q}5PLPm+2_12IjyEIpKp%FX3<@@HK0 zaf(N&0hK1B4bZM5{dXri;6K397(v?Lu>6-QMGr%Id3mNKvkad+moczy*@oHe57HDdCZ((Ds<3`=vrv&Yz8 z>|WN(IG8uMcUm#I^fp{^ksd@kkRmLQLx?%Z8!fFpRcm4jU0iw>PleL(N%XhqlgOyZ zv+(oK-caMXBTiaMB zmRG}pzq-|3_ zD^KO?(%(|DbO|RyJoX6=xYqGMluvpRS%a!i7ekJAgmZ|M%pI(UY?JKkZ6$4QY;6;Y zBq;XD35J9zwpte1)YkOXP(pl&3F9=jGV>!n5?-orjS zftf@8Z+<*cRpt58gSb6DA^Lr^L5z&GkG77Tk^j|J5hm&!^_lL=)Z)4uKAQ{MuO&Q9 zm}GxovnC8kT9JG-d0_J8MBcv5GSU3flwulX=p}UGiZkb^vt%8@vviFv2; zIp;Rl6Hi09Zc2r!292SL(U^Q*8$!M3cJejhfm<%J#tr6sP!8uM?6c?D=GYW-x~Z7y ztEsB-flv|7y94B2t(uytRFV6`m$^PxJ2E}AI&?oYJ+vgeGp5Ve)Jt#ye?t_ILzqKC zhWV@Yqy3}(nytI-rM+8Hr<8a~=aiXAQM@|emYSwe+V*rR-v z4bV4VlKSFrs0*h>zxbi}-T2)2ButVXM`lO%;;CC6IU3m+su!vl4Er}jqw>Z#+V{w{ z&iU9eAz#k7INb%0-B&zgg71O~l;_XGHDe2u8TujafG}5_#J9(0dA_-+Ei-w3vYuEa zVTElWCI=34BilO5ABHCUA50STxO?%wyQ8f@$A`AL66LAEZ{ozyE~sP&Vjq;iyY?)t8x4#C+mzj?m1V2LBw<@e1Gd=uQ`ujQW}9u;q><Dw3sqOu8$2%Q>ixI|0|?Q($VjMNS1x=Y=sg4C|+$^`W(xUAo(jo~F+2rW;u=+)4L(D2|={}JCf_ekfpf|YrW z+^Jt@e4UrS&^ae?AKiYlKsR5JVAE(3^$@qqFw5M@$YS!jzb z!scS;_U3`cRpKx%fF0U?ZHRhDipMTRPKL^d9)}{~<&iOADaZt`2UkQdg9Da^ieexF zn0m3BMbZ4$UOI_SewS1yF*Wg*gw#Z5VzSDPMuH4&LF}&kH%dc<-_m+!FUE&>Nr!q;)1hJ`~R!IGfn(<%Ru3pr_c&_Tmx$%p! zAK(|vj>N+I!*_z^;Ep-ued1Y*Y=gSaCC-b^!p?pUkO%!eLW?7L(Xr8wacBHHc_7h> z*(gZj4~CIq8YcRyjTKCXOohx=bDFV}&_*aP?Bs?q?WtnqF3g7%t-5v??yXe0d)ySe z99t{J4yG!FYRPE@cjj^_^nYe^}zee{}x`7&d7XM)xVUB@@M6pjI>*2oP1JJrFrnKu8yCN7lT`-LDU;v8qJB?;S)a< zeH;A_pZ#?tGx}q6cywxfsay;sItmFR2i5*+2lX4|WxgYyP#N?^WSOj>E>It+Fg1Yg zKwkh$aRa4M9P|=(sZr!MVhCmtyU9(crIt{CQZabLc+}NTSq~fHTCxYZW!yP_FnWs9>?t~ecE!9A&I|a6&4_Q0i;7)psU0xR?&wdN%c2#({ zIqWSDfuJ`8nI~h^3+RsutNHR!rGZjYDJkufGvvGRGD^pIYx!RMF?z?xxFe3lcS=>H z$;xliFUqv|OFT`-;RBrquFU}O6c(x#un)bZ*@-P$ZA=WSVs}rYvt0qEla;K6EzJ

1o>3eKO1nW=dJf%YQP9gu)5|~_?1lfXqPLT?*jDr} z)LLwEL&Qw#mTqB&=&fmz{E@k@)nfhtA8P>qjwj%*C4sP%j2z)p&{T}p9pq=?XY#pj zN9xUfaxR#4G5p+W5b#Lwy4Ju4{=WpdpOHL!9p7DP?MI@OdP6Njw3e@v7HtMOO}i<7 z(UvIl;Lp1T_amWA(hq27)NOhjwHruQjlg>Pq+i83<5UZ@nOY^HnpTx;2mivKiUHM9 zYvKvs z1080Co~4;F_1}lza~MfgQGJHek{qOF!QVGPuc=qnS`)*NDP3H>N2KW+^j4%IccSlU z)yOQhFY+fxf$VXZ{7oJSzS#+|ssBeRd=GzRU5O{FYs0ngK>8?#gj)k~PAdg=R$cl> z%|uCxlkzF`nSa#p=v|7JzN#Js@hUlw8cgibuvFR)g*EprUt!WC@^ zSq>AeabR<0(#?q8+63|uvbSI9P9%plm5WkIT1RRXa-RM}YWO)hjY`p8GcUDn)OM{6 zKDVH#UkdVC5rWp%QlGRGYB!khmGoDriJuU& zsmt2$^m#l7gv8P_)aE=kri7w?l9Av?nyD$6!Ohfaz=5X|Z`G0Dd~75$wKh1*Hj|3h zf*L1_M0vdvd4jm8{ek&wZ7qRmEa&Qs5&{cx0{%^-kc#zL->o9`2`Z0X%5|b6sB$@! z8R=2ApoE^wl_U=0id<2@fOfb~8%^e^h3T@|3b4Nd#3F4OS(`}Iwvt=5tJeu+=K|J??B)Co2Ud0hL==`yL3PqKs?m;>prDA*^t;tSJMwO z?bHJDANf4FPC2O$1@);N`5*P8{+N6&UstVaLseod`T?pvlTWTET$l^>QWqmR{5G=3 zM$&eoKWWfEkTvC7T1%X_4MDm(4pL-EEsbsg{$6u^D>;j5tnDHvs%P{sR3kX}muSd8 zrN$7GX`5b6sYWByS>egH)Om^n9W)a&w>{)$eI1_keM)&oA!u~Q>ovPR3&fi7Anbj> zwS0zC?yg#(K9Nh(2lcz!7UdjOMCnK+(3RA8Y%h{Se$`T>#QY2|-8Q-c`HAiY*5Vp@ zqV^1Pql4NJVywOw|Kbefz%ptxFxMtBYxQLy<_uSMlGD@+$P2R*hjI5`X3WwBs*1h| zIfvwaxksqOv|uIZBqL7Sq3hgy#Gofr1<_H`2n?X`&3(o_I(!0H5F%QH4xb z?U>o^mp4)owJi7z25Ag)l_*ZvqkAe3Ft2S#ep0XL&$Zs#S?vn>nOdNH(JW*a@LjX% zy2MywKG9X(PiD%6i2CYc^)1<5voqhREtzIyXTm3KCGNyC>3l7f?kEpoih+u8iFly3 zqq?aZ)!J%YKd;?=^wKB#rDEwLh{e!N~={T*tha@0)XH|j226Hmz} zYHK}{%u{0OU0lx=EwX;?5cJ=w^}y-sk7PO<`9!{iGi8)M6*I0w&~cT8 zHexi@Pst?{^gQAdZ6chi1)kT#I!CWZCi`mC`|XhNewrMnE}&Z?Uu-qFs|}G;Uka+n zlX^9MH}Q*p40lQr=#M|+IsS&aOV(HZpiAf;x`(z(dx~q5Li|Vcrg|#lpsigAPLhp` zf(2U+WRYgfAf+7DQlCf+&<7E!nu7`RM=%?%Lsi#EZ$$RhZfhUaA*73ZPrO8J{R3FU zwWy^s-r!VEIiHxP?FIj>FSx(ez%8Ao=BqTGv}B?Rbw~X`#nkd-1Ff5Sov1~v(D#ul zPVCL_veYI?B;%C@5vzlKikO0TMnlwOV>FJON8Qs^;wSPMGQ7sB1GTW;SFepZ3PBAd zj^M63tTrM{>QHsHRv#H>Td+SK2jp-% z;Pu)1Mr3M}Ai7lprS26Ni=WBY`ciD8 zmV%O;N8AOiXf~L^NjQ&M>jP2$bO1B!I(R0TWDU|!w#T;hH1P`g=&OiN+6S!`(F)p< zhF~!MFCo1mc9G}slolgD>#v|t-Kb3?lGP8|PI)#mHJrp7yt`lGyX!%!NLoLoRsp9i zOM9q|12HredF-!oT|0s^T3J62I#);CtG?2z*(4|G_o2}agowUoLA9;5YI zRZxTvXyf%xP_mVS(&8`NUH{9AJciuG(Wnr1f;qDe8}2s$+Zw8ewI%qQl5o$}hTlII zQ~5*se6Vm^>BEqh_`eKPS@#nJe02*O zkcHZ}+InS=+7=V=RmzWQB8W2&RSiC|V^A7e)QfU|*%Qx@M@UEHU*!{EhrE)PfcpI} z>hzZJ_SlbjBF`g-BXuJ*oC6)h#lovYq2S8kjbKJ7KlEEH5><&*4}-^(ABfVPNeN{?mTd=0}FV{3EZf7%Jt zcZPfXdiF0m36*t6?4F~fPL*R$AVZ>^z(d1UMI0}z0ynBFcNqKaiP&?N=L)#Ke08Yc z${lOb>C#kDY~#=96>%P1gC=$`=w3}#tI}557Ml|}80r&j82AgmmHGZB z{_H^aK+|9sXh(~|pJVg4_wDx(o`@^k9d%p0cLTGcOXMC@9`7>^w--(lQ?gV3O`4l< z&a&80fx8HlRRIkp)>_&q`4AP#-X;lL-l9B6Idf- z-f5{}O}EM*Usy$K>uK=7p%|3ZUD{)uvxwj0x(me(VZ&u8*_Ijxi46ripTL#IR&*P4 zlIg(C068tpzGU088<`0pdwoxKL}y?IX}1$Jh_k^d--XJO!Rc@c+`xD0X1QwoYlIA~ z@y9&v-9NbqXL;uXM<4i(a-3;!S$B24bKP>yaIf{8#MF3xY@E86nPZ%8|10H(jNC$H zvm%*w)2k#Mu{0J|5>=%^A=x{|eHJMSL*03vdHy@WM#y=3D%DdKg3owazf5$a&M?*a zmWH3q5v$K$GO=gk;e^MwfY~SZV&4<@ls{tG5e!VR{~s3HDEqauRCjKWm~7I`Gb~6| zG*2|WHgMtuz7CfKZMPE?5ujL~f|G!O0{F3VMMZ8Z(VF^$F2%Nn+p9c3 zgd4~f&_l@D`b=edd_^Q4sN*Z=Uf~#9Aml&FGv!nH+w$20wqR+&CP!D-RL=wd%IGPz zAHB%vPRK~lDO|mHo9uSky^5u04NCdmI-LJmdmowRE8)z`<8$lfUd-E6@U6R_uWRsi zq)Ysn+#LPkQfe>Lm!E6cW9IDTlBT7QXPT2lV3ZzETA7w<- z52Zd$d}?_wG@=&CbHXXUs;<6{M~=;|CB6mW`_c@8=l(Jrv&^w?O+1lEC9bqrw9YZ5 z8t(84Tc1&|6_`p)(_3q1wYBmYw86Oay}SW?)m7>obk_6f`=H0X5iW_Zp=`_$8gp69 z6LNsQSXHH)v5L{;a8BTrZ-Qr!tGDB9LBoP}`C`Fu1wRzT3o;#?bGd7tCm7fro2h5; zk1W2F7FnN)H7J=}YGBDy#a|a{momySf!(R>2qk$c6g11t$~lnZ%e5BFaenV9;x8Jq zMxV#?6j^^uWpVl9MYG3VkX$o;S>~QX_cH6JS4>`GttgJ6hRHp`FT5^iLBah3n`?_V zEj&`5OrGSQnQq(OCU#8Tm3%zOn)s*fqPetjuTY8`%JB4a@*wKC9a6)s$$yBBXS49{hwl$PBR1=@@GW(G3LUz-SsR_z`X@0z8taU^V77x_$9&j;zs1zgi(uMfK#{HJX3GU=+X(uu= zGYis(rM^oro7ZwL^rDEZuk2arIGZ2Nf8*HXc^M?7(L@0UZMe;r6in`#lAqitsd&O- z>u#i#7H|`pA=FBEFQTB^ri1seU1}vINC&0m@Dy*>1`&Ix`OHwHz0`;DCr31ZVB=lNtfoM;$tsg)20kX;<@A&pc0f!sgrcqo@VWA`VUTB~6|J z0`xYLVa%My|0u@9Ct@?vC}eRA^9`!Z!`f_hv+_i~AZ5kVqVq$S{65boS50SkM|nqc zhtsjqdDZE3wsJLd1)SOLv7V~Fs-e^ITY6`Hp>=GkKP#vBl#;beZY(}2t5?bfi_Gp* z--ZWypEx??t<6z$`sBSVc;bqAUj~|ok4JCDFUnm(uuozKiAc6g7@IOBeM08(%%lu! zYX5`|rkQLB&5g70j`LZ5R{p1gA?`bYS@9I20T(n*v>iEzgWw z#AEyyY==$sW>O@IYtxkn(%N|G_~!U+sg=?X-ED;G$X4Ws2)&V8RY9zS({BM&kuqWP za8)@Z7f4^?KZ5MCC0NVX#GUUD3##Q?@?YkkEZFDx&-uk^bbWA+bq;ZUaF%pe_eue8 zEK6U>HMgdvKFRv4`0niX*&m8pGhK;$jGgGy@}|&2FX>uQ@H+pO0=u)CySukipiC$n zei5~RR5epwN5q&%Vz#wM(#14mp%q!hvhHU*N*QK9WLQG40rk12ueYEN4P6!Y{!R<_w5%60k%GI$$9U#4w zbCG;;kRHq}fu8n^VZPxsG=}B*T42BJLS967&^LAYlJsl5QuMdbR{u|)@11D{_49(c zYx271S9C0Nc5xMP>#jYnOHk97!A7B@yKUfg^s{!0_gmMdl`bNdu$7!uVoj0r>0|8e z#7X3hcz*DJ_b=B1XH(}7uIKKG-Xi{Efg8b3p(){>k?&(?q;1-7ba!#5wQF*@jP}T{ zotxDrV^*@sn#-Rg?!|u(+Pp3?IBSCg4X zS;!OmXl*m@@e%SXa7C%`1Ajly1E-$-9lc6iVXYh%=g{j;FraYZNbs&EyB;BB&S0BhF<6R;<0Z+0&u)K@NPG1$yU-%1%qlDucV*YkrivOe zb|>sIG@*Z1)JSeX^WO3_@Qn8)dI?{?Z=#=pkA4<(c~?W%qnvVt>?L%zT2e-V@?$Cb zuyB>k&&h49HHDgFNcuhex4#qiRa<;RgR`Sig`w_nVZ#A4VeMdTX$@K~SOh54o7ocV zW9*}B8!cr_mxPZ@ZPKCorE{@o(E+&DC*mP&;92S+9M59~gSbuT!#`)&F)~$%T819G zH+E;|RZ*!FCnHY+KX@};UkYC2W#A5foA*b4QHRTU-}RIGk=u`&w3}<7YoMo8phh&L z{mY%Q?n+x!xNM2b*%yju7ulOO)$R}*Q-yJ*|AE5Y>RsS@1*XFhuf^8|Jc};AHNKht zS;6;_J92<5E&gQVQ|lFKRitUrg~-vBt5{@phofk0lz? z2Urt`9_M%u_Y_3q5%fkXPX0$8M$c6nsrL1hm$ACxvi`a5k&b@(A99!FI&*90_bfj6`sXiy4C(QM_OUY4%Oz2XzZrJ_@uY8VVgT~lj!d+9`b(w)RL_TMd6 z4V5{O`birs&yU}YIbtU1mHbgnB)U>v;dflXwc&QNKQeEq3*-&(<(6YQQ5+dw9WB6w_Q%?%Z9HfbDX)0@xM94&BC_+DGd~7=3 z5FO-ftD`m6#73A|eR8M*ow6}Aa z!pt0OI2ORw-i{hV9ihs=iMv5Rr<9B5hHCjZ*Y^dvxpK~GZ~?mHZOQNLc!eJ8H`f4H zz_Hp90OP^z+~EGsUoN&szbU-7=VnwdwzcGnl1GX+%_>O#)m)gHqc@h1#Ck?YMIJ=X zg}a30zzpA3Pjz=iSAEwe_ZzPm8W0cbrTG_@dC5#>df_XD&lNh6mXKJ^yoX(f){t@J%wfqt<*a|jg6{cMaa%l*T4VH9$(_BGx#T*}9~eEAb{ zSDDH z&18QguAt{@i{$h1&anpYR-Fje_6IzZT}vIs9N#(4JD+-n20leAY82bmR4nmx+WJD5 zvtAYIn%*nvuBET|km*BC)yrs)k#buFx&t}dGWu8a2c#%^^oGm`{;{E&<(;i};*MlF z<$cQe{PzQ^9op~r|O(5SUz2XcG3liUs{uPbpk z*@w({x;$A|J13ov%nO|NoN>O%f0tV;_fXD_oL=za56$!Emo8Xd5H6_g$adU!q`SI$ z2K&232CEmDapn!ll?!bwy0m!HqMHg8PFZiAEo9PRt%=eHtfgkL^O3Wm#(`DdlJ0HJ zzZ^#ARTtx15qc+$B>xdiw)B)U8L3%@tVbDJQ<4un`GiY1eYEB$5~&B${!71?UMKBVa&p2evqikb%piyA#gLxiQ4N|wYoVQ1 zKdFCd9%KbRr600?@K3=I-zBDkb0G6^?l1NmW&mcKW97cl&cTYlDXugJn*H3PIc;+u z<&?_xphNnU|D7Y>InC*Fc64=cop3+!>VdJbvO038tzT0{6}nYqN|8XJ{8V#7VI$2g zBdyv6r~~fDKgPC2!a>oW;ZdA(9UUAKokmZ)K$+No+92+{Igq$At#Rh+%t`4llN;D; zn0g2ddjr0%zc9y`F0YT33eOCD2H$T@Xg~5_0!+Gbp{+7BN`uosq;F3DF74;!m-cFw zLk5Q5k4zK4UR|56lIl~rr#xN$T|SNdh!d$Qh0$dkfp4j>fiv8If8sZOBRFk;l3TPb za+{bG`p19CV{v5_yw6R_nenyB*PO41@QJe=v!qTw2Z`$_uTn{gMx7Y%Y{ps6qO$#FCa@gqidiY*N~4 zx#~^jc)VDoVj#;i-BGU~FaMdNzGrP6zOpf4XyxvxfV<`?yPSx!nWYojjwxE&R)aXCjPz zhU_cMwa!YhX7tNEo>4!!v30hg1bd3?M$q`w9~3p-KQcdX)PuIHU`auSbH4j;Uw(Ll zJe!;)^tF6WD3Nk7WmNJ+sJWcxTgHlpT%jF5gLy~{RW?T@q=bJ0)3IpaMx>ltnCWNu z&sr&|Wa_K5`ROIom!$rZ(lzO#-DsU_dM&QxWNLvvKv^6w5qlL~4p!;`wKn0Py0TN@ z=PJtYh2uryZ(v4LRV*jY6|QmJk^Xv!c&ggvS@Hgyx;GkGM1?8z|p zsI6kzFSX8yOl3t=KqG zPT^vkgROmEJvY3E1I406Rf$?7&a_;%&q=D6G9zVj%BJKoN!Jtqwa>7&HG>$zJ)stB z{p4EltI_+>udx#HcGa!_MGmDq!{ZW!il+p37#^_a!ZFaNCh^zUrc4=Z)T(IZ?0Rv_7$A z>cNbRjI}A-ZSRd%{v0-RPoN+fr!P^i#+HV5dvly)3x3OQRgml|=1mEemev!ax!y+2 zvc^s%T;ya>WBK4zuya*-+ef9lW~gS8T!upOm(Uqk*t+dYDy)dKZPFn0&dP}Ef`$TDSu-@QOrkv@ND*e3=|54 z0$L~qbN+I0@-zu*QHSgzYl(%d*+~;p&ZT@zGTB9wUpUR)hx_CVIUJjEq@u>Y3%3uH z_4a{N*5>)scP5k|?bVwwZlS*EyxC}ZVJU9eWxff$8E5u_o^zD@kLs?~mi*xvfnC0m zzRQ96k#@>3`WFA2Ic7a+yKWzAZvv)~H!(Y@ZsHgFXv;1`7fvCY>x1OxaVowd_AcID zDXN8tebf>x{?+0f3pLs;zuij1mhT-31FXSR*DSm|MwCzda;iOZE>utA8 z`^4Y*(%i4m-O+SmvWW5^x;(hi*Vglsd$D_mhw^8JEU_g@K5>(+FVr?HGCnpvH#y9- zb-eYob*1@l!%6NtN>wk$2SzrA3Ib;Wk?`7hHMJ{slB*)R41bs$rrnl5Y%cqS#9B$0 z6WiE(o6m@=*j?msZLV@uY9uX|ODkQ}CVEL^20kU;6SK+MG@3(xh?oOMW(3T*(cDgU zKGPmMpF+e?wV5<1@=dU?pYrCqinthbPc>a1e7Agt&@)sQ<-;dK7Xv;0Z3EkVIo?Bo z$$`<~C$TJT0=-P|m_4?N3EB1^ZS&1P85Drn(}-bs^uX8Imd53qHqg(f85^ zHJv`m7I1@wf#8OWFy>mO+R7z(5&SPHh))6+V^0 zaM9EvM$u*2$w-i&gg3R9zr=mRKVlCsUFlQgIZc#X#2SPv2OsKToHCGi&&@amgBJv zk)NP;trRf$xA+?R@B9A?yb2x-AC0z&jtkF@E(ui#R}BshWC!c`8X>2omG}F=S-%h} z5uGNlB*JtdF=DKUt~PGnVqIe?XXZ?^#jo%I-=_8wBa}Awf>+m^-3aB`Ue(FN$vD5+F@_d~jyOR^KbwAB) z;ChL@jTZ9<%VOIhaLV^sFI(E0CW<9_k$Dbk(gY+qDWGSZKa;H(7#Sg+!kN%m6rPT2ud!{fWNtUi_r=lV`+-#%f2Gg?|j!4F!U1=uYr1+zCs= zLUcvg6fGCFhP#E#!L>M%p9W|854cad+c;J_zBr~k$Kr(lF)~mp(8kfN`LTuqV-52{ z%R4yn*I4ISDp^XJN(z;^bf|7N0iLxlG&GPEyouT1)?hT)D>66wmo!3E^+NO<_L86* z4w+V%$C*X*RYOg2G5;sK0IZyy{I`M~bp1Ch%hqDf(A}6Z@X=2J<)?`8xG@v-gB@Z| zkS=!#toTmY4kl(#ka9fmrru=QA))&-vSH^le=@%^^`W*+CPR9r{-?S`F)9=^=jEmK z$X}TjYZlFj{uCJ&ISGg5`B+1EX#a~bv8z#2m=Es@uJ&&UeCu83E$qp2)pGKVZuu?p zc^Bbr6+9Gsp*EswBYEb!fi?eTDrfoD>N029_LzH_wup1M0dzSnCRf3qy%Q&NIMg@T zCH!x=P^3bPjmN?6TTgtX{$}p*RfGbuzoC}c1Um5qUc!54x!B9F16@OR5Q|j4yHJ|n z0MhGrZnfYRz77vwm? zukFPq^QgRDj$r{@U#28id@FR)TcY)%PorItcEHB;$X^(&k@` zm&DubNxHJ$PwghJj^)KH(bchH@j3A=Qh(_aRQ_d^@6=lQak2`X1J`PQ?hDtN3vx4r z9AT(o7pRzXEgh_NtxGN6TN;@9A!XoK!*#=KIF|=ucR`BFdAl$f1c4YgmcIzDZ2{kh z-_9T7d+-j9M;;H)X0r{zin>4*hN9)I_Ji63^#8|_3vRLY(j4gyvL}>yV@Z!a1&ixq zY)AZf`~-4Ha-uC`yCc(Mv*10Ag_1%ag3Ek$f@8cJJrVa?=T4W&+09whF~c#}(>8D_ zURS$7Jb}x*fiPUmFsv}t7h4%eh*wP;jJ1sC#L09;qyz?$tCA<@D%ZdfD5>06{#E8n zpOlJ{qLds>^AcbIPD^r*u( zWNVYnZp;ki@|dE0KlW$#52iU&kzK;T=!CqvgGfRg5B*RTZH=;BS*L7Mw?O5! zP3^4AP%cWFq&4v!$R8OM`x5^RNpS~34A>Ms4bA+a*gkj>%ENc=jkT8M#Lh=YL@I?# zMJ7gyM+S#(g~tX*1lIPU_0fQ0ADP|Gx>YC|&! zJsZ5S0{w~HL5?BE>UA&$Q9*uQqL* z&lKcoPSPhr?Nfo6gG|an&_@96u-dVnd{}@hR%r z_<{KN`2P|09bi%vTiYEcZjQi`bIxKwkfU!{3 z46+YlRJ#!U+{QU%EtF0p*ZM&xOPc6hZ(QO0Y+d7QGOu!0A-Z!6V(soU9&o~lVeerChTPUE^>4~M{M*qoXJ$*_loi`JlNkVz2(ln@z`nMSLsdP(7*)DJf;S& zQ5Q=|eY<{*(!-Daq3k>S?WM)O&ThNF6e}Oz*H4^`zz6PA@=zn>v)%WxN}l%D+B=Oo z^1Z4fH<0cz+hMorHx$X9j1#4YOaI!2?~uCDU2KiO*~mXh&ziSdzbm87PwlH<6>hbE zOTML=&WG9^POjM%@zLW!J!wQ4j!s^VlSO;VW7I3t*WzSNoM>hsX5XBxJ)in6y-V7n z+@AizF44bHM@4Vfo;7<)uNbRb180<9D)ln|l&)|_x>v-dz@N^)`n!mjE|LzLA@x3| z)HlhsttbDHC=nV4LTAv}lzjtVJx|caF*wMYq z4#i*3XyzPHTgCfl6uK4eQZo>E&S+)*M`>a_l6f%N6B!jhs)eb8(n{l3CBR#{&s2P;FhU)=ZV9?CwE zKJGM)82%Tv9r0;WF+3;(ogqeJb*%cGJ3Vcv`<;enkz9rhfsxtOmCyTQh~OKZm>6!P zb;jv_ioaL-3MHt_l`Nv^T$>py8%R(8lo_H1Bz^%L6UjMk|=)(p2u``d|W6Vjcva$hL^acWBN ze&6nBu5o#OmuN5NZugErxwIxS4ZFW|H#W)lx=DGrcGP%I>6iSr;6hEMJ6!%DYh&dw z^@8s;^Ag{~*u}Vecv>C@EYz(e}f>R@n{TTN z(530+VL9?+pqF-K`cZoXw9SU}r#ZgLbNO4U|Hz#0Zi}DIe5PtuIL~-ryWcuttqz&- zx6R6owRV=Jnm2`qyXWjh$;Ilof#b0`$msdNe>|F_N_vjmLisp$yIL0NkI131`mT%@ z>xQML`0Lc4Q@1FuCzXu0(&^|je@nL@&?x?}-7l{^+mY+*7U8)opOHQ^4L!8X|=!hb=~*DYZ0xCEFOE9T8cAB zeacqn18co~67h9I+`2@|@SWDf>V2_28382~zb$lq>YZ#~#mUUmwv~KZx0TUWu5l!5 zT+O7cO0!q`EO1fg++ofN>@aRoUpG$5|M-83Es?Kw_W8BMR(ZPiGBVhPnMcF*-NDA4 zb~*N^e*m(hk$xaJ^Zac&w$;d7;qJqEE-NAH$0zk zz3X-h{$$RvS4B@1E~@?7eLtZW9I)R>jCNnwj;j@Q;o$qy522Fir#NB$Lgv}{jdfS$ zH1$27yez3?U*|lL$k1#3@6~1GK41NPa2-xuEx<{4Z-?5|d=m6q1ASebtk|jC5@UNZ zDTn2`YB$4`zVLnQG&Rr5mFbNcqmo~0yA$IwN1LIf<3FDoq@FZ;`0q~j@=wGmj1M55 ze5R6{y4(Mf>GLVhbaR8Z|`|Li4I-WHm~?tnZBN$hwf6A*EeQ+Ac3jFGP;Stc-clX4+!+Sv4d*S=FxL4~h9^ zCuy8I+V@cMHh;s&;!MB%#)VYg1f1NkKM~Zg4qj*FrQUKz=dP=JR~>IH3Vfv885yIm z4<9sF*8H0_UT*B(hZAf!=Uf)~QeAEQll5d|p8m4aSGmi#Ju*Jy&-fpq!O;Q1hwU!d zopQ1{E;~`vKWG?>{68kIch>1q`vv{}h?bj~_{+Dh?&<7r4Ot%LuC+GiepV_Ku8BPt zT4m3u3p8j{(J58Z<99yRcn1A}l1R7T^3J=6O%tC!@`VrqJ(v^C7?0q>m*q}}dFEm$U_tr)J4Ydc00jT#``M9}58m-3C9Rh8U^D@Am z?*GZ?;{IXO3;vXz)20O^Xj$)ip-lFzdBXQQ9q4N%k7*TXy2nI+>*c(HNWO;ty`U` z$1cD#b=c7TBV(PyrxFVSrz1|t&wce4-b6$+C$+|iIv~~d2RFT@jc;0$D zqfKOfc4_=2<(%Bv{liyYSDlln>Yll$a&6I~+A4jm)XTm)Jo>`C@Q2B}gBz^3t*0{o zj_&a-N+spH5Ox0&vglvR3Dq3c2Vu`Nzqw7jpd3yv2>fB+>t1esXc_&fitMTuPwhnYPMrn7OQ{2z=c8PZa6Wz`3ZD#*qqj-@rEqFFk5$vPt<^ji4 z9+J;UD~+yt`?x=QW%UC@Dq_B4G(s-dsQzBw4L zZRa-0ek;0Je>wVfULY|tx-2hf{_a@uV3wAelW}kK>D225i=q|CL>%Kwx=*G$DkW~c zV4vzg^>)~?=6QQadRx{D$pUR$sz=t>$tR5mus=rk)JAEq_KbX0x`8n#yHV{l-wF3G z>&nocR9p2M>-XT2)XyozTObCmgSAhar4BG>%GJj7Km+OX#42C4nyoBN)XR7_`f29s z$V1^x$&pwwYX?bwHky~U<3e6$z`0Rr@2pJh&i*6T+W#DO=J`t=t9)e?N_Y5QmB*&? zZmi;|5oY;Nmai>#&dny)ZR@!5$uyLwchi$ zPH(h(gj-sN%-gi#*yr-3GfA1H{h53{_+G4G5PQnF)!5~3UGQk^kRCVo_=mZx(^m(c zwj0AneMM=iylRwVkDLme^tVB6AmysJS~Jv<#>aZ8oM|0*4{3jyImqD|E=g*I*-NcT z)z|NG{xariyVN`6bE&2JuhwI}uIU-}*Zv9VOl46_%6ij!OG?ORQqzOe(??~=?yS8O z*qHjn{R-!&pLS-sZ)<-!W0g7PNLy3i(?;1(S)GIaXrZs#8k@Q~*x&kEi=y4tvGKl_ z%xZU_@|W7e>1(voZg5e{&f6BX_x;o zcUZa;vR^+-O%HWVy^n0V+w|GyHuv68+f*OvAOGBRuKsa4$N!QQN(Oz^=6GLCw4NT9 zXQ(HWcT1Bq`qj)1El01G=(FTq$p@tkfjd*5;H1ad+W17BKW%S!y1RE_ZL6U6)7%E} zcLKdr(*mR1F4j5Cjc(0;(zwGJVSfT^dqz6nzdd~}TC*WKmX4-IB~fbyIVmsu>%wmHXYC*gW5#Ic?7_$T^y5kU1%0>T@-{W?8*I%1`8s zs(B!9WVBRW?OzqY$=@{gUDgNjfSySGkGVb@k75O{6S-O#X%=$w9kCtuTW8bTM z<2z^WOdJb;h;@xD6WOTpAv;4ZHunT)xo6{-=#8~MlrAwl_+B)c)wE_>sL-0@Zr8Qg z+MEZY;n4nMUu2hkXf5-NNn9Q*jb0y;>}AGc$C0Key<-(w!{yy(Gv96Y<%(_G;ydKr zY=_M0zR9t1`p?)0?3MIS`eEs**sB>s($&&lWE{UBy`Bj8cT01P@ybH%Zq&SPeauW=(TD&cKs?&&sUoGhFWK(J?Z;>Es_rf|A;2S`Q`vf z&qDK3^(kqE^0?jIf5N!hzc%$lC|a{MJlN?TuL$%>ZqNEH@?&^+dal3HXyfK+HP|~p zg}8*S(rl%lInUapEOBns3d|d|dD3QCR^PQw`RAlsN!8Aq?i2DcXOY~~J?iglw3p`? zHTo#zm1Mdv2H&>Nn1A79xRA>?^=+U{S_7ZYkr!dVh)}4JGSK-eVdUNzo1pYJ zhq#UPmT@Ow#z$lZlMg5@5iNHX@MM74jr z*;{?fZRRfY*BZLrEU+l`or}}Yt?TruM$&gUHQslZxy`xFc}6XlhN}M}i+UmaqF1Qj zMVI>f8H0VFBVwm%uyN#e-yCz4e{=FjrvujK7f50EnBGEp1ZUhnDL<+_X01;B$9F5T zyxuZT`rflXlZx&B>aW^1<7vGvy*~JTe6YSJUE`Y~^^Q5dzwJ-874hrT;p#p&YL3S| zd6>LWTH-twcsc#LKH52ARU)hKxbJxUGQZ#a75nHs?|vcoPzI(qs@3l8_J8GNW)EMc zHC6rDeo(zl-el~vTIn;5@z@`*R64D^Z$IMu&i)&7J6&6zT(4h~s`9-Jo>+qH!A$j^ zRKj=6-i}$uCgi~mgjeFQ{7-tU`k?W;y+JNiW$c~MT^b7+x>@O?9uauPc8%xx3_ecAndzCUuinxER*)?U5PTJP+DkK#A&O{=+bOj;rTV+@zO z*}c>fvx2)2VlHx>^D6Yqc_8W#q!5iE9qJ1Yr8Ex1m7#aTUW|` z>W(v%RXiF{W%W+VRw zC$^ht{S9}YI?-O>F0togoOijmxP#Pz?nY#HHdg9OVY!cMBBJ$}G{GI=+~$6RjCR*< zq3Fmroh3U`U+1KIKh9^%dTcMWz{-{qF#M3!r@C&x|+ zo*~>f9TO);tjBID`;e=85ND9QCVh{b`j6Z$$lc$IJm(zXBtM-gdO-2$Qp0w<{=Ash4dM6@p?%P&iGh>oe+P*9vaQCzFCePL~5`v zz!lhiWCeC4?B!M?>sFIW-DcRM{A28N-9q{q9Q~a;3j2CnZaeIqyjgl5)@KNNx@<+u zQ?PSs6R8LGDGEzZxK)VVJ0xY`gug0hBSu#%Rbk)y$FL7z0=rr~E6;T;?9`Km9qsQx zD}AwV|FzP$pvbGRi^$tJmwONTy3$>Ynl6*DSp~9hJ4j(<{F-hPcMRIhkUvMS54j86 zU%+)A;UtzT;lIe2b|X9fI_`jgJlzv+bM$&0vUh(5rM8v6$Nnz)z{+OO=LRW-{^O)^ z_e0Qr6+F(B@MtD60_an4@|`$^JngE;cc1F&(m^*H`S?4L!}x{!wfi3_2m5O@LI(d6 zIHh~4R0H@olP*Jk_wRtgZ|;@2+E02{y5C&`oc@iR{Xg*Ph232qm&O4CzoDKd-5=ek zd$}}2nk~HooL>g0FGia;A#*{;v+v)NsXI@W>5_XFdrko|uUxGKf|b31Y6AYQw0Z58%q*pAmP?!S0f11j6` z)VHW}0chZ`yw(&+0E)Le>oPGaAUlPFyQ497ufNnE#a|0Mc&9#>pYZxH)C7XmT` zfVdy<@q_k#fS3=|rh~_@;VDKH#+6L`$^|Di2PL*Z%vTesv(yzbt{bq?Q|gY-4!GJJ zyU8@fs{!gH9*gn28R)1h$}|S0*F%XmxYiXVE|uEjPJ3Xc4XzWb{M7{gX@vgd;t_AgQ`fSr}3`CQ`njsPuHMi zIVi9c-_C<(Bfx(YwV8NY`v3JMP`imI4Ah^-a}Fp&MK3jsT}G*xM+0&E7q!N`nhabw zv5RTqe^1q-jcT-4=3YQs)#ybX%CVo25O_9$|29Tpp>7G|lmP`F;Dm%4yo>Qa7tqK6 zOb8#o2q_7?okV{mK*_~33VP)CV2urn0WYl90`6IO76FE!nrytt|ApSW82HS`byAYJ zmy5C)-a8Yuvn3Vd)bUNhH~wb76c7Fqz8h#Ij$TWUH7SgkQYDQm3BWOdo+UB5TKvau zMtHHbg&w*1FJq*vli2iOEF3fI5w)urg&&lXi{6Ahis5MexUb;9j8Bf7V~?Z$TJL`Z zkgY{ad`CST@J<^Ni)PL?^d8}1-=opmH6b7XV5|!+9`z;tH7($t2YT4AZ-uOG65+QurL4% z72_gj_&vS|14~)JHgP2=pg!=#U!;L-yf2o>$NeC#iq9-i3gtGrix4F>xM(BbaghdU zka0!4RFq;#1f1eC56^|Y?}9r8m67jMyaK3~Jf7#_BM+_Rdu>v8kQiw5BKHeS^H&1bqi9u7DEWj^ z!vYTwQ;zrIm}PI&l%fH^Aq<%Cg^)vB5wIn+a{zbZya4>0i~o7JQh;|R-a!wdnI2Aw zVd8_eUF0z#KZuhc@JES22|+nfg8!5Vl=!p?SQ}+V9&nb2D+Q27^>D8Q{~O>}4t|&5 zGYj{LxiD(vn>3Y!tAbY9Hm{R^8Md|tk0tO;g7&b15$+Qs@IbjuX&&=_Qv%aM!a2j9OcC;f z_~sLQj`Eq`MZMH>#Ba*Ow1KC1N9YI2BuWWNKYk}QPy!-Y(4!)@DCD8g7b##}MgJVH zZE6C_RJO=I#qo+HdsA8vZ-O65jg;1;P}TcJekBc(_k|qS@H}-RHHV;O_MP;t zqgCpAmg9@mOnpLH-6nKV0H5R)>N{dM4!BdJ@JU(%lBe}Z z=h#sSSaVRu#EW_<4lGoAl9@8U4Dx_lfD(he9K|~Z98oWmzbkRS+>`1;-c!EwIhKfe z7^kg3T}_>{<#N zuN~~Y4zM;`!-8yqS4-Ss9jvRwtE&nAH^SZ4DBB)(XIp%>fpysvwqqOEjLp4QOV47Y zMOuXN#h!InAD`s!B3z}EA>ERfsaf*Sk35fhDXBwE4)IgqeV)(WE;ZrBE+f1W55|TI`v1bM;Yb7HZ{Mn zVWOV2Aa*JHNi{5c4*KB&{+GeFD8)BYNI9;PPO9;Z|C68`mK7E!t&dVqD^W9@LK)fz ztcTyuqYm~)=rOiYG7|0-?%J(j9r3a)OowWqL7iOFE-IC$6beghfJ) zLYskFin55 zBZi1o>UeTIEfH#QN_I*kN_*mnwnzme!9~5#E0q1Tl2a%_TY&Z}pQm1-JfH?*Jp#XR zv_vVydIUX)*^7`mLgOGZ0)22$DE62hJi0y7`#S~rDaY_(g(srBQ)m(cjNsD)UIpuTe+O>yQ>g>p4`Q&!L(q(q^v;paJg zQa7D}^-3G|G@d?$|I}TjxN;iQPzmo^CG6Eo(B)~A#vVXkkEt1L^pW#n4R2aFqz%qr zO%LnT9;ydf-Gdk9Z3s{i@>0yTY4;0VOv^^beQJC18##$wO36w}AZL=sgdO3)XOe~Q zDR`1xc*H2Zb?*vG=K#NgRvH7}v@n|k=hQA8Kq1sx!s@)JmpHQ!_{>HvY$+QUEJEox z+Myoc4Ay{GJQFqP_>~DdbKQFAS3P`Yfp&7dzEF|{@oYBwQ;2V3UR4j}i@fKGKy#$~ zTs)bJr-Gims0Pfc0l9={pHb3^dA69T(muhq_JAt6{~`t}Jb6wzE@qDOI#5o~HlY=& zc{K}uCzS}BgA_pxKwhWEf|7yMh=nCQ$rt&Vm}MR0aN3Eiw+jDbxK4jn9Odi0d&Itl zUhyvtd?nziVcWF94ESs4MI@FKZ`M$ghM)0wYa%@DiQ{dOyhX)L=privPlELCcW-JU&S|mWf`|YT#To zg!ZVN3h-PbK#abRVti&`Tv;eb?~emml0#^t(L&>x32RD=THu4Sx(u93k3$(a_?*Y1 z^fZzx$~_t(Z__6sd=8ZGv;g>KIbk8tLrELq9O@{=ljlGmv`9!Nf>MMmp%;u+!#S@` z{t_OrI($YzCBhOF`iip@dQM2Q)JC*rDd}i23kzG!9I}A{VwJGS1ysoG)U%WVw6#fz z65vRyo$`aQ77``H`y{++6U(TVUOVa>>ULh005O&&O_3%z=cG52c4lMzZw$y&hZKP- z>Up|@)@&2ViWcDXW}cNvtF@^|EzQ01v}8$F^ec$@jnHM(3>=4;KMKDJe{r@-e;VZz zKgFC%=u19BYNMvg0`^Jaf`Vx+5iYD#%!^d-$xJ|iGZjLGI-E8cDN9)YLQaTTAT2u1 zApUQT$S26p)ZToNCg}^{?3FV>Y6|+4s0m6z4YWe&(V&!}hl8^edbCe?cWF_QYEFCQ zXp^4s=;@?Kd88lGB{dLxaZzW{Mxo50WMaRC9-%Em`$EjZNawWCXo*uBa2_YF&<-G{ zk$z}l2s)y)q|Bo3pj}7L1))Y+BIG)$LCkYFf8&@r2d8C1SW>%Cr_)bFp8#jCwEgKj z78ahcRE2+&@`aud+870pB%IsOZyLgNN)(pO1fGc_RF6NlMvH)$CcbIkP_D554(-Xt(b%dap6y;<`F(~_A2tfUa{Ylv>c7_tI(dN6eM)WZG;bLfl`-{rDPS} zk|ZQ8;mJ4g#`1)qkS)Y6wTPez&d+GSafTsmdD0u}Cywappj;*=Qu9+&(4uGzNg$+# z@JluD^ay7k)MFB`K|14nf>tNB3TcEjajqh~PV{qeZb`Z#UsGCeKEwGd>!*a`+?w;} z5>Ga96{p0jp#-vtzA~;9(R0O>jQV(zs~YSlzfq1+h6oyF{Tzqz<#KewlBS#>9yp&P zmD2hWa}~~0Xw3<46u*%!1lJ}|jx?_X0@o)kG?*_R(OPnpNsRxBw&#QuL;|q7@>@&J;-?`|BKlvaYq`UzmU>_ zGYa}CI1}N_fV!XhfE-UdgL4FOH>p7QJn7q_>?h|73xjJRN1-iFdb8Vum{T0`esj%& zwFzH2ec!_0!PyQi750oa27Aj{F8vUklXF(hIVI&P`A+{I_3$?-l(a;-L8(c|5E}G% z3R%y2KIw=QLCGR?AZ4MT4nkVU5;6Or%n=?Pt{c+pA>%(K2{~O*0`Wk-KwCga5lRnQ zniqXexgK3mx(FXLWdrY%)9L>f5<@&s{Y9&V7#A~Y$_v`ClpJi6k}S(BBW6Z?j&okF zvKg3uDAZ9!>6eREO2&^9>Zy-#{H#I>^uXbwS-^m9>L33?*k(mzi6 zqK8V%eW|gikAyWNEFsP|MgQrcCRbBSh`AJXqmKK+gCpjPlrrQXave1*Ax>%{FACfW zDA12VFHkNdY&}mFQyvKoL|cWNP2QzkAoq$@4Ei%!n%>D`Kt71iCU`f7RmoNGw)jTx zMPrm{=;5E3`i8W_ zIT_dO#cB#iORojJ$wCv6iuj%KpKxak3^lBfvb11Hf3!%s^5UQraYD=rD~w!89Y_wQ z1wma;d!6!&T2@#n0w*e@H2qPOKh=;CLQb6VIGr;Z>I$y5u`k50uns7H*S9I@Iai`o;LO+sf0GaV=szV1XR6fxq;^VH_KQAp&NetlrHw@y z&8J8a)WhUCLX)#vVISn+H_jbFuecJgq6Am2$mR5A(zc^LNIg=7dn`+>M4L10)xb4L zwnYntE1twHu}ce*ePNAaeTr-6gd<;^>#*0HTciR3cS!*k*F@P9`YKtH+?Ngd;wooDlx07J#Sj8Tl;HPCjH&=-Gax_Op{8bd zHN~q4S7g1s zaZ>x!lR}&40!pMnvy_|k&f2&V!?`4!*VEpkoKB+^O149|M;cN+T4IC>d80lepLzgC ztpI7-Bed9bJjovPMw}F9ta*TD6-K3^)L>wu1aVf)u%1Bdke_I=W}psglg=pP2W81P zi>@56GK`A7(Qvmda7isi?_Lz}=#IE6>K8>S0|ZZb9MJ$*J0lu!0P0cTsbZ`{4e-_s zZRL2nik4>yXj%g=bVIH5x{{iz(Vs%pO`l%|{MHuWz%>Q>gE^O>=4uTL90jyEPWGw* z<68Gu65xoXIRB$A zpyt&P2YCh~iQ}0(>;u#ZSDT^^%HhVqix2RjhdK|KqlBd=Aq($w=&Oc4b1q69#We;C z`(?F3y#YY)60}o-Z?%94y|0;odNJxDCaXMuavbrOUC>HPoT1bgS3BUlnDJzw?4^JR zrAQ~#$n}ptz-UvvuJq#f+Tto_i#nofxvEr({x`&$X-`17En24SLm91L3|C4+5ovlE zYBNxd@rVgT&0d06Gt}M))U*w7rS~s|nBPuboL?KX-bT6xwNzoW#lT4dQNm4u+2()@ zeZ8Eumx6X{aMBOCg!U7?1OEa;QAFYnK|eWL4N6gzs0PO~f{~v5C}Jnqx!>Sw6`rH+ zYXTV19-#)kfI6CkM!Nzc7eGDK^IYS|$9QWzDX{)A&{aiB-=lxecX{LP8XyObC7a z7aDC1p5G6>*8s5Sg_Z_@pZ~$tPjSnmj0D+4^saUFRZxPA+;kOhsvS#lgy#MrJ3 zV5b11?+z@`eya8;sWWg>j9MyyM_R&^askj_8>t`Yu_+*6q1UB=c_r3SE(dlBK!06O ztBKz`5ZU?*aCZT4y#l!E1Q~S-V`04J6`(_|$9Kf|Ig@9+YXM?$%h8UFp0z{#XH&FS z2UxH5xV0ln-+;*C1`^kj6x4JOHIoZyQ_`BOfK)#QKBtd0&*R86%KQMB;oK+ajLztR z<>8*5$V0&4F;I9G;;;V#mM8&3fLSKmECY@z(Wh?_1x$;SYj=l%+cWMbsILI{)X_$J zl>8mh#NR+~&>Nl!u4;_l??N>2Zq#=aaJ>dpmjejvfXQjpcEW9sF=v2J+Jh$kgp{Fe z)C{c}h}C`s&;NjE;visWBG$ewXrvyVTn#$Nfb6qzZg)?>mo`BQj5+~o=7>%}a?q!1 zp;gM45@5UnIRZHNenryvx8VdD7qDXKE)uG#st4h8SliVAl%uP?l^%H27vvVL7P!IzTi9 zIJO71{f)7FhzNQM5$sn0YE2*?eu5HyIH7AL=xr%v%XCPc zrqa{W)n07U zl`2q2Q^44VcJo1#J&+qe0=UTrz2@K!GB5D^CP<21z~r~!u>g8D7AM8sh@Qy6!7@OG zw)*A3?*+F9&h5Plkfcozaa%wuaaPW}3CfaNQHE^K8{9SWX#0h&evp02=-jo)JomjMscF#Z#u z|I@(7qoBMp)Hw`eQt|XYNX$x5;SNw{9WqEP>;_d!p;hOSx;84|;&9QQNy^)SwxQ~;OW(iA*(3VMpx`APKTDRAFS zc$T)`UQqNLcL=gm{HW;~;BhtZvmH7d=cWP*ZzAt!7_^Q7NR@#?b|A;*3E;awB;0N2 zc_q$UEJFWi$-fM)>42xk0lS&#E$6Eo-7nzAt3X*_;rY=hLq0!?UItO_MNs(NXm1Uq z*%{!r(!Ei72L01P!AG#q)ohH_Vf#ulK|6VMr4@t6TASB#T(Y3**t`FVT58xtYjD0#nx z1lkH3Jdcc{5#ZB&;ARVW?HBN2#2pEZ_ajd0O1QrQ?h}Efv&dGRfU%CpIlC{qFXJr4 zJ%DWqWLGwHy@kx4sqSPv{f&DGAd~}5s)8r}3(6}+zDi$UiE<+j`LqBSr_E3Uig1w! zGzW5^0^=&c8IU&s(k+0ug`k%?;KFaw{#fJ$`5^aBf_|QJZ-Dk$1zj){n7kJK%LZ3o zg`LV80rQQ)Gw z=zThi9`A+Z<*Z-<+T}dAo%9R(Wmcu%;B4vNkT*rhWkYUsfb7jd z>0co6{E!f@ArGklwpTvNm7yo6q5E85J_}eMf)a;;pR?Evt`0KqOSD!4nh!%y1OQW6 zdK6i1J0KxPVDw>c%=ICYssZyT=q3-Ga~!QX;P*Y)>!1%L#0JRRD0q4qDpe zAys|^y?zY(U~T|sVOK)R`~)cEf(~*a+2}8$yx0I)bq0FiNAUj5fNx7rUyT4Qe&Y@R zH&nZmz`ZHpXB{}YBRIY#=w%Y@m^RW9z#|E*7eofu&45rQC}tMuU@M?`5L_|^~d z=K*?Wp?N0aUOA}qTR?vZpiFvt7f@IMP1XhSurVle2&niA)J*$!2(WV_?2B`NTqU%V zf>QawS`R>B18CwPAn*~g$}Rz1nTkGR_Y-GjWfa{c$yNO9^)`@xd-$# z1hjrHWC=Y!X|#XNy$_P6AKJeb&@z!m$^<0!B?9lRR>F6zO$ zX$RQQNA(-B)}BM#gHdBMTsaQNPIqs^H3Kr^7*2K{g&q}P9B%_Z?ICB%A%{jm`{7)3 z$g6c2=LSHs2OyURD!mc#p6-sq$ZKGSkVBH7sZ7Y}1K|Fzps9ueZ*`Eke*>z!Kxc)> z1H2MeL^14_{g6L#;PMj88fdFL4y~{id23I=O6v~^{t$L#ih=%V0sIT8+#0y<2@Wa* z#a{u6G2B6*xl6&V6VU6v!0uAm6+giWXoqaY3Bd4=?tIj^9=v@wblro}Ea?Awun5lJ z`jcp-F)|=u0pHvMi{=<$^93wn6ZAb8wT%LdPUG1PkXid+LA(Zf>xx!e;g=5SuoTwW zozPHxq78Db=K;PJEY)jZW3YhRbr^LLr#e3f zDRPzcDPZ?I=sJKl2LmoofLgigavbecLPyfCa69Og5@9X+LR*KgUT9+=s9pF(T7fn` z#PjvQv-Cin1Qm0((;gBm0sCPS=y@?@@fG0B4=}>3QRivM=k+K>U*(j2Qi9CfCVjF`oms> z9p{1`9?I`Sy?X)oR?>RR=w5>K zjey7efXM-H!sCEtOO$R3YJLcK`UBYd9Q~ztI)Zm0K?{Bgdh$!Z0h31oQ`^(S*^ruPNW#&;MlR%MHf9N{ zfrUTNrz?Q_1n|}w5_Bb6-wVys6jaC6!OP(hISH*?2<%p&wvMnH4#UPc3;gZ^jo*Z) z>p{}dYbAk0c0l6X0bFr*&=b(ihpmuAPrKpH9&lqBD5A=pfO;7faUQ+p>O%!EHUyt- zr2l}o>p?=(1!TZh^qQW_ z2;@{RV7v+X)fPRdfDA8&9J?Aav;_TYh&K0u-Xo~7JxbmLu5SmK#`WA%;PwZMF$P}e z8W3%WG_s&eKvi*!uND%aAMTt7>`nt3QAphFXt@`3`3Yd3YbI$x_ynX#YhY|UMna5m z=0nf353uE2s~#k9OZ5I2c#P8ZD{x*467dW$aRr`k4ynp4(p-!p4qK#`M+s%%GIHfk z&>!P>laQ6=u*}}`W}JVcCI{`cz!UUuZv&+*0Y;kuDmy`et$>qXF}^*hH4Tb84?1)( z(vy%9XHnx~@J%OR@?ZEUvQWnXJbw&$?}JgLFzQC=JJ-SX19qKYS^bJU?h?>i9Cgvl z!;y!V5icLWf_UQgirde8uapFDe$>XGlaXlSZ`9Hc_p4ybbOUvF#I^e9<6!7S4HBpes2~gqmc3d?$9?ehA>gwq zMwN?^aJAV5?&D}nMEEwrCv%$V(Jz5iV7$RG&<3NMOYw`oOXjt61+FJ_qXfK$(RvVd z2`!X?G0|Q<1L;u-_}Abm6EiyIh8_eSzXjA2ptBwruZU#NM44>NRhWf-6yK_V@l~Lt z{eT}==cAA;r_m33krY^%CxKrF+AoIst3j)@tpjK^09soNZ0msMR@7As*>nQ`{g9Q% zVCzIN%E7=C<5;=gy#+P-Pz!x>TuW8~uWDe2-u!*I(-c(2^_V0m{TyJxNZhWFew;fq z`^JW5V@8N-%LEAE-Y9Q+G zi|cy=1q4%+lw2#WCPVM~f|SjVrvQSbCG_^X2+a2VCEdYg{j?0wx(vR|@_A zFTR}y_ix8<##NIeY&^FPG;$2(c0mVlw}Z8~vK>zx#4oO#{)0Kycc_n^@+hb`1N`+f zW)j;_iYu;+p)T`E{|Wl|22dbX(=V?=2K!M%Y*GHV0mM6j2Dq}q6_cwVYp($Y8-P|BGf{%8jZpGRJTGEm641}>p)sz4 zjm*B5ptd;X0QEht?uM50!5v)h&V|0HgZ8)-^>Br=7TDO0mkF972C4tfLjJ{2BS%8H zSPv5YZ(y1cER1kve1HSE9>&vL$IeD6#(-;p8X>t0uw*WU28fW#=o98Dw}^M5_mYvo z5@ugqw`87($k1hM2IDQFsFj)OT$^N$I=!iko8W#IT(PJ9%;<8)-!Yo&DEfOC)WH=4 z+J*Fx@#-1$SHx44W0ZSwg>xJF($Auo|A00~%Z#_9>^O%yc7t*+c(kE`!%su@&@X=+ zeIYF{f`U;WlvdmUqy_kjRsp3VBlReI+5$`UF>Xd=F*=jphD$-CjUY)G@xXYQOToEa zfRoPn#YiMtpNyGd#9TH;%HF5I1&tw927wOR11`*B=2}||P)BRvgptr(U9685dZ1hf zeAf5wHbAMKXoG8DO@MdC&}5-!4L~jRJ=!w?14?7Ax8y=v5dzHcszy%=(WgcjMLm>c zYzo)IG?Wy%3^K;Rs2F0yzPoQz@Sz8Z|D<7yD2?5HUj zpUeyv#s-Q!Nan#X5{@2lM$<7?g;7mhM`a&;c$%?1jA`Zm1I)k@u|$kV6T}W_nQ=v20cWHHX@N0+d=bNpF1D+Yxe+ti? z1{WXmFwShS3UKfd$nSmN?t^%Oao^;5N(HWha;1V2fw(;Ze#f4oxFfR1*b*ap$-Nvc zU+f3HB9xKLWF=R#-wCfQa}k+!AYzRvPuT)zUtBHyzugL0?qaSuBX;PEkcy=jhb zjM3zET4{{zD)OG8|B_L{lntadN_s~7F?x#>!Wd!31&S89#z}c0=#_azj7=qVGvb+} zp!LK#D@Q@;!q|02CUCrjB%>%8=PR-R7~3Wyw+Rcr8CAx3E#`1Anu54t>^?I9$p54l zVuMs7qU53;q?zv}q>{*VC1gplB14?fUtF1`J`wyVqLulb&=Zl_j2mM#kH{|J6C!Gi zvco~o*?-zM)F;F|b3#ZZq&3EyknY$UQV=PLoJ1^f?T$K)acztrVdOM(BgjF8fVGev zjF~4l5nseX6JScjM=`>ZIAZi0tq#gn_L|a$QGDc9+B=MlpxsHFl9wrw7{Nyz61Vh7 zFuIobryQi$guE`|uNeVCd{R=d9kDwF{V1ef$~}>3PMk6(kW|XpN)c^F&L^yyKS$e! zeG{1<>IvtNwvW1I*rZDxg$QU#R|5{y1zM6$?bA$9XRIhzn=%pPIR7o{(8 z!EfYv#!-?(_@7W0@`0s9EGMtA?-2lkF-Y_A@&BL=Z`UyjEp8lP$CO&3*`grV^l3Uk5oY$fZ4y) z4BShI(x3Y$iMA<+smn;8q)>7vcL5MGL1cz7kBG8H>}4VH1{kwTYggz9;+IlLcpaH# zKsD<03zRs+(R5whLQq^PjVo6SYS%* z6hK|ed|yUq3mGrs_Gx7?ql27986v(@W(aICa#qM~M)7i0iTua>8$M5bGAo6(GWwD) zks)$1`-8PG7nvn#8*;wP5~L$$n(zt6WK*t?W=L&(a}*-3lahtwqvoO|MQKF}r%Jjzc!kf)KXG#i(v( zoDzSG#TOO{rHRNBrBo-ziQO91L>mjSqCG)>9wMqC__mhM=tJ@ z!nq6U;!anT7Ni`$XlYZIkvarzlRHHAAgO`cpFAPtFd;%oO-N98lD~P4*rfGH`6;-M z(eKo+lmxU_X-P27Q)C1%4}z8k>mYS8)`~Gjv>8YVlpUlGk>Sm`Eu}MQkJcw8LL+aE z$sAJ7Te;s!Lwu*jPj3~iW!5MxT46beJPu-$aWNWdEb&SZyEvQ(*ks%nvt1ao!4+a^ z3C^52YoXo5IB446+y#QOFveih*TooGaxQ(2wAc*XXLdKO5o&VoN>>XiV6-4%&MYq4 zl+5U&7n3*;b`|%rqg*C;60)>!*$(GOv?TbZ%t!*$v?G`k!g&IFz&K1w5z+$p!8isP zbPVkr!T*D>RK=b+jQ=Ue?=rL*#a+&C7%9d%7o~uLnK+}!8G+2%F=sf;rsv#)QX%YF zW3;Ha!(k)zuraQ3rj~{KnXnWX^%a6of%5_4pO|S2EYY(dGPNiXT|`T?1x)E@$O9G7 zn?fq$ETleskcZ%jYYb@9R;J}onRp8C5=itcz#Zqq;CUBV3Zr~?coH}>$;7oxeCER2 zm4%rnGl4jtP6LabQJ(^wGcu+VO0>0 zu7)r5EFw^dpu`rrm!Bk1*sHf8 z*6NNqzhm_MWUg+FcY9+M_crWv{~xCY>vj{}olils0p9}BWD_F<80&5!GuvYxG)COL(Jo#;4>lx`rV9<~@;soDb z@?odcIVukWeKdlPyUb}WTW$||s-40*?mf~@={@HSKsn-+%ZTvCnzZX|lGfm?r{CT8 zq$OC%y58v}y)FOgRKYv`oaDp0-UP%lb&(xsh_p~Hx7*A2!$;az?tzu+40(eyA8Tq= zfNVSWN%@F$!1-MAp~bH9Zu?=Wmwc}?O6iLIa!2Cy)g4NOJq-~fm2P+WM(hk*hh2he zoncrjy%y_7^Oc*OGoZQl(juHWT_B%wE1eHy1-!5td*^Nf36mQWHuj9|=k4p|t8hwOnX<)tMs5!OP7~!d`-FR?jGbDsV{d2q zi2F8nQr;(bfjx1^o~G1;UD;Y*Wi?eEbqb}3wAC@>LH47Lp?+`uCa<(UMG^*G~Uvy?60V$GBPE3dQPQHEQOs*gAw zosfFMS&MT5U&nc^hQmn;`*rCVoOtCp31@oV4)q%!9uIN>g}+W3#J zgPpH^;uJXpRG;#+*-vRHT@7lvPHH6&Q->KP>T{}XwRK1P>ZO0ei9VM&9|bDnzx&>j zb(}*n&pIwGRreZo>NY1VpO6#gS6V=t@2GM^r43+r%uYzNltSrMYrHngz1=|sf%de$ z*gC2;RVLeGq`Q?B)_$$26Lg=}$}H@`ufJ+PrnN9e=$+(3Gb;Vp*Wc);Z8aa(z5;#% zZUd*jazZ*_?UCM<`>9(KxB8x!_F2oEKb17-`4j7iZ)>8TzeV~jt*`sFVQF70m!%() zEA<9AVe!1OTq#ZUmgo8#Czhxo^>d?EEd~z9sZ)~uL+@CNQo-P1>kms+K89SHszvG! z61K>)cnZ zCE7~&JM*0Xee-AgHPw>Zm* zxV0Ngm*`(OXWfA~H{m-s=qpbS2>fpBHur_rI|J&r7p+S4j~~k$S=T){3};V{R@Mj3 z)kqn8?R@82XO#ZAQKkQ8&33oQQ`83|H{>d@Or^JewYkpNUG&W9?m0{K>9zF&4V5G2 z8SS`PWLF1|MJM}ywr@0A1)sK-8qWjCw%RaT>`I}*V%a) z??u*Vk7|3Ig{c#ni>sf?T@}~NNY<@&-)a&0xLv5Mlq!rX!wakZ`A?;CBW)VbK6i^! zl-WI2thaF6xLfp(6F2*og&#eAdy}Ux9M@h@ZZsa$FF0G%Yr@y096_I{V&gNWBri9|=IoD*l*R-?>8srrv=;70V~KuJ4yU`RaqW?G z8|k!sn>$Dere+2b`U~Y9OD;?8j`l44BXUenIsN3vlx~U5Av^WG84TQ=ZY?jxd2fpq z*9vK-biy8p6X({}ew3MvKAt%~K1J#mOvN4x9x?UCCibwbKF)&wRvTbzHB*EV>fe5%{0cvkhURL3Gc_EfScCtwXp%*vQ- z?M;P)*x4vG+859}SNALorfcf<6*o$M9vzy|N1mT}SRZFE(PpL(D=%v+%!Se>rI$TP z9c0c>{DCF0`T9ULlIo|-vy0Wc(u0B{6IuR0(mR6w^hV#P)coM*=6L&K{d;Sq^Sy6= ze1iWv?cM5s@*5;xu%~I)>#M7eXKhPAFMpeuU*J?Y`tq#G%8tG(Ge0O_op-;pFA~pL zl71`RBDXABqrPap9{eq_RC>eTFEz@}&1_OvqU_HYR69nQn7yKMUg$=BN%hby+4?SZ zD)0?XdpM;vO!{*^E_*R=kbL6&%#sF?mDbAap*018&i-}h#uYuAZWz0*D6{%S_l=yj zl_eQRjXv&6{$t4j)|kw7iN_$D~LsYo2i5@%M_6I!@x{s%VcRCwVZ316eKIOb|SZbYhM|B%*LDt&njq;K63zU6K&V@L2%a*ew}e>y!uX{&FH6l9K;c0@)5KXW6_K;zm#wN%esWIqzDO_kyd zgvIjP$<4lU<6F6-S*=b;-4sk}XOxk_dq}Uh@vO3hT^J7z$(*s!*a;2`L5+P2TOB4$uHWwr_(hT zF0Xm5|sg#2>ZXH7mlBZcaG zeKQqE z&i|45ls(b?O}^Rc6@DeT)3~SbnG-{rOse`my`^A$WSMa`OE+z!Na?O!ntoV5=vJDW zf)^4msxM1hTl>ozA6!@!&R#2TjXe=+Wi6Ixo9BW}QekJ3|3mqS z+LwwRyKo+-?PX}g1H-MmV*~P+rF)pY+``PgH4lbbCi@0<8tt`F>H2{wN`7o~;8*q0 z#29&|KH6j1sOk;jdyso9xDU=cUCS!m4x%9QUb`4I53H7eKDNx^<^~_?|SEX|6lYs>W&AR~wxJFT@H%pUPb#pNGQGK_5E(^f#P->XsxZ)-9Zn3(#16kTO}ROi>+>weeckPuvgyBC*2afeVG z3dP-v7cK7Ylt1on#e!>GHoKcmcC#*5_kG@v`H=i_?L6nqoSAXb&C8b<3nK?3+i9=T zMBIzJ)mhR!`lJ5>y;N8M8{j$OOmwH_t@52%;4fge3r4jDd=GW`r*r$|+QLMT)tsYL zB4PvExuxALtr0KqB$b`4X+@~;U>`uYy;_OzBTBd%8BxZn9!(Re0*m_p)nv0NkvhTgG28t;r3bWKd@l7Q z*96Cbu6RmtxykIlg0{dX1T)Aj(A404LR3Ueq40)o8hnkHDR;o!CW^zhn!8Z3en77bU_d@e8!G__px(T7S(P1F3=^8pYc!-WdjRS-xLwC|!BRZKK8d?tL zA$!C7DNnEjMc_^1IQVxAlV3)(Aa${(#-Uja)t8>3n}sj5BRB#3D%B@Ggw0?=`5^6~ zEBNxbccECcrgV``_XcqkT`rW0%XCl6W#+6eMmyV|fE#2EpB}^+vrpAk2<%Xv1= z_=iEe4An}mpeyvt9Xa$0Ajb0`*j`YWQx=Fj(YT;Ei0*ru&vq z$AlE<1Xb0!OLs{*7|7vH`9CuoWxH=rQj5~|s=zfAPN>ZkJ$^=h9{mV<$X~<^?u$E! z=24@ATcbVkw=xOk!Ee1E;)dH_>CT6m5$5O@lnQ*OE5mnPZ*)&1UQGokJv=Ywt^Y## zlvp`B!_^SwxS5dToWb1WipzJvO*D_xdu%P&dNfym)zbueP7Dze)EVsmd`mSk+(PAd zkd!x<>g%+?cXX58eX&RCcBmHrqU0+nQ}x81&V$KaYJykv6S;S2b zHDwdSH9}*}E5+5`5I0!EhU=&g? z*~aQL{8Q9pqib|_tuFIayI;C4)KfQB{uajICv;jcgn5;rkwx*oy!O;GYF_Z)=oieX z+{BiKCkLVALuG&azs8>TZRG5wO%2cCfYVO8q4mgBw}t)}YaRSTb`SSq{*?5-FXmg} z)xNih!7>-bn1?ug@M%g38S%ZL_UIR)Yp4cD00(!;H_r5D_R!k(oSw4QwcnIp^_h~F zl&(e2^36*EEKN6b_>EtrFYcMupMUI?P^f9?pMrR){{_s{B+_2cFE)e3vL>Vs4p&rq zK@RF=+2P8y^P5JGq}(X%rCFQsy8IAGM;k!>;_Jv%a0N^oge+l*)H3;D_Bm}HG1&P+ z&qGT*uc$3J>rba1V*5(kRE#SbN^Iondpp7=YHnbuww7-zw4d~=GaQg{oxV@bU}kaL z=)zSorm`FKLF7$v5IJAD#i(9mM2eXd{998brX#nQhjzr!5SblGRHteVgVgn}M4M<| za6V?>SGzEf&h>=$C5SAjs5;8R)yN z{4E}%8p@+ZjqowiTSf?e&aIw+8@f7Eo%x^cV}a{&s`SBoj8F89)WX6JpM&UOccehUodszg$-@4`BCQFJ&_1FG&Hgtaivvvy*-qMbxHq4gy2m$3`pS?X}4Bd$uh z5eJ%r{7)&=j7EAveZt)ciw4GWky;>~Rt?=F2dbiw#I-3kBW)0!rVnymtgh~7=DU5G zII4M|6`AhskM_sAM_|nz@dkMuoB>PJbMRYur2GJ2!tJOP!4n|Ob0~c?Tt!CkYEl+f zMY$iWObZgPc)iUNR+g7Xe!(%wrfk-nlXhTB;Gq9Id^0+m47zR6ZD>z&w%rc;xEt^! zv5C--{}EY;tPV~_zvwslZ>yo`SvEIxEm8*?%vyv$0yu4e(;}UO2V@h?|04fFi!_a( zVd6)Hhx(&wvR6D0Imlz7Hqi-)jL(KggAVaON1-~G<#BIcds1y4c92afQWP~r{4}^NyG-Y@6EbJDOP*h!{{*GCMLs6al1L}bd zikyJ;>L~F)_)>U16dS#%$yD}5BLI8838#SUwtwNN@NBrgx*CQ-mgEr48{`I1I#pC= zX#R%ZM2|y#q*YLDkXQZ_P6C>(9KhW4!V7@Dg-|c61UeiNG?z6I_^#%sIvs8vU82cV z?!mi3wdbMviUXf2qmW(G66^>6PrV0K(%9jOU?b^4^b2%TP1i`+PWAtHv~)-sbV-)nb=8#+iTL5(0Do@)5y*UbD8T+Z zYpRMmrE9cM+8PN;Ns-oKTXAG?m^>=vj9d>7ilD;y@Pv>rgo&HOrbzYBsE8D*80;Ad zhsp!SV4X%|LL;E{dZlTI z&4dTwW^x4Coq<>jr_-*`mgsI7Dw`4wvrO${)|v{;v9V6`@YpO^5je6sK|)W{{s2mUG5L% z56gO$WtNgE#78z7ECT)SGe4oV*6wnY29Bquq3naR{r?H-?BRt z&CP;r)vV15S37%`>5$!XcOt@nVR_LA}RhX0#lT;RW$h^dC*FDhD z+-UqFriE(&UTJ{1D>5}47hWTDlrp4#vP<0x|3GVzqnRk1rJEnOC#6eDFx6jqaP@;# zMpYH7Y){Qd*;%1R#f25C$JaAv>+5SbusE}cT176Qy5c2t6}mlniTI1CK(43i(Y>kT zcke78S=OiYjeVZ=cEOLrwRyPZMt)MhJufReBY$T8 z=HZL;rX%DLG^*%ro@Yn|@#h@#RA>Vlla*DwdT@A@e*JxzqMsp(C^8GW5%{o>FhCnn0~skgL!4_rNqS* z9wcv1sZg;-MJlmXT%y^a@5P;E9CSNkFYZ9bYK+i7Y8p^5GynFCxk+WqfRMN@zs_2;|R3H@f$*Wb6o;$fP zY}r}V**4uGSms+%i`(M0UMziI#&~}VW(j}cn}{n+99xZ%waI$SwA%2z;;V!)v7~XT_6|LbYDOfZ-H^VTy3rkKk@7{E4;bRc@JL)sR-jih+3X$e z2baJf<16vU`Nv#;wkv&)c!>7Ubds)wwtD-xvK%*UFRZ%aJ4HJRhvcuxJ(ovh?aI~X z_Q<=EZ_f`Dd?{qDn+rx3<=RF$Hk7aR^$Bs1mP+KhX)Ee_Xsa7)m|_iUjFa?#={IVd zaJPv0NaN_y$c11_UvGaCe>H!JZ-wt$;CAp;=u70hn5kNj4^$pE*s#&OF0Osjpo+&T zPOo&hlBv>>^Fqd98}uNZh4DbUs<87P*-U_!iO+3S&jaS zHE}cfMf_Op7arwB<`mt9D#c^b2AZd0*WhS(TIpO{p5>?|uIPEes=U+r+wzX&wa(4V z$#KJX&*@cj$pRGDr;<)G;EIfq_%w@gFG|qI?gvYqdN6eMXCyWCOTK#SA zZ+auX2dqKe3I;s+p0!@HH^;+!=lVJXdW1aT%EAr#hGr{9F}QxAIV0h4MLhXsB`GzK zdM@=za?gq%D-4Ja>NXU%5APIbP; zWSSd{c2!N=l_oDMiU>#U@Dy z;|IqMGXH9DXeY2q^l)MtIs>)@-);`Cf|W(4(&D8pVN-Qkg%Dbb+%UhM-^ zApJpxFA6_^yP>K0zrO`#<3^v za@qXSm5wGQ!)&zmTG773Q~9NNL-H2pZpvMicPaltVO;TbTa(i2&Q|4HxI^4e&ct`&Pw^A@FZ?&Yn)Vm% zacxUoW8HKer=O=!Gwd+jHB>P!Fy1nfruL>8rkbYerq9OJ#@B{{hC+Qa{dnC9Z8tuS z+r!kMACj|h9VWt`G?SyFm1WXv;g9gA;0^xVDNkiHm4Dip^i z@OQLJbd+JL;gqqT(PLO_IIFLypQx*^?Wmo__u(U)$|dmzz85d?ANV(XRqZQnW!+L; zRQFZ?o1xIK$2iLLkLjW**R%uqJ61vK|RQ+6iP+!Y%$k5QZ(U@-h-8jVf z#jpgN*kkmcb-7wUf1F#*wq{IpZ*m*H9Q_;qOY>T-CRY*PhNlG=_}hAiy4SmomiZm0 zN-EknS=SbywER_6Sa`UAD43RiC$DYZiQLV(cpja9H@{-hy5d{5n+|(fC-*vUrND_$ zePNNbN4XoF3!g=+5KpLO%oOez|4b+*fu!yO;S+AEw9A59m~; zG4q(|&aP&=ab5X}+Bw>Nx}|#7ARE$+ded1GYWA8|n@Cg0s590!Ao}lGCx3xk#QN#= z)CVFSmyk60S4~T`r`%9vBDF%T0*`%@JcG;MI*Us?I^s$+_6^n>#p^98MLP@m!nguU zzCUk8o|@Mn-&e4t$Wa`$*C`$4Jn8Q3GX^V%(*>QpTIr}+3$H-C5(mgy%m%IlUthOE z-`7yjILENtkf|@%t=0|D-rzl4B4^;j>^HU_w}E@Zm2utpAm2}$p|xu#>m0hu`q%my zhWf@c#@j~1q#9jDyYYe1VW?!NtFNca;O}$Q*w=KFjNlP84KYBwqZO1Q@nK|N=tw~D z{q$s&zjJmk`{Ec}(!<`&_O1Aug$D1_XYedK=I_eO$~~GlE3at*RaB?A)Yib^ETh~_ zy!QeN!wZFM>7nu?nhq~O?07TsB0Zd)z;)FQ(m&N7H`F)$YFMJr*15H)b{}_uead9f z$@BpFcls=-a0fuYb({SOuBI;Ef-m4JYd2_LX#47p>&kTo{UH4?{U-fe&|hxQZs338 zHnM%0Ak~+wMNGh+BAelc(8Oq-!pPmlR*|)#pMkSJ#naK9>Uvx@sLf1jDfw%eqbAL zzjG_NBiw)7HvR%Qw-;zHX<^+)-8kJLT@iR2SM#g6No;-QC$)kcLA=8{qh{n8$WOng zCMzwZheFTD^$;5z;rDtwdY+V@bv_MK~p`QL@1{#tg)S&m-IxH#BO{$)(dTg}?bHFElMONgOI? zDqEsmp$XTC50536SylK*1Qpw7|v(t2@LKYy?R8}Y`1L;F{XJ_cLjVohM$1h4; zns7C4YCN5AH-1Cx%$P668HPGKiVrXq=sP5W^@Mu?r<-lc9;Ja&sFW&+>L8U>;?SA z$Z^Q3&X5j^Cxr)*mcn=Ocf}T+3Xevg6B5;$wlnFVSNV_nlURt2M1O-$Mt@f-NuKbr zU>#qqt42wzWqVfDA1%KneQ1|)@J)lao(%lM>Q9Y-49)FWbfm=QZ>`xz4A(b~b5~kb zySz>y&0AMjXI=H4DIMZQ>X(xF>apN|p03V5_V$)_`9t#43p*5CDaBEo<$2wK`VaSxHE;#CZ%3CY`V%U!tW_A0ZUBga(JE2JOKW zLO?_n;9Y>^kjX5@pU_V-RWuXkhsFWM)B30U5C$P95CsUxuUFTB8rC|{!|ie&v>nZ# znBD(Nlg|r2EPI>&s{Pw*AI`p?@bz=%ncNt=RJON#Ve|{3-WGhyUy>#K?3SIHH?QbYF;lXk{69Y<05=`9Dp%jUFfk+b zW%W08n%3)(_NjLJ8j;k5q#T`(nuc0rPdLl(^xpFw_N50#h8b~&>`>Z5HL)7RIC>Ww z);>3Vj~N{wmpC)&aMH_!Pq9->S9I-|+4v*4t7eS+KXHSgM&?JT2ppOgyzIZ`i7lU6 zy4H5Vay@@m&ZMk)S-mpdKicN(&3T%OTRi!rY}+kRac9>9SEIlNr5V=Wv?YE~^82*m zX=QbX*Mh5TNxfw##aPG^=X=xP;Tl?wdVU5*1hXRIY_%!YxB9LKB1}sYn^1*#_N1XJLczxm0gv z7+*_2)RYl>JK<2$&IsnV@Romc_$&Oy74?N=o zc|r}%0P?XeHIAydvuec})oNX>`5%zhSF6xGrlxKkRSj$u^p^LEF95drP#h(d$f;^l zbh%~(au^*$Sm;~4V)$k{8D~z2CY(%|5x3WTUf+m2M!m(Z!VXQg+Bxb}2W!?u`vbok zo4Qn)D%TTti_^pR!mYw&s6p_XZ!D-*V%beATG%pgd*+JoS3Z6Hy#8Cv*K0pI|9G9V zEWdGHne|}FM%S>=CUrIG)RVE}Ds``Mwc5ezw<@1XJ{0?m-GfJzaY2o@j_aD^WN8z- zTnd%lC_P?2xje>O)i)sAQwBS{ct33obCbl+$(>SfRasKGb;{Zb=Gb@oDojOUJ90O= zQAtr&DzI8#tpadJt0ocJ4DUm_Vf*me^mBHiu8H9<)47=Rn8cXJ#t!<1{7xnxoVPAG z2VMaeKvDQMFfI5FPlWuy7wUs@LarufNLA%~(p~AXTu0g}8H5|5nSPU}yK||-W7%e@ zlAoHNQ&67sv1oC@y29fH*NfH{XbYPb78abc!sYKgf5@MZk=#{NVbZ7+U+R~X|E180 z=VFYyhtvi9GBi^iEBETK9@W~6yPJU14wn)4b6o{LuqgiqztW0947bC zf3vIDX51|HG`olXO0}R+Vi3^-pN5|SsH1@yiEl%T@U7$kd?&m{nF@R&x&b#BD%ufh z6K$;M<(a}+aeNr@HgLJDlZw-e7Ut#T&BHssdM+MR37Zklr>f76erx$(I~p39!< zdKG93Pa*p7UE(w18&-auLZ{TIkQ3uJ+~@C8BZ$hV7s`Nks2f1HPYL4%#NW>+yZ$ac zTh^-VkMbt|s-YZtJp7R8#4pizGhR1G3`0$u^%Zm`t;~(!b3p%BO`D|sSF6+F+yy$G z=!kBF4Vqoivudf@OY;wO0ey=#ARzK4QA8MsC)iMQ7y38S8{38j@$2M&)E?#{bBD>K zF-jtMyeoPMo(!i$^09E`rzq(_gdXVf^Tspcn~jJC&2_;IW+`Vkp}j6yo2 zt+CFyNDgH>@IACWbtYXUZHRL)QPMK3&5gx$W0_b!UWFx zTYC+6_s`_om677=h$x-53)`zMhg#FZ|3_7JLLDxYhAE2_g2xi{8_dxWrIC8 z<@?GtDuZfnPz^D$@6Bm30}XL{qoFUKrK`;T%VtqK$$yFFK(FyF)F!;q+rYiqHPHFk z_0Ijs|0evasE>Zo?8ScK|B_9q$5dNp2dnTCbo=$|jLl8W%->A2jZ5{vX#2Ch09w?Q zfY4=_4IPR7i{HU8{0(*%J%x2ZS0GK%hR9QlA=HYa4Elw_Yx;TR6JFlrQCPD%ff}V)J_kg}SMGu#ZG_PRn-IuF+1@j@FLh zd|aWnrs1@%qOluapT9s=!PDXA(mWw8aKeAeztq1iFf4R8QcLv9Z{z{$HRYxHT5}1G z!7_;$X0djXp~zGigT&;SHyO7WhUho&OZnm8PCD6I+!OX1TR_XyYa$DeL!ZNI;qMTN zIFXU48DE6=A*YcER4O%`Tta@N){=H=0@a-wLDiv7f=)ga%tn{c6X*^k1^ExY4fTh3 zs1;b!ohA2?x{KCGx5)EQyU-*5eZS=G=DAYNl-G9RrCZ7#*t8{O*5<`F+kk@lRw@5{ zUT(pw%pC=da}O3?wTGN_BB>$<_rvmtCCpp4KA2{PY9X$#zR)n)G|+S#AT#yo8AKXf zAXgBd1{(xE1#Sh>g6`l8z@^j*lVKu~5d$#p6PQ zP^Q0)S1C_&^>e;29Z?!7rApV8)Uj5vEi77PIcFgYdKVqcAD<8A5}O)(Jf@=gkujtn0r;aZ6(H}D zO98U+fqYF;WDDXM?!yjXbMVE4NNuJ!vlX~yY*Tg(Bh%05#dH$=8HgfbdJo--`iGc{ zKfxAZAJB%_GPDIY7^{yqMcN|^;2}Uql&RX3xypUHzpzJAOe^*TwaJom( zQPf_BWzJI*=)F`m>NY)*`^X>C?bMIgKhd`1I#FAwxkM7!FbGiRDLZxvzk;@eCnH^< zmdJNh06GH#Uyj`XKT}W$_;$7fTAY;VRArppDl#{8IndC3w7fz2n{vr@xa_E7MX}qG zR@}Bw%zv1-JkM8f$k7S$M2?pA*ykgBEHoqi*5>XP|lBYdjn?JCR8o z6F)6>4>uc!kTuac>dWYRWoM*m&Zzg?k%I_Mn$Q>E9UhO~gs)&v z0k2Qs#l`}gEv}$b z?t5FO%u(5A3gU8}<@YFfk{v86aQyJLR+_3!)KPK}HZW^xyWxkqlktIRh-r>7QG1BZ zAi7{EHU)bPot6EeJP+?{?tSFW@MU=>c>nR~LaibTg%q`gW;OB$b%=YRP1L6v4g;oq zmHDq2I%bTiwZ4#Fr0uQqaw)W(`VUXSuj9#h5%vdB8UKO(O}+!j&>kY5wBz-dq55S0 zAEq^zORu6~E}l(h3%Pja8Bvp(L$^iWps{!l=oWs7sD^$=FQO&z3uq{I7Wts*gPnyu zKpRDCMuNA(FYgu33y&jj{MQ0C{GH1ux(AkxEqiM_ZU11k7fOyx`Tym$F8-r%hSgTM zGrwZCT)eh0FRx8uZeib|Cgm?&xj>VqfJZemPBC^h z&oT`#KGUAnUeR7O{cbp_Kf+gFx>GCY-{`qa0sD%T$jx*;stZ|%+(ezm&(nGQD|Rd8 zqAw6X$dhb0_BnHf`$pq!0?ziW z;7q8UdQDv`{Scf&?QjRb&Hs;YHek61x^|VO*@igoIy%{QTF2O5pFg(nPR`7t6nA4!x&H^4pQ?rTN6O%TsbhS+@tS_Rb}5_4uh2c!e&D-l z`*X>-4*i5YQo93N5lmeL>XbW-`e%9{dS8WSN0!NV)MB}|LMhXciTD$eU>EcCwJHbe zR%&JbB0pb$)AW~li2jf+%8#L&aszcKI#ORrH<+;lBr8A`kZYN7R7-L$Ih6F`JIS;7 z47?lQuNyOODFZ!`e2TXwx6*B?BTRRC7TK8S4SJgjl zVkPzwwgMgS6Y!M|muHEW!P)8wEfyY1FXT-?>_0{tqh`Ycuy&+G9j6Y_FX`*tJ)Y5B z(jC=F#;Cc!zCgcA@6w*pH#g=R=K@al4Hv_hKn?!P-sNgBr-^BxgIG(frMghxa2}^9 z6`x7=q&pKG@a@z(@&I6N4pTXJbGirh8X)d($p478_)EMpb^xt`ch^jdPKp-DRmHIM zQ-H%t@Pi-n{pVZiedR27_Od6HJSo{z^4?ktut$#yuUX0B9hQd0Ypw4rTdZ2^O-EJB zbKCI3W7c-o&Lvk%H-=xkjiH{QmXQ;oyr>p^!%QN_@ngyBpypNPkLvpBI`g=8EmeX1 zfF+Ty(T&JH!~uPk_eqb1t>J@_#nKFMkhCJwTc{tIAf_vwpqc1qbSCP?&tl!F{>)Oo z8$Ve)OK&iTbzAihjQNJ)rURz)=6>c}!#@2e?tyl!b`77!zhFj?Gl?!FMT*srQHg5HRpX{ozi|ubm+W|cy9wcD>i*@| zGf5Oooggk@Mc5te8*~p!m6N5}VkMcB^VGLWHLx?bRa!1G$9>GgYb_q79<@uW(=CJfXz<}*DJ z?|`9rd;BeNor)5-iRr`(qBqfw;;H3Sb8t#O19wnIW?-YScEl0XfUQI_H6bNd=_z#+ z`-{7U(}7z4&Yq*L&z{+CxNM~JW64PSNlQC>$ky2g*$azz79jRB);WbYN_Sbzjx{9@ z9m9(~#r4b1Sgw>Tb)I!U^1Tcf`mcuPhR?_kqYcT?)KU5|wU?fTx8oA@EddYGO}B>C z(k0}6`V06UOK=}D0BRXsD<7B2C7<|7P6Vvf>d5fOIzbk4)G5$!&~f+^3{<(;RT5() z_6pa7->+L`te`70)H8dGEn+&xJczNFDf2Z$6@47|mdbQyT}w8R@{_|!KN$k7*D&le zG7qhdzQYLeZ=xb`5$lgN!f?>N*PzybFYPBWj9gAmCVcohVmCGhEk}+(37SJ-8lNto z5ULA*2Zs9=c*nW-dP>}4skJmKYt+gV@O z@03(3JLVW_UuD}{*2;6k7mmyqCe^UjIL>sz_}&<=@6TRjE>R!I+qA+IP<06e zgOOzP6}}2@NnFF1VlA*-^eT1-??EC|FJcKk5}$}=qO-A+=p*zrvLzakpGa54HsZC& zsL)`5vO%63uI^>mod8WJvDjvpEVHz*`ikYkzb!#)VNpNZW9!1gV?{Hq@2pep`clY# z#JZ_?ziok|vNz!E5c(GW92^li8`%~8fZe8UP^YQ2pl8USQ9fCx>c{Gi^Y1~AFpwHZ z*P#9dvsMXe0Lbi9rIRulI22fw6N*9n7HKZb6SqZ=Xx6|^;63Wl+}2sY{Gxu3-A6D*%=vvzhaB8MBd{3HFHIqR)|X z^ly9+@rL;P|9j8aGo&Uq6%P=%$!_F3ycw!T_ag+_3n2K-;7MSY>%4F=QV?ks$q&vA zRQGgu{o|Tldd`_qia0Fxy@mHniY=|IsC9ACpuz&{Z7WjD+hU7F%Ol$}hu(I$WT>NO z$yoOmPu<|q$Zp|k*dSC@IqVUU$3CS5>KkQYI%_?;UPi|7r#6!cjN0wJ1H>(xKY}6y3v5^4zYS}3-^NK0mc|4&)&;Su;NRUTvW! zOIM{9Vn;DUxGs*Dnu7Pmr2JGe)QSLY?1jYRmGDaJZLTZRnAP&C&Sh+EYHD6#YHGS{ ztZJxj+HL;bRMDu{{mXUZhH^0HrDxC(RgpN0rC}Vt5Kq94U^?_EmVy0_KEbS57Qi}8 z_)?-eb_g4YZ3Y`c?GYS0q<)o`$eb7ukA{x~UIljh*0^W7JC}cS)p33(g&nJkHOSXt7>Wm$Lc7KBN;jk`-U;X&N~mtk zV)icoyY{JetD&WCzIGN%@!Q!2%oOGkS(BWPeS;SO7OIZsmRhX%ZXU0)TB^KxcsY<~x#u zBqHsh%0OjuR^BHr5%&pn1P?pC>%H5&d%S`2CeBpX5NB-p<_472e9kv!|&3)KoB2E3=egWxI(OS_QppdTz_RMc+mOudZK(@du z0q!~j`3rdpZ-R$InZPmRxaw0k%CxjxY7}lAY!_yDhElVAQt4_(_0s#U^sBQPhlRcI?bQ#C+ofMc8SV`M$@8(>ra<_`1A zI2->j*PK7W&SM8LSE!@Zd7=*99%%+!fa^$xx=P+4ZxZyv{K$&1H}WyuT{s^pigb_) zEz}+04R<=ov zq(-4B!qd>>un>CbI~g4BY3v{G(|KjzYu6l4jQgu+d-)lczI>tUle=B{ZBGl&FmH$8 z={u?%$I0ewQU9<~y4@a31>_?^#(};%18&nQagJ6l%s0VHf zTylnKnnYux@03{ivs6z~Kvm3hMFY+K%|rS<#u zrFtho{s!w>Xe;r}xJ#^;Sw%l5FXNAaL&TqG6|@@C3zju`;3}^~4@8rqBcsV61MHou ziH=ZDDKCJ-#$4q;rLs~}zATOuX@M252){B=ewNpCjddYz*mpE|By=tm8#)pGE+(o&;2KB>8Hg^%_7iKU9`qmd zDpqB4*_tfEWKl=RWU8FB5gabS`=Cp}k7!A>Sly>im*>a_K<()vsiIMSDl;JMqFl`d zi8D=**4T71m3C0q=|yx^&{Lk%zSi~EP0|axr#h9-;hu0Uxdq%awh<$c$M7)r|4Q7E zpz_DSIhp~0n^To3${;02Dw50+CC(CRg6nK8U}8;SL!>lvKQdQ1BT&K`As|!}n~M6# zx9|vIR460TGrTdpEj&3;31ADIz<%X9cMFfL{Jdwp$LdOO_jjN5F~J|fdg0E&4v|{I zPGvsupt%fvgbe5*tUb|=Tu+XocG5DPz#OKBP*=e8Je7sU^xnfYXuk5-Cku1T-!=QZI2s zq?vdoVwMJo9i$pUOr#)uSv(?Wz})vR(pNYZ?i)H4*%K-Vr3Uojg&{*Q?Ca~<v z=AP*p=!x;9l;84I^vC&!`UPLV@YL`;FblL&HmIA`T%c|7BKcSqycsc)RESe#I@yLy z1y8*%xeGslwZobsUl0R)4(bRU0u^bV+DJJrSC>1=iu|kcKyyi5E+2^g8ZAQthz%tG%J^%h%N2;{M?Wa}$`0bRX(ZvH?J|iqX4R6Sy3> zhOLiQh&EJWpm)k_C0iILjFd(QPZVclwWJRl0Z-W=Y?c>`XF|tAkHLvoLFgP#5Jp6P zh(m-I!f#S`I3au@^oRI8JSBY9f7SocH`H^{JJP+--^e%4z0$WdQ08yz!TfDQorDjO z_0j?5mHa|Zie7<@$TloMlwfAif0YtVa2fAJeqTDUc z%Qpi&!f$L5dxH*8M5Dn7UNM)r>hs_WFTn*R_4Q8Ws=2JMB8 z!J@ z2vi+$BvB*)rx?Yc_J)13QLX~X@>IYkXGaxjgjArpB4i;oBHcA9;roymz&th71hIQ` zpzu>Y9l@iYLh;IJu_$bo;P5Q9m3NXvhl;~D0)K|?2CIX*Slx?6UU($mzuw#7dBLv1 zPJywZ3u4beJK>esNA4k=l(uNbMX~4uz+U;_@6owfGCTtxi>^fapyT1E_)NG7at!R5 z^}{Mb6`;Wg2^lnvuv;1uxd?rR2BKNWcf<;1z{`mSSXXvE)`;wc)yFf*2+;R5p^VHN zJeN5IHs~XylSH}O=wG^Y%`@#hBmq7_`vE?ir6$n*F+>%KHt-GgX0!^v4($ijRf+Or zO#@hmR1$u`Yw!bsAaYh1Os1=!LY>fwniJ6EFcKYsu2cue_tg@M~I8iSI@2y3$_hHBi|mXyzyz zlzoZ}O_6-iUg39urVj~+pi$5zjTnB1ZG?VEpOi`P4tRz#0p#D*z|KeO$`ICBa}{)R z0eBfS2VMl7LAwCQjOEcp99EuUhY<*MiRaNEIZJ3wSkz?tYGfW;O$vfucapM(IwRdA z(l!0z=bFiIS9A>oDPQTX%3f$V>X)}uj>tn|E=cm4hD{e+l9`cXbc4`Ca=B)4B%jd- z>rt(QnP`UQUAUAU5E?{{4~)S;CcBaZ%@@*X;Om0V6=0-LofIyjuKM@mUxfcpSg!)q zk^RJc4JqG*fTF%z8RSE5X)9 z&0o==v}jB>w$b|ZEgU}jF?Y4T&|}KVuBPemgUdo2aV7MIaD6$(j-Vyz2_hR=**lt_ z8)1ahPuwwLM5i)g^s8T5A+7}ZNBc-FFlX5rY9DbQ*2TsoJM)rk(0{U@xSZ-XzPMp$ z8jzcMJO{^F_9AnS4&{dF^~qf38I9-Oz^&>XQ-CdP1lYU!W-bBe_|8m(ElFD$3^_() z$pCs2e-{fS$Y**yX~wM3<5`~h2m8tj%y(vsUX01k=0ipG20Vs_8r$^Z#7vtQzmfdv zK`yF$^xQz0^k)m{9og*Qk-}vVyYv&N9@yA7`d*-l3P3q>8MUw>a6#&@PF%eiX3q0a z^(@O2fala#IIp&5*U&j|z{phD+r_Lh{)OJgCL6fo>SauBb z4bk3oZR2|Lgj z$fm+mLWj!gAIvP|blO9^F$MQn2Oc(ijk$0lo@b22_fL#IV7;C(#zM<+E!029!wb#&q2%d3RDnH@J>Af zmf3!2bZ2KKz^O}t2g^Gn)35#-?%*Rhu;ZZ4c^3{-)#2<{1^S&2jlRqVXs~{R_t!%> zNsT}(GCdbEz0|hi9WtP^`3(A+WgL3Io<^#}RS)eVt1y_#Z zt`)d^p2YKw#p~t+huj<5q1kZls%VY*(5sw-sQiS}-2muz{(v*!TWE-?;L_d&?y3r| zwjJuI1%b`UbXl#A=Ujq2{2iW)b?}or(5W^oyKVU8gR2633s7V@KF`; zRe7x3ncnp~@QU_uytoYI)J!kLTzHRHh?V-KZ2I`fAjK9(I zmZLR9I5iH$zdO*TnfkT^5g)j=;`!Da9pUD+1+l1$YdwXQX$d&U<%Ra^c0BKWWD9;l z19u`+SqI>DezZpo+*w(?Lo=wf-b7!! zjq9(6VyzQir9mjG-b1txKo>Prc{dg9as|CGQ$5!L50z#ZH<_MBpV0!DzCFe8gwN5Y zmGBc0U)6@o(Hr!P^Uz1lbhpV0o!FuniD78azJvDX3#hneVa6JdF=qcV%0Vf#7_5gH>KdsSnU5EkIjI8 zk3vuPDWY^8(cFU>VJ+@o!SjUS=H$oebpoO$^11S)}c?danQ}3z_w?SFh}k* zni@Ox!t5AkDq9ai*_j@+SBwy2P#!MG$C&;kvy~Sby=ZfcG7t1&J0RQAky$~vp~p=? zOLsA z>8rO~Tc@WQ7un5vTjrwv7WG)3ZiQ|!<_C5#n^)!a{<;^MzVYmC?Hcw0E7*V8lI$p~ z8MZqY>Gm2DLMho%}(fvlESDWDW90S+(hg zRhS43%fpfHYQQ+eoyW@l6z8`q%n;6Cw?NHu8XOxJs$tyIc;P(w)%mnP)HAHY4Im~p zmAb?0wG^mBBy*$Gy^%tZu*yRBG=O^>Ng?-^+VcuA;OF& z^&@?$RXc|fO1WE1LCsB@=tYg9?Tvfx;p!Oq|s^f>B~Gx}Y+KRT4wq{EB~G!fdo zE$Lp$(ZQO{n26v1j{fV!6$Y^mp5u!#W7%9dEgoXcT#7VFD8S}1au|7NL8dl$l&fjH zqNTL<`azl>>tQiHLpwuHVh!qH%%feII68(c##oHW>`1stfdmx>M^~LI6njC$o=Fx&WOIwP zhH!|l#+|{>Gl1@(;xeJK&`KP{=Yk4ADLy~%H2Ycn{EVWB&sgf zdT2k@_0dDY4x!hKGk= zM=D3c(WQ~;VPEikaA`1K@L2dsgx7BCeVD9lId&s!XSb33!dIb$bWqMPH8!1)=m{xWYf?U0^`QDK45Sg;}*<9I(;9jv#4BnNrV z{(>e$K{B5$!>!=&^N+bI_+1^+5nSYo+*Ym@SBvyU6|0@mLT8{&_Yd?n6yHq$$Uu)k z(ZB&;Rc}YvQ^z;QR!19Gz%|vq!fkUoopJ6tu4L~>-|1jOHJ5Rd^c60Ni{#R>)nqo! zQgWK-f?s(ernj|><*WQw=*NA=s=?7e^pbilb#C-(#1;00>V@itT*0^D#nE;+ZJlR^ zlXu((zNzp?oGU$+bU6mBGQatdrJ(hbWtaJ1WgntwmH(C2NhPEv(jDl(5uq{vjGP1J zb1U1QG~$kPb@*(&mEX*LA^{T1wdcz4E%`RWBH<)2^4qvcWB>`WIS7gqY;97Ny$-MN zDG#G~%&3=9H{*T! z+YD;&<{IhEAKa|2H>Qwz;xBoG5;CQjADAy%j#&H0K8mdt`yFVLWT`QCj=86|RL@85 zgbzc}<5QqiV28hEAPcf|bwc$b&7odgn0ZQq!ba(pX^{D$`JUNr9%11vBP}y6wEKg z>v-RJj(A>s4!DoF#yNL7zS^tWd)pOz8~X*@Pg}_L*?z*&+O^i_2^TRO++U{MmOo=& zTW4BL);E?s){EALF$b-uEn}1mVh5bTZfe1>9{eNlz}Mef%G=)e!%vZ?U#wbm$r#MY zf3pL|%{CX^Hw@HJ5zUxfa85wpVpvY-4Q&$y#J zCytTM%Lh$;kgMjEg{A?fn{s#HGJg+fRC$-ZIWzIAjTcg4CLu6nL3&gPC=_D8lFHfP4X z3_YW@?RZ9h+vW5Uwh0+`>^GfDeDfkvDhUfrg{@V<<4(8CvJ6%Nrhnvs)LB|6WX1VA zNONk5k-?z{{?|UA=ac71*xN)Vlpdv&11~7tfjyny={FSGe4${ z^^#eVck@TFtJ|u*)Yj^ajCIU1GF{v(v8I{I2=iTKfieJ@wHKznN-uK(^F;GabF$fo zN_t%FKN# z=#AC5>XK$q zK&=C8!DGIma9vm_iqa_LdyC06fOtD0{SY4jhp~uL!JCN#Z}^5Jdddzza!C`s3dhyI2)VA`pGm-^m7s!%5+5~sEIyF`yRa>$sGwr4yj}4Xc89e zre~H%v8nNe5^E(EODGrD&014wAqDvDK)lxB3h{sOSNV#zj6gsV&zvT@9Dx=GU`(@<|e@44pc>^$k* z?Ck0K>{{af&pp#!)YI4f&b7>S)6v+m*LE}`A>TEZ0Ik24*#;gN61)aHaLRgUnraQ}!iJU&y2G+I)RhfdZL zQBn-c_t-26eG_S7H}F|!SkB2~`839>htz+fCnIGd4I{N9dn5Owi?oJDgsmh*NXSlcb&TLt@@4ze#1lK&*q4+XiC|(}CSZ4sy@<55fW9E>833!6rGx4dK@a z3DSN!R@rSXV%ce#Yd#CH3y_l_2`hRJ2(>A1;GU}w~&FGT;E%l$YwQ0}OsiTeOLO2WE zAk;KZjT@P8C$V?J%eV*TT2c;@Un>yF6ZwxbrEC(=`0sePdff!^sU%#XGS!k`sb)?$wUaxEIr&w@hBG%~l&4koe40NR2wOwl11Efayk9*} zJXO4(y&Zj*ec!zsyfNN%SE6f?eU(j1ADbSVu_=Stw%E(tbo=v+?is~0@@K5HY0j>` zZfXs7foWZA>qI?yWb%!~UvZa|@q)m_Mm_}X-W1=T&}dl`>7y>vFEi_zTVy%u!u`c> z=XZ;3P=DzfvpjZI{LJ{qafvZ@WrFyQyGsviiLl?w65&y)oQm^Y6|MunQ5+zb2D`IP zOj6vZ_{9lp6YD0fNZ1kg$$C+#B{{iaa8UV#e9$mda@UBvO}&9)>StYV8El?xS|H65 zD+`l_oFXqBm5V4n%{whYOG)6J3Y(sZDSUpijLBt`qz+A28Ff^oKzKv&rvIc*_uO$a z?s!*zR}B|&?{OD&X9pwUh5ai~D{sKbWzy2nJ4U7F$as^nA){i(f9dPfhNZ1czhTek zEfpEYNt z=oJepv1SreJ$7zfrMQ8yYb`mHv%(tkf}YoksFCO$^@e`X7)2KHwZ$@WO{~VdER$lY z#0`kQoiHETVeg@zk~iLMEoeS36%a;nok>M}iR zmESqg(ZaFAaofJjIo7_?@twB|dI7 z|2yAVI47)^Y^Ek=A4X8CxRf}5Ocl#=Q(mDrF{lT`wjA0SjWd=pYq=-T+xaBjmOCid z%q6TXV(oD?5;`X~OQZ?$@oi#PnWOSfp(5%yAJ{qMIv)}vravwDV`|0zjy3E9`omTE z4zRc$;QVHbxnvP@N*?P@>w4=R%M)diY!#PtZ`kFAlg8-XwOOh=5*H~R$`W|#TjY7^ ze(y57T+STMkGnNt=?|J+*!6^VDW(Inz(2ZBG9;ZFuUKUzfgj z{4wxH$0~V%J1K5gT=_ULrkgpBJfC04mNQ=J$MoOf{4<+v$h8%|3wf~5 z`cvvErOT_8c98I*?E>{-SlXGF)2D59X4&PW)@(h%A{0Slh+^h|L-I zDyD!nsC<>{OaF-DfqipHV@zJ~j9OXOST8_1;*b(_Qw9SvS+%dhG(_=h`WXBtK)|KYx>pnsi`SxacO_2rDVKFpK6<6FKf&0 z*q>1;J#Si;A3c9EKP;(#+N*p1Q+tq$%B#5b32l?Z@#!&DlwN!psDNDv?+P{yJq>fw znOa!eO?$$jd^Qte>XKO;PMC5@bA4;A*a5NmVijv%GcRA@+mJU#kPe}ZjY-TBHaoYB z>xFveFuos*uqX1n#5|_*mP4_>;$I}Xz-6c$KPYB^`LxsoSn2MhJa?HNDejhgD7(zl zEyU8@{M6J%j+e?nTPamoE_RVbxvI%;%B>7n$}5eOqhP{UzS>-CY+anC6~XcR82e;jlfj9Zb8D?o2(D);Z%<`ZJqiuW!F) z_huBaeN8KrTIbip?^nO4{@RxIk85J^tlmJJVR6Ko65qysi1jJ&#r)iEI$tZH9gZ~B zMyP|ePddRqbc1$Ss{~ZaOA;?sLnXV8^#OKZ?PD5SRi%{tRg7{?xlts{ejpmuE3a`& zxsKc>a+nk&uSf%KvQS<=VcrpQGX5A=qQ2mWWQ%QJxum4XIi$l_XYNWzb<$|loJd5UR+sUNVtFHMV;o5~~QKcz7EQCUq_rTe0fugINa{n#Dd z((u9SXPWb#*Y3)m=k9gx5}q~?pv~#J>ZoTwj#J(8^tAL^wwD>{_W$hV z91R`K97AmR?7N{h^zFx%AG3dUPygX~8~Cb~7mg}n>+!fUF~3<$EBS?O+%KjjbsDGi zYVgrwX>K}571Twj$d}chQ$Jf$*e^9#ih+Of#_~`(2_3Wh!26C9S_`ql8~!ct;f8Wq zNkO10t0UK%59(Z(xe(u3-iUcRf80s%4ClqJi}_?-ZrN@wX3mgvnFh=0$OZf%|1DpX zcgcQJ5v8WNo;fS_YcEWvp>EPmHlzpAAn@q6nI2G=_6%u2FSzI5Ihb_l! zrtu6v(olqYCYtpjzA;wxj0>(VQQf)gpSMyQ#NIgsf3bSsji#`+oG`iMu-R9(4G0wn}W(5CA(6}#LF$Dp*SnQ;M$P|%p_V^ z@1QP@#DulLIR9i{QC}I~A@3D$y!WO%(QR_=bJlQPa<0Hk(*@_?QuYh>ILB^VvaNAO zUE2fO5yvOnKHF;W9A2k3PkWTQDlNAyw=>OOQOnM;Vu~pwO_JQ=9Py{P7pL06;Cd>; zB;boDp;CB4tqDBsP3UdCH)71I~&)*-XL)Fy6lwwkV*iSGG zG<`C?GbJfcO&d-1O|r=gb_11FIk%|_kc}NplTD}PE7Ac}?)&qXxx1t|Ifc5?IGUn= zQ8OZi!=r*f{3m=C6g>+1`uNs)Q{BB>Tb)C2B6{I`?7U^aWPfIF=4j$7;k4O@+D_TJ z+HW|suGhPNHHdQ5*dPq`_}&1Ly#}Bulu#$n9Mcc8hl;pXrn`*wWNe&tf%qRBD*4vRRrV zYW!22YAcOH)~sUG~axrYmwSa6B3Vr`uZY2?p**-py?#Wr&J;N>gJX?F8iS z;%JV@pwRrFH_$gQ4H%o|fn?zCC;0rH_MW}&ckZqp>b~zT=Kka=f`4bZAGkQzM%QrX zZ1)(~2e;eV(KXib%D&sa+1AoN&|!7$@OJU%iJXp{*9Ou9vmoiOa>>qC+?%IH4r@LF_0ElIls5rAlA|aHiUrPeSrb8B?w)4a!gP@<1Tt`p6%^ z`KT(#$a&>txuRSF?G%z?<=*lg`HFl^?kD%e?@poZx{H$pGfv(c;eqxF8SAS?Uz}KV z)f4>^xfZSwZWsCzJQExTwWJ*WvrwPCyCF@T)(+kR|(gj&bN+bj%oH2_9G6z z?W9Aox3k-9cH24I5L+JmLi=im-F4ObpKod?HB>ZO38>5CMgX2SvD{Bi7VnDr<&*Lu z(<=EqP)Dnjji!aBwQ@cAclnjHLRu^B6mN;s#RMrTF2eq0q4Y$m1%A^p`9C1MS$U~+ zR%!-jk5Ak!?i4*@k~BwZFPWrsh|&)+C|ae>Vlot@>WFJFN*jua;%UJrxP|UwQ&ABO z;Wx3gI9G%khB#KdCQ@;>^hipPt3e%kuv{4t(xnm72{EU*78$BSdFCxx=+WdxJ%StpbJo=X^K3g7>D!=Kkio=c?;!y{I zuwS;_vW>CTvpvnoW7}hkb^hsk>k0Zg2DXMqM>?vF^ol@Z6lAWm^SG+~EHO=7DGfwx zEjFzLbGW8yo!kI;-}PWK8F;FyA`xqdWuW!A0-vA7+)_u$Bdw8kOUa-LSC@+mx@coiXsmrUn?;zSi~D(VGk1~Aj5K0SS<_^?3vj~;e_z7Fk5Jbe9T!P zhgd}%gZ#`#@ed%rJId?jH)ySI@(ujkPnP6~(oeC8cpMHgjrdkvF7gr8R2%4ztthYT>r{9R3r*SK%X3uj%{ynA? zRPrcGOubG2$%p03QWa#M;RXo4{de)bSWjvtZII?kpQMpeTi}9Ei-W+h>?8IUYoP~x z6S|76@O4cw7X4$pa0}5mBg_)Y2}6W-;5e)krXw1xP)NuwWD|M`ZvH#6^k4XL!Ybi! zsB3-{@}Lh@0xM(;#^nO35%`MJr1yxnB(@bE@s;_xTv^hZ&BuJBNmPc5-=XM@$h~mk z@P|++SRi;aP%hBhztP{+KhihOM|^*H%X>L5Y;4>M-Id%0+^WmrN^$+^8UtI=PtK>n z2Jd!cb@WGG>zVtmXS6TkH$tbwoz*#-qO0^O{4sZu`P>_voKA}yG3E~;A}L^X<&$0D zcg0}VU5qNnfBa{h!WZ!6_)Yu`K4i#+7R>yZUou}=K5Fq1^)*0r!u~^zDM5WzIfkYUvJ+`-+1pj?;X!v?_2jP z&mGrC_fPj;PkHYmAdCluT7{FNX0@u`ML$DxLY?OwYKP1CNkUJtoitNgB;N$uua$ff z{XCmgQ|cgg7I)&#PxIsXPEdO&!B0hYJ4wia`Qw3*O}qmYi0(pfFdTmH6ZoZkQPi*c z@niATO1=Z%12KKe|3U_rgW|yx{OmD*gg=T3Kw+Q5U+AgiK=89&C zHjg%lZj5|~FIzY=IWjTaE9?pP3~vf$1nXhey&T#dycg^jDjF&k+7rwOGNCuY7r_BR zB~&tWCFl$G4KBdvtzd;W*3bBSu~Uu`BK2ON6qXH?ku@X@^d}O4z4KHf}O&0z7G(qc5)v3=|)Uz zV-uZdEJj_q7+t70(;Cp%w2z)w52}ndROPgNP_?b4u8kg6cWPs`nd)LSHrg>-9Xwf2 zWJI(~q1!(O}+$wGzm%=yaN^=p!r8>uO z6~QTM#TVkfVrMjrB$863CFue^!#r#Y_67DBFHt3{Moi=&!?F1Y!!~9og4uPCDC8Yb zVUNi=V3h}uPGl7`lReH_xyJae2)lsYi)a*OzroXOI$437St4794N{#Q!|p&0q8rSfKX?qP9^6ddO?Q^r%LY^~bcEF<3LvyZTT4G7YNB_4Pn$w*sDH zgjPE`RR5t3!mF-nGt_-JCAc8FjVZ}BWp9$6WC5!~|M9(XkE9!gg>u-zd^OUU zSIj@8uCNtqPM7Ei{QHIP#wIam$$5h^Wl5%@MFnsd)<$=5cJ?4N!7m$WYIC}u?P;t4 zcBHLA=q-IZJfscP4t1sLYzi$7rON4W};T*j*@J2nL*fBv8YP$QdyqfM&9WM^G4^4Imqd~H1e}k^wmJi>%gGQ0G@3$|10u~ zEv6-MLA?+ax$)7j{7SYodtBw%R&**;18%hC$uGV<_0SD;HL{L=-bK%_bM$Tc5xyUH zg?YisFsd&Dm6{^F8wwm%*$HrOUP1G*qv$EL{ZC<|;imnyEcy#7G5g6&t(+c?e5JLF zzQ#2grfy>it)y2n7SUSTcjU-c(aK_V?UM0=W~e*4OVC)|%khSuWE{2k_nHdu>cS)(^8&XonmXQjScy@;q5Au0SiFhj~P_^vPo&NBLf`I~(X6wnLyq;^n0j|%-lVWNJA8I4}!V0798r?#d*JUjFz z>;$cxQJOhPYa+sPnH%a#by~D6mzQIiQpSs@3M7e3|C^o5tz`cud)d$GP3opM(BeEB z$F~OtC=X4JJkT$*yQtS_t4*RifLScUJ~2v=6tzI~LbN!rnigiG{)1iv7JM~*4O~;d zNEd3y=&1+H$|G%|`cn5sCjqTDM5_R-Mh5N89bjMUv*-h0;qr6yjLK|I_Mou_?eKz@ zV;j?z`cQJoSOl+}3=%O$u`Phw+JpMY86dPu0Ertkn$cB2$L%w+b5Gb(+$82-p^K>t z_gMW2(Vf?}%6-C*xAH{n>=!C23C zAlbo|ev9nvAb6fHB{j$p)(#c=(nJDsYA9HQJTo0TlrBaM?uy}O-ZQr_+IZ4MgW3>z zZG2$U^g(KxT2L#bRYd!L2cD@PG%ns!3sBQF=xV(<)J&4}pm8JG0(JYR@Y#L^mmL<^ zx_|WQY<=S=yq@>zz40Ce;n&h%i;tGlCq!q_kUqpXpkIvc)1Irt^)JSHdY5jalWAkU zDflB9`WLNZ^hdNmoC13oe*L&Uln&Erw5h&WbLsDNL%XHjHD0jW*{Hu7H{&DrRB`+Rt`T-r>wuhSECR(M6b3q~CbJDE zaEcXn6nV!iF_wcD4xK3=69#hi=pnVUHU{Xy32Zgv4p=Lj)Ct;ehL!XMF8ifgSuLR* zg2T@}UGP3YaV+@eo@#yna=vl1tPl2{g(fa_KaWOJN z9i^3l-VjUsscWK9bv2Y>note;31`%y(H*D@&ZO;WhW0XgF5(3DbdP4!hUx{>ui<;4 z>fzxLPW=$A6PX&^9#|a^qWx$R7bj+u9*X6}Vd6vahSWmQJTv=w`4PZ-GIYs=iYL=zklu z@w$w?Kq1|(Y4GtVp|g5N{keKw-L8&OD`|JM2iT$F@S>K~l2N6JXme4M7)_hO^~0kk zX#Lc=>L9gI^kAe_I2c04y9d-3Ja+B=RFW+}AsZ}LuL z`=4_~;lMS6Yt8q;PQ3!uRbR-T%_CyJB-oQul4d4YLlI8z_Ffiw;Z)c*-# z;TO)xO@)^Nk*dkHl~KwtrMv7DWxf_kW=Aj|nJCjA{4$;K0iAn^8Ega-apU2A2X`0(S!I{XKoxy>C2gp^;bB^%)hW$BuWN+>wpO262;e9NORGajxCU zj}<2JIrwbca&`n;6&y&+jkJxL5_W__;XCR~;}ppSbW#K-imjwC>WH_2FdM}+n$R*~SVBtNL+d@| z2+%Y6`IFoL?ltG*qTC#C-`C(Q@I*W)eiHLb$DqiQ$ycAm7svQ~&78n4t^=^jFMtn@ z&=0g9{i+{=vejNK3p7s^^?vkoWJDxuv~1L*Mx*DVg`^ zy#q?AiMIt%I#Krr&q+@&FT5_@*W9z5qn##byB)A+b7Xa&_iqGKWvGx(o@jb(5};97 z8tlP%@n4~qkdJ=>rSrawTi+Ia6D|rx*-jBfOQRfea=8(g{@mZ(BOsW%2sLrS$qR<5 zY$}MVNrYcVTCl$X=e&w>qxIhb#Xl1a{$Q)9b8>E%wW1C`ZW z*4Ji_X@UHQ)JQxbe878eC9B!-Ksq_NDMCppPW}^lr%K4Fu99a!C2OQqQ?&7fSGYN- z(7XhO8<=G}96o8Y;VxQ3KdSnorJ~8vmC=-FyJ*kI>G1W?u3-1zoxuIT^T30^00>wFNn>;Ms+fkD~P&j zibN16lPM4K{Oy3lu=7(n9rH>t;6@*U!I01BN5_LVS`mBxiip-;Lt<}{`@%=*jj0Qi z+{an^TM9wN;VOK^KBG>OTWX4EOX(;NN4sn{uiPwcW z$TrjiBl06~c2%G!H;{WxIMN2E=ATAUqYACAXMrwq!|16!(h&kf!s50u`mKE7+=M*QY~c0YDq_tZR;;s5P^TNzF{*md5GjK zsmVQtYU)b97jka>#4o5MK0_^0md~LE;Shla5DEyb_|x!iX+sXOh1qrJXW7BA_zd1Q z1q)~a_N(`}=0b7sfg@4@c_;L+n#zli=~0mD>5RPZB6&aZ$X@xW+*S@FQ<@BI+BoT# zcovG3`^8h@0r4?%F$cxP$cI@)KklP~a06&OANj#P1Y4~Qtp?WgK=n%0h4~9EgW+AF zpFuWQ5Bc&!{?oqUzHiGZ~uM8McMk+PW)QFh|mNYaYvynwG>FwU&y`8<6W4EW^e(V zr2ixv{vW$2m`z84B~rk)d4W?RgF5FcRCm{NySV!NDxj4YA+vK9xz31GRh}uol^>!) zSJu?cG{kh)wA?h>G||+@WHA*teL$Qi$X%dM6ojUfB6UKp=sEP^K9L7Fon;tT=`nq? zwnUv5Jrns97Q#(K3xg*x5>NU6_6_n?@~-o|bNgH`TrFJrTz|Nlx(c|;x~e*pogWzna9qB+vW9!O=R+v0cR zm$M0VfdoEB4za74_C`QIp{-CSf+c=5tcU&##fI7k4`KF-_fPd5_QnCFb;>==?Qz|5 z?RP1zBhK;8LQdY93tAB$9j_fj9IYMwfdM|~nCnP%9(R3lz4k8geGDXrnuq&De^*oV zH*}w&vgNts&~9x6)t>F9pQf>>N}p36D-V7@Lr z)K;o0Ure_U%`K?w+>$y0M_xiKE}Y^&aoM=SV5)LVcQ6a8!aJp+njg`d8h#gg5$qbw z5uAm*`!W9_Uk=}6?+H&iufzSjX9%#e)!o}&Wl_V)>MrD}y<>bcg3cfx85tdjs>Cy}6L@6oW@BA{CYF$PL$&ZPsDB?eErKf547iT`!0hA2 zyl2C_@*1ukn-R0D;wzw=-ynBBT9_xa6xJi7zD3B29Pwse=hi|sXDE4xJ<1O--^|Qq z@Dghox9E5JoEl&Y?E{Z!G-H4RIfmWE3PdJ1$P1DO*q8lacMss3Vw7hS{sOM+ijXW8 zLIyvKI?g?@uEe4yG#h#N#i+cEmHt8o(k50AON#Y{FUYNWuv45z*0ArGWjGrp8o$$i z`Yi2)+EiT(Rk78P5y<8(3-t(w0}BH^{c-+lzE$3^SMx~TY97Y(!u1%nz?!a(PKaYU zvpIg+FF2;#>$yvN%llTrDZ&|^8@`2X;$ZsOsECY7Ztfy?OIR!_ruNX8m z_j=H?*$39Z3gr5l!r7!U_=;cfUOm7S9KbK)**8L&DFqqo2mE=|e+*$hoWdCB{sd4ZqgV}2qLx+?J^e1K1Vx0qi1cNyED-4| zu%$o1R$XIkMi#mf(Aa%+QSYtYREMbtqFbW-BNd?d@jbLEbRZZL`V}}9O!pTHWO&>9 z`glfq*Sjuzb~{?R!uE*$y`#Hpm}eVMaO;Ba0%O7-Bi*8H_2&8(aBIplf3ZKvHa>`R zUPTExX5eU!NoR#F;t)PB_VgX0rnG}w&I$bA*az$eO1K)&@*nWMkE?~J3iCOH6@n^k z5Mtm7I!1`_r}!&eVQv}W$p!3VdG-i+b|sjD-~cW)=3@t18?2_kXl>e=<^s2)Dji1` zLdm8f^gVuuw$T~5-_~TWu;s{nLdk2g7qQ5R+U^9XJEftb`<|c7uR+GYDb5!Mxp-(K z3`U!$UlU|5MEGbOBK~1F>`xkV4 z?qZZVnEGs8%wZK+m1&RvH?!B-qd+cPXFszp>|88l7kdcGB1wqh9%ego$hE-vTn$Gzn=uCL)oDgW zaHbj=gW&MH4B4W!XwSL-XD>R8bg&wq+3+&qiVAN)^vJWsO>|1lPgTd|0&rD>NFei}-Zve*YYGV|7 z&+9K)K_N-rZOt95~vp(pX0hj`~ORH9aZVVTU_ zHcEosS|6vO?cn+jKz6zq^E+~H^T8cogj{wT%)u@nFIU`Sy7%v0tyvKVKXg}C4H;1b7xk(>{_(@gGa2$}OD zOdoKOdx4ca6z{i<*~=V4W_>Bxjz^dq%ys4(a{-yBTg+o7mHEjynK$_UCa(A%T&1IU z{VK#~5AJFnqSYR=Pb=JSKiqLE@K}p7Il+Cd1}17&urNK~z6QXR^@3A=3T)D4;2!^N zj0AJNwb2CYWJ7#b#lK~-YSu?fj6*w2!N)Ri*&m?#-~cbzg-;f2*b-pMmI4zuAMU?0 zc&{bUT5*htLHY??UKjY8-|+Ft_zq^g4gY^Y72rFr@E#1`yWk*u4ILb0@bJ+~9^)$$ z?(8~x6f^^je0XLD{@RDtrvw|ONnuW4tUSXjt1~&U7jxr%0!9pzh#B6CuWuP2(TXBk^BP|D5ghPX zT+a>Obu3wqLeL>{X(Pl5vhWEe~hkQR;`8}RQ0>3{C?#KK8UYEqQ#vO}z zZa-#{+Gwx3Xu%qYOg8lQ>bTQC@Y;4zK4^$mt%JL7^#2HER)5OkFBQZ&b9C0kc>4vV zfW2tJZ)mS47y~~LpXF$+L1?3A7#(jgJDkS4xdQQpV*%&6B7G@p#5O$oAc@JHo`|&TX&=irW0sib~qcWawD0Dd9(nYu`rAwFv zv>3AXH)%3E9jzSCm@&3)V{{i`b@~f4^l=FKR9#0rU9*p!N$PR*P`w$ME!i%ojJo1OJIAwPPw{_NU+q*236rhK$Z6jF3!ikuVes z65z748xjA6npAgY5!p+Zv$?SjwVQ5WLGSI*KUy`mtX$Z3zwYo!? zvpeIQdK>lMwqWX3!aixEVPR_mY5o+w@RCshZEmF+Glx3#t`!D=*BjGFd%?v?!Y0wo)h2%dqdP)4A+F({Ln*Nt@vsH>()>KK11d1*=`2)r`a>y5 zqnGqmnoiH+oYTn|sV~%0^gYn%{EXUiDJbDkwlC(#?cx-+n{bu?mrX_v{HCx={4Ni` z`TK%6jo&Y}HND2W-sAhRPMjkpp(^rmON`0b&39$5^CQ5N`GnQryU~h2i8$pWCv=V3 zO#Wqak#uqm+_>*hQG7+3@}qz`5l9pj){Wd1>|l4XEsRUl%nm{;9%WonT^p>kD%`QO zRgvOq@#wzbK;Uwkgoi}VM(zdN!QJZE;BM`FsJ%a;b_+h$x*)&1BKkD^faX@S>+|$j zwSpnytZ)fq7fLZeo~9BDSAaa?D*(qoj9(A$xERQW)@(NFAnl<0FqwsZ0rZP5 zkT|SBsjPz8rw2yE9z)_$MZ~~s zwKyHFyCXw!itMFV7fN%KnM_Wiif|O1^KN`mQ!g-AI!kTkd16%&plQrQve7VTUhb&QGdZ~; zMjW)SZqZ))RPqfvwPo2IaEH2KTmvV6p|;i-uP==Zpw+{xLaTwTs2MWR;-USXB5;CV z?4Pau<1g#D?w{z53)!{Y{`t;sp&Q|nfw}q-%^SL<_SD{~8SE_Huh+sI3TX9?;&8FP z)J{1J6kC1eC2++N<%nsZuw2?EQoe@iyEL1h#LW?2lM~!`VI&#O`VVT_6=TgprW zwp~CbswKA(T%`BLE1?ZDk587SY7@C5e5`SwFDbB42g@$BGs+uf#7oRTI*XIw4$NM;*Rfp*o&!!3EK3uA08qzFE#|p@i@a&pB^O;9p;X;0vulw0Nkcc0Q7* zt>?-B|Fj5fr{iR(^g!+?W`RRS2+kQNOe0K}O>@o7WI6!OZ;qdoh>8`1V?J9pdv$DpYI18`*8FfkI)bZzqhhex-OSe=8Jp8!ZI%%fa~8_ zVLxA5+zV9O?|dIP*X%_dL7}(UUQ7+6HmaU;wIOPI+EDwUmQ%My(&6{iC)zxGK9n;& zB@`E#5;_~094zN+8*~N^c?$+VxbwK64(t8ky6O4DJHj#Avn`a{|H1RxKi5A$yih9u zJ=?lapUrCY;`<5}`3FESUlqN|NOLu%jM-}*ru=3eU>Rlp+w4_VBQssZG#r_(M5&bU z5@^BS*!4I~w=w4HNA!W(4{*a;s3oaW?*P5KL>5>g?hN+}I+4B1o202qvL%OUxiuWy z!}@1zi?~&m)z-za*JAcUA%2bdmsHO*L)r+wP+e{u8>PwEr481uYI*cKk-E`)(U|av za2qg?p9OP4V_~bWUvN%nyT4`Nh`);8Rard1HN{Wcdci*)9PO980QvU zi`|bLcU{T8OYSCq)sx3JIQTrABdVwkwU%lQU<(UF`}3;sSZFD}g~Gs8%W^Q()>^V# zj#(3}*Wh*TP+Ur3V8;`H);=!$&8Kh-vIk9owXIFZ0BbvrNoKD^#?hwCb9x->MN76Q zA1|)tRe4lQ2~%g~rFDuW5_2(OLEPNfJPEgBSI2CMzZKsZ&e)SJ?^V*CDw#PPSy3b}x9cDl zlG#ra+u6er_Czy=gh3bXKk@*G&ygjieXc0(LP@9YO5LhwGKp zj=-$!H+qL}YbBWP@ck-<>=)0i6BFR~kY-Iar6>YC+RMf+OYE2MDXv4}@`MZVbK}b= zWQ*-=30m`6y33!yGU$X_c6Dwfc3L-$*Fb-4(%P%_;3_*dd?wU2G(Gsm``Opdd)nXG z|Ji%aea(^AbIqOZ8tAIxE@i)+vDcZF{?zrC?O?`oTaWYxHYvTg%i_4?y5OCbvD4## z25z3{LEtF+N7~X6#zVdhkQ`rxVy4kb4y9tkzPMMhpJOM+l!^Tr`+MT~*g|nvtZ(40 z)7eyC$ts_gLZmCZ5<8BE#su0xKM5t~`q5mBpdSyvHnQr!Xj3u}?7j{{J^6y*wFF~y z<$X+p1W&vp{%lg~|8aB<@Nukd6d%ueHybBS)71Rx-rB~kZQC|d+qgAu?bLRg+K#QY zN8joF<(D)`v%9l1@65dK^PKZP!Hq(~kpTD>@3^5k)^I`pTlbIFpc=zn=Lx0@+_Sxu z_n61!ikiUNKo|dZ&tBglUnlTV+IVJw;M>7f+P%uw*5Ol>vUYbZc6!_nv9YKX7RW)w7sAeOP(4&l1E;2` zuC!s4`AyI+b3%~DlowPlC^+bpX@+UGuBG;}`PMRrenLX5EI3VZqU-;ts;^su%MExoAr=TMtEhF}XRtUKlk`dA^xLQby zP`7EDe!Fh8ZZP)ub(NNRL~vxVTn;LqoU-13aUo|@2G~F@X3)73BN#^x%{2PNFnCb-8HO(~|jYD*$HJw%SxU;O6 zsYRuLBi~B+5J-eF`LM6JFCkD|>FHkK`36PP4Dv1d#ae1Jdl%lz+rhnqNc~^t=fPb= z{o&7|?uM5QGliE3n-q3Dd`9T?pi#y~IFA)|ef0_cHdU3hf`+wUz`e(});-wevA?w? z*fowpNM)Q|^r~<(z(QX@slSll*OH5r6RlOrzlyx~^|?PXmgd#Vv1NbABQp->{VsTG zZ3@4_EypfP^+2vzLXuIF4v>j-O`Xah=`-LhHk!VgH-`QdDune685a^4d?=`)*>8Gc zGU$VJl=>|Dk!``4>8na-`7#pQcl-B4ZC%Yf0qor0juh{BMG{*xTUY}9kh^>{?E=HN z@RK25O{+qSM|6o~iwAhQqPrg@)_Z{ zZ=~Pix$0Aw)Yj6_BpaXW7n^qmbv9l$2Srp0 z%LxA#cF^f>kaIx| zjP>-(G{;qc^Vg}lid$UZ^SWPo#<)7$t6H`fw9a3d7o8uKo0@ecOPAd~JtOT|=EKay zj74b|Q*NeROB;~3Keb-ECqtKUDYb9vzU-maVy?dm@~kTRZfA4p2iZ^6omoe<=44Iz z;HJniZC!7oAi^LRF0;o8{)b~`g_1VEh%)nyv!e>S1M6HSnj~!CVU7|(t ziqSWtT1I{d?-B8D*b;M`Fa1b!W1W;2nY||aZT8yi z$ZR3=Tv~k6z|@nOMX5!pZ<6zp-=P*F9H# zsI0qj;p!t?HT5R#vXIfCpM%>*wg|fu{5yPA#Po>W;ZH&x!B_u zy#;ZtjQB1v3&}TC+-n?}4wEa<-4+_zimuy8Q=CQU)D1Kpbx$>Ov{MbeLMz~iJ&kx2 zJwG}-s(9?!*h;ZcF;ZlFbghWi!Mj2ZhTJy}&|lELWi~@0GJ{wmtHqM;lD03lebyLg zzP9EK$|5skGNxs$%3POvFL{6J8l)e`rL_GO_}wFQZPM;vTaq>el$4$=Qs|)l%)|2lBg(9V7k?sT4{@2_f@>cSL-T+8og+d=rx2#{_NAx77Vl z-(B;BIl+FWW(!M&Pr?Y{p1&vF9qSwp|9YNJPjwXw-E~F3bePIichGjzi`obsW6TX( z7Ojp*jk-}xiM$Y1RIGgQUopF*KNLF=bvSeoQdAtluZ)unTeXqQda@y^M_EdHA;@EN zOtOxIvgT=FSivtGw;!2nGDc^3QwuXvQXgdYNbQ#-r*=)MlM$TpGHp!8htww-Ju+tJ zY_V)`&v%ZoPPb=xU&*z|WBfDqT-7Ufh&n5HZ%BMdt1wr%7{UcF3i}>jD_k3LBQ!tw zqEWA#uDzlAfP2eR-a#K!CW$3cCwL%@6Q>I`{m(q3gaKrAIZ+gpE(A|cWS#swwXCk8 zZDr^dyx4r$v^eBUXw9%15#Phhh93>t9L$Gy3)>vL!?0Oz)l|b=dIEQjx&fX3KY_ml z#hc?Q<*Z|`WtD7KZO?Ik=uq&>GOVb5;e5-$yzYhX3do{9Dl%HF$JAXjo z$-=J%g9?fl1){x8HHpdE8zS_mEpw!75rjmI1iq z7g?&=_SpN{yV{#N);LSLE4v$l_Fpp4*WW|vA=d-LHJPYQU#1^|avu(DL1TWDs+Fds zHbx_9@{o&IQX^;vXwPfTs28gU^?smKclr8!9n~h)R(=?N zi66~(gM#cdZ{cUF=BY(hIp}6es>ebZ^;q2!+MfNYuc{JohmTbm)XSj=?#B;CvU)S_ z5xbR11mE6^zVB4(e^gDfKO8(?!5-}g!c8~$6f{tDWg&MsUK%B)iOs~zLU|!SP#3uz z>HgBbU|)*2r}u!z>}lp{>?&|IaK*U3I$JxpA~lg3!4-dps852YLtmkzz=T}R1hZN27zW@UtOK=1U3e}_!C5yGBeL1_^GUDi6Blb?P(PhU&%Y<(jUV*Xkphl{ku(G@C%u$OV<BI6nC4n@^(d20HwxSp6AZ?)??(BvB8G#akB>w{c!N6_razENi-k-kt z?hs$RyM@Q^c6vO{VP2i{I+B1tdyvNh?dKd%w6~>iDO`sWgh*kym?^%Jyz&I}Xh`A> z#`-YKtJcsSDue2cthi$AGG;AYE}NP9sIQG-+Q6qh6g*cee~F#VMS-mINWDPyh3}+J zWdBkR=N@v&>hqcf+s$Tde=J_uL@_dSKu=lM0hO0^7D~`Zq&r;Vh_aT?oJ=YoLvO0^qwm8GxI>$?= z20y=WUG@6hExjeZcRWvh^#ZMgPVi*akanXJpMak8J+d-XNdxIzswdqOE`;Iq9O!^4 zC?{{yU#MS{p2?=eK{EBwKmUN5nU$=QNyFN@p{gXj=O?&A?gLWKPw-vT4^`)(_Nk{X zWG|_0sGlGhmAeQPMmhQvSAp3HC-Wfawcnw$v5pukHz8Umwd5gz8G_X>_)GeS`F$S2 zGu;#9dFy%z0*KR@=ep}Y0axd9M;+&0M`dScC+$)>-Z_VQCV8uR=lQ4lHGzRpm7i7u zsDf`Mwvtci^%Tiuu&==&4TUyzE9!WY=s~DN%!jA44l|f)Kwbg6lAr}Tj=s%SXY%Ru z+z37&^@~cHFilsrqWZ1rpi9&JrPFCwf*G|^Wl?|B98iBy)d1<`in^HUD6SXJxnv>^ zDN$D>i71Wr^B1ysV6EITa5^yFcgtt>#0JW_4?EY{7CYM5df0Z_R@x6(TNTx_MA=v? zWeKzF0h@2PWt-!k>ymo|$OF{_^8=N@%6F0<=}3AI69oVEbM*?%R&eePz+ug*Cj4Rc zU^TuU6}QFcrmR;Eg8KWC zb9iu*pnAp-!)^U)-30X~_6YJGFH?umDL5h&`lcgIE7qIhk{wGOhpitA3&9X+UG%tM zL;l(P?ELiH*xaLeo%62bjmg`e*FLXK_RZYg7K6=SRL1kt-@)N?b`(-cJ@c9RpxUh2 z!N${#bg#_eI^2JZm-QEs;5lD=O7n{uMF;V-=*GCm&Or)Rp#FVg(q3(#kN_3359acNCV^pUwLs(Ym@Q`N4{n~@NZMxBX62-!C=y1omEaGyqh+YG(kJ@k<8Gg@n*Mq;AmujyX${M5jMDCjQwlk=|5>_@ha%_89XM^Jm&L+ZPR zyx{M_Z6h9suMfQt`8H-{iT1@z5!0exMf?htLKlP-Gw;+**Elo_`9AboI0*!)YM_2# zt2fAb!tQZ|*+*Dd$fF9%=I+i)&lr>aDWgW#`mB@LH?r4cOwVYMKQp&Y zL8HPrM?Z0$W4G;wTtw7iwsUq(Pu)c7Hy42W!e`v7nyDGCrI|F&toy-lLQkp?U7t+B z43+p1ahJbBfR&~bFZpPeqFSqq>wX(YgpLTm8Cf+nGHh_M55>oUzE#)YmocxXYNSiv6WAF@_EvVDbS65h*_v7I6)^?7k;SOW|CGBivvtPX z^rvYn)7NLe%)XEwoc_P84Vj7gO^S9EBo+?!)D=RVJAGcbAs&KXd|PMIJYzqp6NB~! zx7Kx2mC;742ebFpVY)n~j*=&DBqx(R+(Ms}no`pMt1RLC+R}7W%A#Lt3^hyPZ6i-a z9uKJ;*}O#c5@llVMm36j8&w#d623jWd`Rh_&iW)Rt2@Y#N0yP6swZCv6#9mGbDdS3 zb)9>xuL}bCLjJRynwdgwt?X~94Kq4s=u(4HI%G=ep47b=l~YHhSI=vj`zG(cMeX%C zPFO<3Qdrj?h<4Ye3GGY%1$}MwiT}v{osmwE4-eQWBH5q?psn zC-+W`%1X&Nm0cp|W&U?}G1oea%XeB{N(`nS@>{fxL6GTdSOyQmh9xWXth zYF*6h=)7oK#ORQNsG8L^Mi`%}7SluM`b44FUHl^acIVrgIzHGN6sat~@(np7vo~gt z+4l4g>3!3uCU^MdN_(DMDWz;$#k8@>MXC8&|K?21-fOMrzFb(*@m9PI2VflAOO?vk zWXfrK2DJ}rVu;Xg(($@xnrjA~(ToJ%61YDN<6kom!LPIu3zZb^tmzILPc=ct;vthE z`d##eim-at*VJqTZw#rdA% zwyt)sy-m^Eg8lhFvO_ZFW~O8<%6O92BIQQX>6C9tEz%mKBGzLTa}{f zfX|wy`d7D|%H(oUcV&a8NAlrSLTq7sV${Vel(58>i#!w2x>(uhmr>afbg180!Pv#L z#n?+-gG;5oibV>Qe+RC+irI4)9@gZw}X799jsqNDONez>zw5Dls zNwFyxQWH|_sZTTOEqdiKGb zhgr__Y3Ys9-=rK$3(lC3UM5veZIkvi{bq(Hs~{^pZ+Gs2qI1@PwoC5gf&9QYWiTkF z-ROIKyf$53S35|@Ykq6n8QvRC=r@8ZcT5|ty@2ZdZg9Fwsv2qv^;6ZQwaraNL)RcV ztW#L2u$tj7Bd^^w_jLDH~JPr~+z)Y#RZk>+NCTV->qI4&a ziMc|(LDNG!NHbUWmoe2)&NMaTO9DgRV43VA`X;HlT){^7T&+Gy$` z6*W`$Ro~nYY*?&6r(dBTuJ5Ivs_(D6ps!?HV5(pq9sDD>Sm>6pJ`o)w7e?NRTpoEQ zaz>OGSuHX*{CU`q(1_6TAq~wM(|i3H?MY;7vTP#xMy@DE2V#7Gd2TvgwqF)!;k<&R zJSI0RyIW?bjLqpm>6xiVQf4PJ$+5{*ljkPSPX3s@JEdycMqKR{<@G2$X4ztI=Gx>N zB-ED6!wY|#DZ|^;TeK~7_4IzEd+W_N%pJ^=g9OtkP*S6U>X;52#~O`6Z19kf@}bj0 zn}kgYKOC7J<&0Vx)hKFlRN3f2)ZD0r5$f>jVIxCJg~XeG8EfgAX_u?}@ST~)q$2eZ zM*CNLpSdjdb=GA?p9_-mtLG(U7iE6UxS3uPdp9z5Me@(2IZ0EK;*;(qRZ5zhbR@ZR zYMJ!yS*3Hg?%_i7dszA^QbI)>`P%7+^w(<2ij<0H33PK|sNIWMw*#EG!0p$Q?Eg8w$( z0nv4%_Mo~Qzn^&nkM|J4S;`B&6FR> zr;>LiUr1h>dMAFd}qPs?ykNzIjFDf{y zOQbsDQrM!zWpqb|bZG>Y>!*sa?~?r@zRknRPCwB64rl){FK+XSSzuz%KSv z63O<=HEy;#SKCs*+CZARfqVNn_+GF&_=5Sd*=E)UPY|VdcW#gi{gS zBW6e5i+mEfIkIHr%!o=6kHT}p282e2>@*jdb{Xp%-e|k1FLPORDEUR+Bc}VidJXO? zj!U+#*7x93{mi%J`Extve$V-y{UNJ#*5OQR#@USa8PC%Pr9ViI&DfqXJhNO@uk1;= zbiwSRW!CNX(as8<&i=vT3VAWni8{mdM!l?oHd?3Fmo!W@?loOBy))J_#u!H!-y7?j zbfzc98K!+f3Fei-<`6EVbcjD>Pw2DId7+shNx}1j~X3)-1-?}_ki95*qN4d4CeP9kY)wa}a(oZ*( zH4Fq9|CcUSKNKHky$*c$3`3^zUsIkbB)z&;?{zM+PxLvFIhe zGu|^~=pm-f742Z-43tvc(!mVTDR zqR^sch2e!)3la)O73c~)`P~Z+6f`J&S9q|34DdQ94LgOl9UsLg*-KIgt#)e(`PPngJ(>>FCLr(Zp_<2@q^@JdjsbNS8vZ_ z?@<5!fM1-3biOL&L}~_ohiS*9bN}N%gIM=nwG0kB58QX7z=td1cVYFK05;cKc(7~0 zkJX+JQyoxqAja?0)zUA*JuuU7%#dU78O|8e^^~ENzPxS$@;f5bv?>6{>KA$~brJc% z$D|n0kJ71RmDQ>08S0X%U0hGp zZw>THss)s5)u2c0DV{}I$Si+vU##!Ex2$)W$L_A=KIod@dgn}W^njnRq@#?(=s0MP zvUjs@v-_c-dFZsds(G@!fB9ZoMqGiGqnW`PB_5IL$< z;hx`wiqalrwXC5MsjBcLp)N^}WR|gqxogOYpQoyzZif7aPZ;M1wP&?IFk9-V+oT<= z*{|-QcBxu}-9H}Hs(zM4Q>s;evJrUk2zV-gYfdQgc&Xso~m#h(4jXF%%U<`~x z??mNl6KiE#Be(rCU52RvK4~v{0A(X*Q`ND@chXBi-x|PN1`{ob8^C|#hk&Iis?8dM zMuX&z%9?nbz3 z9O_DPFZCSs-hw{oyWkZsLybTvD~Q$P6x56psYKL^c7nvn;tVUzR>QIAz|KM914X|E zzhVOz3^vqVqmXMb0e-R}%)g8dn#&GcPxN5EfwwjRHRPRKeWZ^sL@G!b_7_uu=}y0e z$IMCAAQwS@+z9DSqoC~ch$BTGa?T5dcwxNILYOMt71UsG%@tk9fzFo}!FR8LXX-9^ zJyVDT!T??0T5=tElY9(^*Dvxk61Td*A!&uf=K_2zCS>NlLWST8)b2Z=jGqD({w+9p z3Xno@4~lm;QuRXNa~Tcq(mA3K-XrM4p{kuho+8&n&$)oyif24PE++fIYbFvp_>DXc~vH-V{ zw@Cxl2~|WF_;b6_7rPB-=@Xp49jWRdu~w!=f+tc8-ndNU5Ot##QB$eLR4b}GwVXPQ zN1Q=s>`{FEO!=ti)IZ1|nF_*MZ6wy*g|mMf*keP<-|$Ku`_D@?6Z+TZa5`Lq*8d&c zXyfp)0RDt`a66o$pk18(!jgG(jTgN@&Vs$UL|>g5UygM=XUWuqOGG zMCBVixH_bS4x$RK-E=tfK7$(5ft(0Mq#oY0*YKo#!TWH)ORx=2DUeB^_2z-A_R4wK z|D)j(%mJ0DlhPkr__<21{7zm8hea!;EnHR)Wt9?-?^Vh(KY56I{7F1|~Gv_P%` z@Nr%Fywpj$D@iz>J0-1@1yxBjM#nR0q%>F_BA1l+z~y0vKh8_EBX^O#;VX7Ns-7L`AqwEb)8IO#g(Fgc&B_M7{t03MRMvIyhZ42?YTZsJ_l?QQUxQjFR96W{xv3>twysyUd_JwP*3H&8-cqb$r6Mr0e zY4F`V#MS0J&huyZ?g6eEr{J#mimS>`_;UWZvmRkgKE*a(!*{pf?s}-)#vc3#7f&jD zJ;`uydC-Xq!F~zFbB9A~_{aI&6kq>`k0@NDg5mjaVNAcqV-n$>dyeNli`P7(oW_;w zBCb`B@%1K-_9-}Y&f@s(htmBh(qjKOnvdZ(7w|VApy9QjVjG@fOaAx-Ps9JS6W6DB zxE2>Ho8UNFjjwxfK0e2LNX9;O;tD9kX%mCJR{lR1+n==7KG+vMaO|7G71q-!lxsULAn*sxO|aKR$Zn znR~#q^Cy3{A29^K83@--cYGayp9kUhL-4nQ@xPPsbtLh}yE^EG{zWdgU>SX)0Bgkt27e0%_ma-1~MP2dC ze;lNvF*0W1k)yGl{qXEnM=&aduxu@3xnT`AawZZ!x`kk6)FqIU&KB(U{Cq*n>36Y@js_pB98E79J!tN zSdUcX?Kq-4aD4CL<10SWaCH3e2Rg4vzY9JZ1|XyASE1>u_xM zz!iQ8qvR~c%vFrW>-fu;@OeGPsC|nu@eGf;ihJMl|HjiLjNaQAE4Tjpt4lb;u3~f@ z!85JG^X$b)*@d$ag%y+r;zbJS*ToT?7v{R`^sQ! z#lex}M-nU?HTe0DKdz@%f5~HRH9@pf*QC1gUOQ4UNid5!gczEB#=kp!Dug|EtMq@iF;C$$g5!D91 zsA?EZrSMxlW}4rzH{Aa@@owM@JqF+RR_yDu*k`M8uAaodKjGR+z`1)2;~@e2@6UDS z5`Om*$MGVL^JV;e5BI1mc-(Frv6J}e9-iwCj`}C;=}#CT$M8Ihacnl@5$Ex~e!x!{ z2Hzb8f0he-I~(_$Zy1N?aNJj6Ueq4Wbeo(7hj*j@t|}wp2bu%7{{?22dMmTSa8(K6 zhOi$Mk8U~0=ij-P3In?3cBsYU*?ZDj^0m|yNsE5DADm2YDM8ss%Sf9{fUC(%4x>d< zL2~*aX@}g2py5wmfO_c}v8TT=C=CAwj!Cr3#3nK+8P+!{Z|0}R4km7mm zr#wkW?CR~y7K_QdacqO6x*%BfV1KJ_^B$GRY@)1i{C^<9hydffwDetQ5m+2hi@$|b zkQkRi8<#3yR#enG`Zl+Uug#0#1WjT3F#&cmh|p8%hU80SuY4T-celJ-ZY-xtgCrxI zDEFYChzHR;0U40xkYYK5c#Ap2LZn{4!*LvmJ>HuDl9*VIYx@M;I}>qSSK=;r7k-*l z+?9Ic>OkWD+Y+mz@nkb1kC={>+Dwo&Qt7X>1QN|rcv1*Dh7M_7u!=t5DHk!M-Mk!J`w8qcDN8!u`&Zf9_i| z;hbHBF|b~~Bd01`g8j`WzXqRFf*pcSTux zC;iDoZ-8{P|$D~n? z?obvc`}L|lz8JDd?C%}H((*NVt~}4XTRb5wl~V#f$%$2Ak3enVv6Q6T1dAv@ZG&GS ziHoMP*%xdZb~-5!#`Yn8*`zT=OQ%L~cASZFa?`v+4gz6c^k>L&S# zq0Cz02Xh7<-vM+6Uyka)Z6PIbGTo4|lfi5aCRdJxOY}2Q8q~54L?K1dKS5CxsM%B~ zo4}UhUaCuLp!89XR6S7lRSo0)+EvsjHjJx8EarM3$@nt4Sz0PD?XbHH2?eNbJYzzcTwE`F2iEg!bj_;7Crgy3Tlf29~Shyjaa9r~&aZi*^Q(FW3 zd_m5DZG-c@bEf}>yTAKA=vK;Ec=( zW)}kuGk7+~kc%`M*Z3@H0Ng*H<-73UMu5#W3ryqd^lmnqx2ba29XPr^PN%t}>Zn<& z`N|(v&A}R;VJA^pIJdqre^W8=v+e*TYk-`M^R>8i8*`B5U}O;J=lo`Nu`%2_^((Fo z^BHLacR`+-%gE3nwITD7nB7|5D>jyHiJL{UxFB%EKOpcp&_XOGok#7;?Vk^V?l>`0 zGz1vmLr*507VkW-eZ^cB*Ae$ccPHN#&o*ylpqqD@Z;r2c;H&q%H_8_&kp7?E@vb6N zC{B6u{5M1eok!988Dvwpe2gfLO2m80Ovkd7pkm9RQ&C4Wg3I58eL~r&)8sx{MaP3h zUIi=RnN%7%0(B`Dau>U^r@4FJ14Z*ZYL2)0t)Lj~QQuMZ;Y+Cs!6tglc0$do67zyy zPhX&J(;lJ~l|jE$wvY#iSdypD5Hd)I@#H~jJhImWdMOB46_9+_9w`JhSyZB!({urR zDUC=kIS%gjx=JkQV~Xe&e~Kps8OeO7h2+2*;XHg}eL%SF9Jm7Y)DB^b&{Q~wxZv!- zu)sd>HBSfT1{V5P1(x`>d*6ZP*Vr@Fo$M;>p6qPusO@+Uu9eEW!8_hJ#{Ua7f)-F2 zd?V(PYpMOr4-l#ZR9x;O)uA*uop3UUN%rR)?rg@C&#- zpx8YC`+Wm7nL2~jb~PjH;s+`8X?(D7=mt$GoDC(~D?O zph3f0fYo0VSfeN9SUFs(EhY<@;AA%N_wtSLWX1@EvgGLqcPC~RMMWY`o#vcY`Ro}I9n3Hn>$25qrZUsHIZ2fep(9BC7!bHSvOk`881ahxm$r8lJB5FrO@5MH#!Z9 zU2!Cq&z3iWOrHt2OpU-2|7>tjD|=^p_JFH;#YMt*)Xh21@xxJQe+jqQG<&FhB$BpP z+xOcx+cRx#KwI5zZ{=v>T<5&uN_B7d9Pn=Sx&0FYO~mu!1bK!s8G6@klmOSX><%#0E;BmRJEy>X)*R~1WF)tJ2iGeH zxeBwvmzxPPRR)uW+@;m*FV@3$MV3`NE*5mhvp7RDkU^1xTH_q*GHC%BC0TAM&y-Gx zg~fHZaQN-397F6~>_hAqZ8^64w!6q~u^=PqzD==x zv^}uZbcl{(F3#<72Rsvfg}#hHYw?y8svH0-pf(iFC6S{t1}Q!>!4;jaYK0V;{vaoI zglf4jSDaf00^wZt1^85n$N_nSjEMlX9*Jvbp&_o$lmL4yht@%bIuto1{V>yPjRY1C zI|=967!WQ8sK%>!^>@`Fq^;#5-DC>Cn#*O&v%%1^kn}H{DF>-WsOW8{vZ!_RBMy_vfj4%i%wf#S)#1*mR&_R zEbXl~?3aAb`H-TB6q5 z8Ryn?`W`g*Rl#qICFfyP+>9{8OI3g@u@{u#j|Y*lG1b`H;9rv5Rxt2-f|t62{f85H zuWGb<8MNayRE7LFFe|I7{z0Zv1izW9&-TL}GEncy?%3bwk-3}&qI@WI3fk>*pdNN55tV$ZIZ~U&+Kt<&9iJzAl=Ck^XG>v)1`If26Oj zw}fY&>zu2e>n*N3ZNcb0WB+20b@aByS)-wawc8sa-DaGngr%7!$#Tj*%hnh4%ko8` zmKnBCm)CXGHNbt-J3weIw^R<0tHBeT8gAA}>%{$cx)X)2}!A!gG?;vcdxtuRNy9dxjpxOmkRe9#z8p~hC`SyZ8r8=TY#EAHTqp=&wNoQ#vb(+4x3_;EPJDe0v=)qWbgo0tnv02DY z?Tz!Z1Yd@W2KVVZaZ%~hB%#0#!&@VAkIGDzSwrgo^C&6pHVoq=yXvFi)<-v{bjjX*sox3UQ*tt zf@Ri$p1*xwf0F;H@EhFq9I6o;u0E>ytUaLXsr{}x2xmhCxRR6jHy|2}RZ=AfsOH&$ zn}R6XL}Us|S~#4K$-T)8I*F@_%&l0Y4_;LhnyuQI+LF3Dx`~?ms$<+`#zbEv60jbd zMjS`qZYiAva`Q22C$oXQta_;F4t4Y;!Yv7C$t+%|8K;bXiX&cN3Szv8pJ-k>L1>YtIzdL|cJXXYXls6g{wyDhzhavc1S3oqVx;AlyRav@KNdn%0sm7wW+bTg*Mr6TU%A#k{?HpBeI~jv9g2cSAn6(yvcSC z2f0x}YGylefDAqkxbp_O8-GbH>z)Rc2=eNC88UTUv^{i}^d{X_%>?cZQ=g6?R!aSm zlo1@5BD4bk{Fb}}tKEuVeHP9i?5T?WggmgONyQ<)6WSQk+v!2hjg`DgO`E z;hLx{tcBlY2w9x^K=h$iObGpsUCc)!32rlzzUHe9nz`x$>ei~Y=#%nnEBNy}P-{_1 z*#>X#Q*o+r$zR%U@$SVUx&T?ubKMpswWT|)_Q$qoP|a6xyU zr`#PmiRmpfI~7zBD@Zkk1yX0KFA?4nm&JZaSK% z3QvrB>*+T4>z0RW}Upkc%xFhMOiB9hy%1Xx(w`G|g)M zDf1N!iFE%le_QWi&qMz*A=UpNknZd4E$weFP3KywDrn=?TQse-8&r~3um7rh3m?;B zT`Bc9F2H!m?qKmO!AiwI_5cZ}ugqhnjob(7F0}?Oo&@eQI^2)=c3gY(?YC$OwarwY z)g$@+>|NwlEo3*+HJMlBYvn#V*-M3^{@Fs2@FMU#kmHN@boESk-S;%}*Ym%{yr-Jy z)}NlW>$!WB`>pGvGvA$N3$v!=$+;i0_hzQ%ZgYNPZU~X`Ze%O{RNfGg{2fDV&{dNV zG|)6o?NvYFvxrLcXnrp5;5!gp;JIhL8wd-=&N(8R;$-&b9LJc7j%6!$y{T4kupq7475Za@wCts*~XKEMo5nB=WY~ujJfkp z?h9?9PjK6rL);#|hPpc+4j=SG?v>^$yO#W$`hnSKD7%8b2@=abWwLk*tCNq?J)$hR zohXZb?=$i;JrG^3E=UDNKM~ijBxX4~i)lg|+49Ipr~nSmJmL`Oh%<#au?A`$2C1~z zPe=${!~7}E$M|wQJ$xg;qHuwp-2}|+(e7q$tMiS!qBGq2$&qTmVOeaMQLw&Xt-X$~ zjZj)Xj2xwh(r;vNbi(~ zwRQyw=M_Nm9Hi8c%E(nkR(yco(;zZ~z}NjxO}5ElxiFu#fki~;9=tADAltp6zPlwHt`Yr{21s3D+iQs@fk~^b{*piANIQk&g4a3e+Slx)y*Swg{C^2G!i2*vIEl zuOE*t;8D^9{_!jFDz5s2&>NkPj%qwQ^ixn*_<_D@C2A>UK?gAe4wlyFrg_oDnu-1e zXw0a~9zgZGvCZwdf_D zA@`wrd6_(iubs)k(5#B2gS`Se~D2XGlpUUY19?519@4BB+kekmB*-=sg!W}iqb_&Aj*ro$mZe?qNUVMSt5^@+7nlV zF7VuJA`Xacm0?g)tdXaPqm>+?j&e=hDG!kX(jw`WISl&>I4y+b#l3bC17qa3DC|0kD_ z=oeD4L_ew~@f&7JXlFkd26$q>P@=e##@& zW2%U|*h_-Myb=;i_iJjPxrAwHmw3lcR-jHND0$P_@2XqjSAg2+hlgK~wM z^b_GZl_zc>Ysir3Z| zNO0LU!0&llIU~oBcBv7$TmI7(mXY?9Bkd#lf__I3GtniUh^oI1w9G5YAoREzD?`Z$ zqX;l3cD#qRuIH65Td*lqKRd z(V6Ur6tkO94G{RPikyjl@j`S_-x25p!9V{1XUkFa9Zq3i??%OZK7KnMuWdl@^AY+F zBf-DBgT$hSs0rLd-{mbRn$6MGnT#GqZ*(SJgK|6sJ(cTVloiR~qsraTaj65f#SP^J zI_F(*waAo*f~h+Kywg@t6_iqzD|O@?sA2r_-%3w8SLq|)Qq17QRFbO_qvTkkk$eoj zkUDZd;tv_4(~^Qv@~A(JkeTK$RPO;IaFSUn%Ftz8R3)XVr{h; z9iG1E5odu1*d1H#!jViS^Cc_ULAi}9_*11MxgXgIb(HdC#Gk8;OoAp*g9ygnzJ=Gi zk9X~YKX(#Y5thZ#g`ntv#Pd{!luf;m=SWMr0)Iu;DIHMyQG=E4vV*QE4C4~~ zy_hS6OGso>{h|DK-*09Mu9QXe4Kajx#7y@Oh8 zGsSIGxHOl1Omvb`siERJ(5TF~k`&52l>)MsScmb8>0}pq&L0Jj`~&^<(%4@s2)k5? zYAT1L3;Y_YE(^B28lKZfHc{fpaT4;Z<>_!1pO=>@1V|Gd<(EVSGDNOMy5&S_vG|M% zh$o;6QHd3}_C7D$SIi`ks>ydQ^|?+g24Y+THMBLmrIk;(q^&=FI5i)Q1*&-8fV6Wf63Ea;e@${=*_J}F5mqpK9#Q?5hjfNb;^z1IIy9f`ix zL{(qu0iW&oOxf9d%&S)WW@_tuX!UY|r@E3uoW76)g{hjJB0*GCud?1Zl=Y7yBd9m_ zeIeJKTe^?~!uTvs4pBj1cRK1B5KTH!)Uq*)mR}!TBxtDvI&COQMasQPaY$XV`fV(Lk_uuj@p&KAW6lqUNiquxV3MX1)tR>CJ9KMO%#o_W(sQGG3HV{oa zqW@R|4*J$i|G*{h9cTkK^25Pk3>Q}eqrHkgFU}wq(*Me#;BDSPPx>6$iP(pH!>Op= z8>kh^PvW#tTKO%_26y#3(FS)uf*K*lkbfg<@^>H>owa|cFmAlGRo+D`!<_T5+>qEs zj3uuS^GT<)9NDX(%1p9{)QZ_m^ijGn%jHGNR%sg92=9%i+LDuzJ(y3%A)R3<(qVFy z{pjSH$YqpDNfa(Ct)<<}-%30sf{?3~TjRM~QZBL&(%QPq)s-?xsgK6&{4a92d{mqU z!s|_{4X$fL#4v&oO>}9ckp3GL`zJDsInZS30KFV~>SX4n{~Q&HOwJgo5tI=T^i#0l zJ}b0nQ?9e`5mwkw;AYhsB-7~7SGK;#16wfVWAEJa*o@`1aQY`TT zWBH4$k>4sl@)M5QLt+>DQH`OXt0_027-25`7B1Lo)K_u`(GeW^|A~mBRL+nDxR_(4 zFJO$v$zzcyW<}Z|ILY!-vJ_G_k(Y@3atM`6E~aeKNou(=SguV(D9v#n8c2O63{3!BelkcJK@E~FW2Lc6j#b{u*~Dc$q8O=GR>}{h($t@O_6NCujD${eG2Yi2r5;&@ zxPwf#&d@q;Axp@;2!r&Gw2zFFlgK+@BK4qOi0N=d1eB}tJL;$uMpuTSWeHRtC*qxUK&c*6R(0p>_^Q;PrAhEaz*e#pC}!u z1ac8_LMn@HqBV$$jlkKhD*I5C+z!p$QgR-7LjI2V+7kKC{n?M?f;GwkDiO1mIPqU; zrE-s|uK1BWb6MFb)j~D)B^8GIJ||UIE=w6ib+C-%ss7M&T#)xF-!aR#K(lj;*eS-7 zQ=#R2fP}=$U}M%&hLK~W)?{_D6?8*3@+FcXYm*m2sI4KqBuJ$%xlU<^6;CyIP9l|^ zWQ=IX%&I%}T|7=fi%2(@YEmJ3+ouR+DX^{yKx1 zUUj)YeyT>EAXwZ%7eOaj4I1D_%6DQm^qLdMCS(I;3^HUxl;hF>Jl>~tBwkaU(PvVT zi=<0f10KTp{10@6Kgr8j`;kz9HI$ItBX1*EsC{i}-IG7_Wzv0N9L)$-7Bbbu=35a>@wg!<%mX$3UH(@>%4OiUyzKuLE;Numbe%rJtI zti`B$L|j50XcFFk1TjQu2{q(Xr3HCSNkih_Hq2TuTZDe>F!Z;HZHk@jmJ+ za%<{h{SpdI)IzX==VGQgAH3?87+D3FHO|FVwGmdIwaKHz2Wq8KojeM%u}8ikf5Dmg zjhs#{!%VdU=Ka;7t*wY$fXowXvw6#j!19&UvWE1W(W~_T25JPb$&mywrbP&XIl-Agz42cAK zq8};^jpe_T)5HL*7HS{?R4?}?#$k<6U+#!Cz;e)^HxS3<*+fnGKG6rW>2jcYo0V!% z{uq@G%4jHqm&h?7eq;8DbF(bg8v8KvI$_(cWAq=D*FZa$j-&S++d2c3@KR8jrQ;ek zS@|KAK~kOpD~L+?pZZE$S+C5MmwUYu~OV!io3hJ zE)I)3ySTf%ySpq-ajUzwX)>94KJVUp?l~rmpJcqYg)VUs&}wdc#p96{2 zpD2QcM-}*Upml+fJ`U#_jFx_c@(%+d?=5V!E+}6RuoC64eg~@Q$`i0g%RpyiR4Lht zM)?6RO&K7*Dnt1sSUD_T!oBA=yg28es`Iz}M!qABSB}XN=6*alZRWtEa~~=?r7}_ZPid~q!n=;(ET5q0RSzIG(qXx_0GoAZ z;0OnxP4rMneh2-*TU0UlyBWF#yqN#eW5Hb$z&2vM*@uYUTtJ3VLqrCh$bad^o9mv~6L4)g&D{mxWmSv(cbikC%iDE4&1C*`E?;(bKXMRBdv8C)gJko~2Vb_?gk zbAp%nTKK~Ii*@)mV5z>!)#Sc98@SFo-a8sQ{icY=~BE3ztj zK}|M}{7^cVKgc0+b!i+}(YEtl`Hfr?uFzE;IazT|)nRbnaGY@5aA+K-Q2R-aBX+`Z z%w8GnQ(3ly_F8s-d$Mh-b%`z5nhN&0>CVMmA3j}7hm!mS*+(rR&XR*!KXw@M^hy0R zU8qjhtuVH;U|DWrIBWrW~#iDZ=0eZ2Gxy2$!6g6zmFRRr1J-0eXs>MZvo`_@fy+MA zUK#pYVb06;i}sGlifzCVjyAvsmUoSJ`8)18Kig8E9=pl;mM`M-q_5ISNszCJ+kyM* zPq~-?&2U|Wc9V7lkQ@WFEwl%8o%EFuo390QYCC2ubP-dDsia%+qzdkH>(r|7CDxZG zN{3}a=?gFFQerT8}dT$vNYlCQSsFYPAwLl$xq1!P%m{v?jI+urAXz9*!Sm4K!YK7^{g|++0%K~xdM?K&~#WH+tMtH ztv>eImO;+V=AO>F)_hwC38m`&iLyV0}&XhpP^0FZDyZqUIPc)mA->Ds0T__ z8#wEaVCvr@-2(UcC#fm!6AiGsug3l0J9eWR$=Bfg>_N9;PSg4HB4z;dk!=9}?R1@= zK{1p^hHOiHS;IbqzsEmD#iOIg8K4`^0{_uVTSv=i--Bu24kSz}<$^|*og9GYCSWF1 zu-1(QhOY*76wKro!Mrw`{y|TK_dOW)!aa&7Z;?URIh9dj89aP{+dkQ6TRHnY3pn*{Mps*BUqmJH>_x77jWGz#7g0wX#|tuAi!ZrGEviahB#Z z`4jik zJ4n}OuY!l-f}tETbVXgbZomGPp{}u&=RTvx_(9)OKSuXf8>4#+P1G<=BcMyh(RJyk z(8{YvJ|$=>gS>~G{1U8MlYk_pfV64~g@>oe{yGXZv>qf6wc~YA*9rpWx|uRmS}N5P z`XTbwihIZfxE4EBIl^rOw3rJl>(ElO?0v0qwodk6=EKft^IDg;y^6gl*V7ip7utKc zQk^550o**-6)^UHbd0stbo{pL7KU)Igcb61?z-I6l^`!yp$5j*r;@a}81v2b?U?HN zNm@d`-EdyF&KRkmt4q)}VNbFhnPWgYg)*CHBjrZi_QqbSIifP*@B@?TQs`xT$DR2z zSTFXGcFGId#Rf!RF4O*O1s$pH1_tySzz7^SSdAS$H=Fi(%{86&YG>+as^OX9@!5C_ zT#`o+rCz4rs$Ha6#HK=9FCBTKpAZvSLRUw2To@HiH2`zu0J;%QW@ z1F9@_2G$>^u2M$JmnB&U0GDP4_W*3k7oF>DVU88nOyG#dS_E65InM5HK5QLnziYnk zJZ-t?++r)RUU9aynD{jNb*_?Yg5xE(-om=pTkkn%J9XS$K@zTt0b(C%yFe*(plRb^ zAF)%l8qHT`6;!2rX=muS=-+|YV5MO`u&T-0wZKmGXU`%}w+fR%uf-l55eV!jPa%HQ z4ZES!;1q7J?3a53&%cMrB_|=TuPZVyds9=Gp4upFmadxNhkm<9SHme|FVBO<3Sb3z z?RC;qG7a;5{5Ddr)@4$1eyHI$|b104_*828cwvX05 z_Uq;!_H4^?OSt{4<&v$p}|h*LL$C=MHOCyPspN^^H)@83SZ- zS?+{9f{(&}pM|%h3>Cu;UW zP2a`*xQY9FGvXokj{-7k`ap}~JXGs;V5hwW9-LNCN_j=rgh%Qay9C;(tUgWq04Nfj z$4aBGN0c$rI0&3U{f$P?ogQxt{XEL(Z|E25dTEez<&>&`x8bH(M>xhU6)y4rIBH^# zv|1SI80swNdSf5u>}lQWtm7PG4RVdN40Z)OTH6{r{X$^s-v(gja9BAR{{)L$btBSeRvzhfha@; zY5Fv<)`84I+MpT3?gN+LYiQ5c)@XEXp=_L}TcK&J>#m)z>4Z4;es(((2b_8WbBq28 zE#^D)Zn87AAF=1n)G%rUxq$qS%q1Qnjxmau1Z||PR33H>xriXPA{ce96h}=XF36MN zsb~r-uA9(+$dHZdbh!p!LiFLC;x5-#AyHn;$8dtHim2h?isO?f}o(C)$^6n692CLVHOY07P4N;NnHx&4bYrGw1`r7jL8LQ7fUcIS4jw z9_-=oh#B-ou2^Sc67oyek{{^j@ck`h&QfKWBUpJ@?9Cdp=QOR^4?xFNXY|_X>_g2u zdJR0G$gigQp#4aCFO^AkVIt^H#3H&K`AxaQ1fhp6k(u%cx*Cxro@R&1J)}5IAEBH~ zfSxC(p(;{U;XYw4=+LfR?tu6sapD`GSL zUty*aER7cX%CJY}dZG#>=|-gw6rXA!PTCE-&{<>+W;DH+y3U@a?V1Nv0o6|Pglxx5 zqX#2G|BKiJv}8xZgI)+Lh#)cvk-VXV(_Pg&)N0wEoT5D>-65LO?KSP7wDXZ|3l765 z%0Wl7cd2CNfqF+DOh1u>Jn|5c%_F*KPt$LtdAc&lpV-PCR#s>ix-w0p#I~-s#&>9& zs;(*CANa{qFxf_XjvG&GX9f!XY_x)z%+1oaRYxf%aWp$w7g1E+@Qqv~-{!WEPjroo z?=j~zy`9La6c*`{EGJ~z6khaPo~duf{{iOakMWG{9e0dKC0#C&+emFBd@LdKDC!$u z#04m~G`FlH-819?OulZo;^s$16*b~`nOSoR@MP%FzkQ%`& z-Ag@7_f?w4NTpwidTV z-C>$2ujZe79T%PQSHnvQ+Ev6z(;)tucwaZplPft&)($x4O5~cEXl)SxL-F!itJZcb z_31(fJFe+pGK<9qsO6R^fiCh98^r@|9F%)ub_*(lxb#M_XdD zFkb_-rFfEv@R*`}Qom`+(#K^BvB5N%+bpgP?8G$}t{P^Uc3Jk4s$aI$-F4FtVw8lr zT>Fr>wk`Uel%1%ejhE+o7P!7>r+aka8|qew4Sj0zb%=lTHfj|jsomLhx|QUosmyQi z%yg|*X!W4c*Ya3Xmu<)i#Ar6#BI%0d{zNYaZ}hPTVDG<1NTM>i?MxQut-Xirh`#!E z)>utLwRSuOvdF2)((H%zmYXV;;tbpVN#o$E`l76ub4241*JBVtGS>(+cG(CvBz(BO{IK*|=`zw(i zo$RMQyl6o!R9@-ynmw*3`s>6FZ475OzL&;({1(OodHX{>LuU%9`XK-|_7hqb`6D8mYHw*uvgNZQ z)Sp7K(#fvs8`)3Elf;MW0eh}&5UUbBrJK&?@Dwa_7GiBtxc`)3{y1M=94AbZZksa| zy>;j8j@mGk!c+zt7oG(l!v^QdiPU5SMbVH;)&BFg)SXl*$Aoje2-`Us|r;ssvH zr(`!~3P!|arZ!PYZK+hEN02>~S!x zrh(xFZ0xB}>3R&F>EqNNjW^o?HgpMfpBziuiFfovIhK{^YNx=C5qHvDhL7MtkWYI*Mb_a*-Uwm=S3)60?+5a)R<$J}Hk>`pF-V9X*9$ zm8wvJzD_P<`(dZQO#4~)*>G4p-cx6E8oZ40#&5dbz?i+|E0UeWF7R2MXIDUPuN!M42SN!aoKAy6^(tWq z-Ch1H?o`f+u?h*4o>5%xswd_+cRDw5vG!^95c4S4HS_Sox8_5IfyHkNHWozXw=D`U zUSu`8WWK-fR94iL%o=0|-qWnrEi(-E80d4^Gzv;YOm|A#it-^YD6ilXSk7C7 z66XSYV^@Z)6ZWYi9M_Q(@L0GYwI;5UdzoFD=lU>rOntu=fE&fye|MffQ zWi`Du=IXOG*MOGX3^aBG5ujF7c*!MQ5vobOfQdV-m{h2CQ`_lK*2;`PzC#ieztfN- zpr9(cVn59JIk-djzUjIJV+E*T2CVnW}&5GLG_Z*0$uU9!drPG^E~qo zXMXMaxZDT@R_T}U3Z+b2ik(H_sk88yA_r#9FTW6 zKPNXiFCPq@6Y|&ObS)qXoW)yge$o_0X3OhOcwh5P42Ta-3wQ!Wq#nM5{9;XgO&|5` zwdL4tge3Ln_FDIs94_=Kj4N>Erx!jdZfYIo+{Rs%!_}GOOwAZwJ7a&ZzCJtrrv^yD z4MLnDH9}Vf{uNN(_lfBr}z;8~L00%}&!k(|y+6 z(^mtwHC5MLKS*0vvx?nL5!6|D^cw-OS6RF!jO7RMHMyORk&bEBCYBgWgOYhADFvBD z*}2{Y`T|ql`h0tK%e+I`@z?7)u!R_~Lh+fdsYO5axoNgsKQ&|YwfkECqOys1}bNvhJl!WuI)xk`Hv6{CwSPZ=o5feTn_@Ms)-w9q# zJhU2-D1>U@6Yi<~UrSMuA-_s~UhdO8van8ZBddo~7C%vEwW-FF-Xr|Z2azFV!j6SD z3~3SECAfdk#(>&>jZAHfYqXu%bYh|0LwN5R=BnV*^BJ(CT*@xez_!x0HB9ojZ+v0g z%QioGGUI0oXEI{ zpZQM-tg;LpxJf6YKNY zYq2TKGuDuz?a#cS?tt|s3~`XH;P{&dU&d4=Rca$%Q%9eIFl>q8tokDxaQ1rC5ROfnfyyL((uF7#J9PhpZ`q1 zQ@&@sR$!&-57f$_CrjEvPd7M%63TRqBde6eb`)R$;b(Zugwng2M&7%pF|K-xV*RyX&D-=Sa4i#T(4Eoj$@Bva3chmi7MtBv2o#c#LWrIL$@mElmX8B^EhT zhlGRtSJzg@HQQuMr;_Btl-xntvoeFz9I2gt9{@Y;q+bJnZ~iT$cPsepYOfsE&G79V z=^OtsA+Ls*I3b}*p62#^ss(|+FH-_;eVQ?bWYP`EJ#0^J~new?&^XY#dB>*TvdYC)$+X@+9Y~+ zWj1kZZF8OdwSU%V7oS;nN6XUO|Qkq6v4~9$gZVKR;K@t6%U}&Wo&DS-tag`D1eK=g%!l zwzkj*tBf_mU&frTFgM{+Vw)QEsz=AJjL`e<(Z3^Gb4#tZ zB06tGPKUfnh1mrw^Dh;QD!ys{QCw8?wZvpS?<^|S}pofPD2Y(PB{-q_|9UCO(ecfVk5Nl*KCZkaNUs;$`r<)I$>80}%?U24?FVmKe| z+~TP1obB*)JV(~B(LTazGRGGk%xjpXq(=NM_}1$4womQ9|4eO^v!&#ba85Tm=y}xb zihMO);^V}Sgy>49=(9n7CYc!``Z`jJx8}di^~&p4HUBYJuspQ3aF!Kn0(Dr%(8A9gwl8{8#Tiw75*{QxtJ=L{lc-j~ ztxTTGFtMubNKqtK7Goh(+|KIfdMYN8Tj-Ol1!_)pXm9L0()cKs1&Spo%9+`sf&H%PM}A$7WXgQ}x&3AB8_Y{EkaYFTC$ON?bM;1b9WYsdzrVu_ zmlL|zZRaBKznb6MD%soGciPL?>ev_CY3DuX0sf%kWZrvx_Bj~b zujQny zhJJ`jj~yPjF{W*~w8)syVLsyxr|ALmYjE!Gv2HN$EfLM>mI*e8J;o94JmUP#g-e&k zA2LmCgicF6uLpkXgHMKEjXYI$a%BH9;h}P%>g8#u&puHWiEo@`k-7N7Hr(0Gb&P)^ zzLrLVXXlQrRs4a-V7Lp2uaAIUQig4}-N*jc-ol|pUU!&vp4qn~v+%#7ibcH(?Zrjr zSVxLDTHVO}G_3bh{F?;j1;vH@7yLA^v)?o?PmiA39D1fY9+CUzTstn;8RCjUF8y}y zJ-=RBCO-pG@G))BJ@a_x+0W;PuMjXGcy7?Hz^nd?{Tg_WH)iTsW+8b>xhT~ZyYm*_ z25qFNVq-B<901J3MQ)KZ(rz{nExuGRGS8K>F_XzymOdpdGrd>3E+aZOsAxjTEN4q` z9NAf~H6{5r3A`D+Hp~#dHncb>JMfzSVy~&j1YKjME4f8-@o~-)>w8N9|aG#)9!4T(N$1gUxe3*ZvciGH{-Vym>n?rv4U+_xzSOTThc}f9) zgFESnb{@6Aw_ioZX`YL8P2@uO%UnevgfHOBVDG;hb9t!UhL&_JHW$CM^s@PLW^q68r~Nc39$k%b-Yvafm`-_@HO(+JLq@ZwwkLCs zDFNawQe7r>kpsov;uvwF;4KA63*@tkmV_z=lcLSlzxJqP8e&@S^}wsh>z1jhXOgjr zzPoN1czY&N1A*c@4E)C)(Nm~|*hdG~9Or7sYJ0dX+L~&BqcE?QoAr{LfGIr+Ep zj~47K5iI@e$N4JqX7Y;0$MC|q+H}?Xj?Y)`?mmZ2t3A6LC+oZDGr&l20ltdaWHn^y zF^VABq+@&l!v8b-CHSFi)L&2>)<1lWMh#u>&(M%81QM?1p_ zAWD8BE`9~7+jgKJ-jM=uica8w>mp)54msNoq$^UW)KeM^jlJKJR&Fi#kdFbmQY@E4 zro>N~RyxV!2#NEbf<*z{Djf1Jp`jpS(^Qs8~dDqrhO) z82a8zz>fJH*)=;6i~9$Nj(lL3PmzzvUf?&KLKo20!R#S0jaULTEnqG&@0e>qW_SXz z@r>R?_n|%MtJE|^?@o}T!7R77l;LL`m`H1&Y#$NF--R6I?vh?whw(5A+$_U|54@j{ z%qR1yyiq9R6+Q^e9AgAO(F?4ue$q_I19`=sN=;>x5(3JH$AK?+3_QnmZE8oG2dPCkWF9yPEio6o#m?h6ZF8Tx@>$b^TfVR6J zpO&A>-{ddyUU{{=1fSNIBY?L1EPatqfEBm{5Nea8ozf+V#IN4SRgo+DRcQ<+y*tSN zw}E%FIx!fL>;*uJX92?+N>%|s@C5Y7LGmhaQ#a82yU3;FYJA!WHQY{a#Jd5&{~Ewm z%8+Hqe6Xpt1!k{UtpJSnZ{T4x@PQjKXYPS1@CjngS-@7;LNELS7Rn(&b85h6`Wx5Z z0`AgsK>_F>phbFt8@M;VeWvCCo127RMd7bXU}sH4Nh471oygJc2b6d- zjFW0WcgG-;a2aOjZIuLWI}&K`Sm1MmiGS2~;6dC3Cfr0|J{e+v}{lHW4w^|RS9*k?vLC*)E)`y@t z=!bQ*AI{;2EJGmM(dK;+&oQ7z)xje<6`%LSDDgzjM^(&D_Xw?tnjS| zMXX0Q=Wbx8fiFOt?1WEAL>udgFPKgBz|B1m9NIiMxSQb|6wv-8a%WVu7l#~~8EPcP z`%}z<`pDFEARA;cSUQKPsl-Z*uUg2v?2fB^#YpUb=#g?}IXx!D+Z_0%yyC7;1s@f_>*ar9*-xLSK4zv3+B zi325ntyU$6Vzl(eO6v)1!e+GiVa%u%7%LAjvgROna|iHD%`m^h)oZF1NaURugDI%% zcd&vKpnva{GK5ZsCyK$SUJJx|6fr`5Npu9K>lI=kax#t*ZPj()aX+DMLoe6FUF{WS z$tD$1a?&5YvJ-#TLCe)g87d-wcD=d^?R^F9y%v4<6l>8XaHZxe*Kp0dgrYPhDw8{a z{f>vK$7GD(iNFlGv#eB??9nku_@hhj2u5Lo^N0DMfUv4CpqaVRR zi|8j%@0hjbxR_c^jn4-ZGR!$Lrlm9AS$T86(wsi}+ zMHR3r4Tj%11oPUTY^mPE>O2%o<97*CnMZ6TQ;uNwMIolM zh}O#WkV`QOnIQ|wkCK%nkex7#_#?Lm4?%xnAw8SOR4S7rrB+llZ4=HTwpZXyU zBmSaWiOUdet1D%yGwJTq4{8i>{S8SAvZeYc%gGVQEU8SSA$x5GrrI?iM4FLCHj!^l zE?~Bc3FLJmLis{}k#?yC)K)o*`bSD3%{V%O>ZD|o4+uis#6}V}ejQVmTq+UtE}0~= zh<}wO$e3v*Yv5zONgS4t@y#|9suS_7UaU!F60c-CY=iA|8*v;Fs(Hyxq)lX;^fxn6 zHW9DM6Uqj(^C;k!PRse!5M*DQ!J0BpIjbfz*AUlDLT-6i^ma>O5qLAZ2=&!&8l#*g zZv%T{6JxK zKzt-us-5JW^gxW$YM3WG)L`06icu+gDH$eN)E?Ax^|Yj8nybB~D|83M1aAQ+R2BRQ zg<@l}7O5l3D=mOxx+K-67ZH__fA$Rbz|&Ya(xvIZP5;A9pnkD8gaOJ`x;XhioR*=>xTQu{a6BAFhkZb9Zcn9~liIKKB2kGCcT}6c%C$=JYA~)_Ae0x0IjJzlP zQEoGbl^#M2UBu23hDvj^JLNZOu6#&2uK6P^;7%D<>ma zFV57b(?-5RJz*Jp2<2~x z5kw_V9jQJ`8#Z59?(w`JGs(|07rC3YeN|WiYQC({g++b6+f>$BMJTh11fxOmk9=QWL3a!Vx-`O?D0;8|so>EYpbG zC*@LYl;*@paAbr7xfcri{~#;* z5dDML3S3Jz)F%2!FNhWNA>?nDQ#Xsts3hv8vPJD94`ccgvxRkRb0tQsrEM*>LuT`D zxf7*Pao-~(peA03jhWu$PGK~?QeA*-lp{(tue#lO!=f- zC(mFPV1s4ZUHwHKkw>Ur^hsi~+*Pt*IbA9w*T`W|9@#EM5i7w~)CRk= zFG>UzK#r2izF=Q+ zR!M>_7l1wDBw`wNdpWRH=Be|Q_CS0sR3z*lL&*b*1xR%cOe#~GG_v~S?AP&Mhd4o*63COJrg01HXZ1Yx>oIpNAed0Xqpf1EN^*r_( zh3X0Ls4)0-LtvlVz*bAfezPVlodRSHzJ=v?0V*a>u?MbzT;=iD?|mV28)Cx(-!;RXxNf9VIhPdm$fcbW6Hpu z&BcDbB1%^e_P}pgz|Ao_>!6KVpls#AfKmxNZ$G?_h605ac20HtIv*vGUIlu;7wmeZ0tpbm|Q7qCfUQ8zzmgs|+D|W00 zfF%-PF-*W7F$6999lLiK)`E)G&4wLeKn|h?J@yQ*S!jz~cSb))B)V9Hbyo_4$&R2H)s#v=L7QdKCid?tvp@JW5bf z3(8fD*33aW{Xt7|$kedm$Xt{r11$tnAe7pUQVBTP2enu6z6kFLjL~A8HyP`-yZ6HI z=`UEopV5yD>c!!g-!YdB$hr+jKfBOQ7W9`M$Jx=IR(!9)IczA;Pn^>O@BPsVsqo|! zl$I?6PhqjNvXb3{<`(l=ap?q$2AqK4$h%>VImOvktK`qPSS9;iW zWpQTrs~^UCEOP%^;!HuXwB0@Djo)~nf2}BwyNu;=tS`>tK65#YkH)Z@n_%s%gLOL| zXGy@rtsN7CvEeR}9wU&(nGB_4T0vjP=rI!Kh(_(BPzMr6Yfy(s%r+KnWJ4QRu`d0< zx{0+Gudm@t_<_0n5xx8yt?5A9>r30yANDtmD-@uW?KmbIZS?`acOUy5Ebq^7p68`) znt|UHqy63bJ37?ffinOwh}J2Ozrrw%3vt{J{OURSeJ{>%QQZX()du+Zj=_Jiyi@`0 zEPN54a2$y_V?=8NqOLVj|K{+sRK@#H)WU?bxmS;$@albmkMlnK4EN!$`i|HCP&;=| zM&qhAi2i853GhF4hmuYUtjrzpH!ukpn-x&X-?+kd_;B8U|I31sUB}syQL03&3*P9T zZx~_0Sl2?(8{hHz0p1fgmhvuY`Q-m&+zVx5aL(4~n+_;t6z-IL@LYlGj)W(!7VZtP zC@+iNx8Q8qSWVq3F$tJoRZ2%wW!!}g@B&wdr*batQT_3|K=gWBtN@E~_ISK^Va=?8 zwWSfpT|Q=fI{dpH;4XWHIxmA}%_(Hce?V!^U}iqYO!mMv-LyO6xISqg4K4K+2 z&8_fvPvm!x!U*-l{C$u5r=cBxp>9v$IsJrs1*7~mvE!_UyKZg#Y6!6!B^iNxem}Hf zXVl8QLjA%@@&F^pt?C2b5ok9}MIZi!9-WOo9gWA|C{+s1)*pAbM_^*N5e%sz8-v%; zo!Oj#9B^=kL93_<*7$$njsB(_0!QIkc>YI$2jGtK0iNsz@Kb-mn)f&C4PT&Zmy%oI zAHPUW28&`gTC6^L@ft=@PnAHn&I~0+sS4)yMmWoP*cQ!D3isW6EwaXn36?aH^|4oP zhscBh`PLKA0ht$G+^xIXq)NEev~M>*($aFirCI-pICqxT--E-?_E-M;Ei*iAFBt8>R0`lFnUkau?) z{^_6aQ$JESquhb$`5Z8bjD_#IEXHdpR^UmP(YG)L+- zYD3_w`vPS5Jgh3U5R=%3l8i@O=RW+%SzzglAtOom{^AL|M6a4Gj<4j5_!Z4!1a1jU0~^_qD}YW zsyo1TSq|(0@5LRWhcp%OR4o+i%0ab20$LP&%;1pkhFmIt<}1z7{Se!#O7(^orxl*c zhf2J14bhoQ$-a$LB3iua3fN!ZX^7b#oGqs(%3T?^)aDe}U>d6{#zkVbeAo@5J zbLt{2&|+l~#&{sO9TG5R_bPs9iRWs2c&bep-}k_H*&R#?_o;M53nl1EY(l>A9kL=u z{1mJwd6+Gup?+ipJ4P|sBqK0Z*I~C+9_&1CpsV^1VqALUT~YLd57R zR+&oh;m$zY|3swi7&2Qgz>4z6zIzG!+=}`S!x*_pJOl&X7ve3EL3qOg;fRkIw{3AR ze2qCVT^)|zZ;PDEYWUp{tQw)n`n?NFXCqd&FW5t#hh@h?=}|yN`3CF=NUS%Xz>~EA zv6M%+E6gD_6W(BHN}`64umr(0k&67zV$`_{^6=JTu8qY0xCwT?ywVxy)yi@mXo>8Q zFM(~LtxVy$OqwCb$deI6tR@c!=SVlzA=d(S@{Pi)ofHweLc`RC%6mCcE)a)^3crQx z$klVzcU^Pda8>7)^ACl|(s;Qi_zI7b_vq$WmKtc*Xs&DiV)wC&X^G~cIyo2$Hv83S zP|0j29uLV3cKnWwayK$4n&{tHf{k_5Gl>^nz+n^Yl#c*3LxEdnOnXY=~zb zB>O{!rx7L6>)CGF3%Yf>6}m&3ST+OM*r}MQ;2=}RsrkwstoV^&B8gJlC^TZLPo+XR zUE$Q%u+lTgUr_3NjumXNYLi>ZvJ@u`l4~io!Qs^$>-H=3(KzrP*k!$J0fTHVxJSSX zqm08{Ur=ZwUOpoKBes#M2v5Ze!eybUcm>K=a}mA2%Pr=zoyCr_&X3Nzd=Ie?kwJFW ztkDv>M4-0LYM-$!*-nfPBG1RkTXbVE6-=e}gGD7v(n)(bmutTBxpT7XJZBSTNbeLc z!bYZ(3!#Nm4eSuf7Exg*x>#zn=dtc+DE zL5>E3xf$AufZuQyBG|{(jktIBM1E!_rj66S(40U% z_G7v(bzRMo1HdIMfv37icq`fDO~fT~I+U6BF$#V`IPdHd(9%75CLFke}Q$4hC#(ifulzwPLYn#AQ8H?WypsLVsk?ETRTt*^&6YS3k zWNqBx&w)oT2z#0nQ1xqoXq`b_FISf>;yCF0IE96xMZ6&AD}RA6OCUnvbM1(j9U*;C{1HL> zifrRlcCh9U)Na~nYO;A)`$(n}-HrCA3&~E@E9|x(D*I$XJPB>L2I56&2lf*uX(u~_ zy{GG@ZK*r2@zZ5O2jMwelOD_dgu3>AW*QV)s?juKCi;@C)h^&EkC%4=$1xOHyuGM- zw2Q99CNhiJNQPo+)8G)pUUVthhPZ+GmySEwb*u#*(AR4vO+*&5AZ!z>3h$w#^#~d_ zeqx+>K=>?V2%E&gz#K5q+>0j$>>Xz#6Lbd1^)1Mwfy#rdnrawzTwQ=}~MsyJ4R5Nm*?JyS>k*F>ziKs+XHhuZXY`LL8PcaRIDdB`PN05!zP zII_7gSqS4N@p;gee~eWvXzv9D zyRWpCX#uvX2haxZ!)|4_v+Ix#f0?Pz45nY=tj%#ZZvl(iZ9jLxTr0+-2blO4;jx;W zO->~j;{ToCY652~Q~)*uTeJ?l84Gld6Ui^Y)%XF|uuAO(rnMp1{|x{ibZ_jeJrM!l zq|C-$o>CtV_40utp8 z_UBF_8m+Ji$T|<`=&Z*t{J?Co2@Fghp(^zfidC2BL-ZuNDeZ%ixF2nk2;P4c+%tE8 zZCa0(n~1ElA!G|MR8$}%Frqwx!>R$!#O8Q(CY$2(%CLAu%;9%n%-je|sz2_*iLg(j zU@^zw>BgA2RZ#1d`c`6KXD7mc!^2L$1bcK0TDd83XR7iN``r=XRI8>S+gQ%Vn0N+; zw7c>}`4JFx8Nl*IDs5q7j8^7hH@{g~2X3?-$jaG@z4}^Zit@KI9j`ONo;C@Z;5~3` zBc&-GZGaRU3F~AX5NtQGrzik|trGn1Bd|+94@>43d@3X?<_g%+c0_y6hvjt;xUCzo z%ie%FI2G7H3tEzab>|7kjVB&@aB3>BR~`7g2zAfI{`n;=$IE!!0g~zw9?$Xr3;h2I zudnd^2jUmvXsN)*hMOCg)bx+{t6k^f(5p93itv)VlVU- zzJNQh{oGpA$6)vEK`h{3jQC}U63l_^I0-TC$v`O#gMBvvo`T7+Bc|Xv8}Fy#uc`QJ z622LUcT@5CLRg{;fcN?bDlTqq=YH$UOIa{oRBPq(7?PrN4Mn>Vo5-s4PP@pl?N&&1;ou9%ND zv!N6SzoRCeco^`YOGknP4}b&TxJtc!4*W0T6L*;zl->PLg2&Z_`UaHx9^Klhz-KOuZnvL26Rql2O-{wb z-7jD8*LQsS8Y%Zl8EQo_Y99 z5lZ0xoeTeiun$+{@y?9@-G0&nl+oQ^8R*?#xcZOMGPwJ=5WgzMS+YvY?4E%cII^&` z9C`RO2WQU35pGXyPU&B6jd!=&H;3~(O3UF^8(}ae8H`{WBU6Ou$UTRU3X4A0;2U=& zfx&m~f9`qZg?9Ewzw1lid6$kE4O%{;bd`qo;bVPOYa^d}5lYvCt)hq3awbOt}fn0b#`{1xZ;jJ~1pEyNp2TidM=9#Fa>`Jt>jj9dd+#E$yd zaDC9VAx>gMi}_%@YjI|Gyh6c9a_fJ);|cya3yr!NOa1NHnEe`*ycq50)*tsQJ#!4& zJPK#<#TgXD{}RzU3|b`|GsGP^$U~`Kp&hiP(Tnz_Ykpm<_SLZRHALxr&|mI5n0r*X zeT-iCT`*=@0FEI_$3y8Xlpw@F7pe z-0qAu16VlJyE2Y$fiHm(GnGLR`>ehS$YL`i@2ky zcW~!QL$5!_efuT+%1_bvzwmbodLj_VyQ4t$QHm-UB{dN5$-)Q%>j?ZrXK{xb4{vE- zJaVAzu^nETd+=8F1~an(yVjP52(L4A1LJY#hO&>dij%l&mYIT!PnN2R55RC)jjRD* z#ZNGxo0Yv{HIAg>JeE_RJUOp??OJLk{ZVmPvYb!(`Sw!|4fPWY?i0zC`bJWHsCxcQ zBQD20riaO6iBu|E>J4Vz0AjIJ9+p}sV2MsqOT`S)C0tbEgC}Zvfx6=ykC5EDb zK9|yjc=0h9;e5r;@?D-JZ2v5ca_@1?5q;yMIB|Je^&{Aapd6cXI{pvG}i_AeD z<7epH26}_`D4IL?yjq1?5PC-til*vMLgx>cY)UNxr-^h*s z+asDIQX32vzzyVoKI^}Yx%yRj%=E$6Dqu|P)bEhQ*s&y(nc2|!-wt<+Ti~GNG}@Et zMq`peZu%edI2l>eHu^E+m^K3%-jdKqW+#i2LvgpO0p(#Y%<8dCfaewM;YJB3?k@dBS)D zH<5U#DX)Uro&c|(ZP1WzG2$?ODj4}a!Sst@&9;TM<1}`~f6-kA{?yNuml#OYHf2Q) z8J;PiQ_n(nd%9^axa;TPoX~>EXQY^-vFmR#_91Kj#FUj>k1Tf)ViayEHe)v^eVcJt zSZS(A%m)3eIvApdF#|0o7QFSAdQHrg6^M+>ns_pZPdN%%mdXTY)FK}01@N?+;oW^R zO(#R9Ll{%pSci(zWmKlt;2bWE@x4Y2M~qz;G4%`V?v+d>q0-n7bw($Q$X#Ucvyzf& z6p2GV1AZ#yH$i{&`gW>we*ISK0!;>j#X|Z;-g#=qLZx!Jrd3 z1oqnj=vVV%7rkzbfu7W-RW{~mZ;<<+hSO^^qOFmr3YGvPe2uXTHLo7vhL17qgbfZY z4@_H7sVj_6*BceBJo+Qur;~{1`=S!C4z-z&xMi=zN%&FUit6h%R4zV)^GV@u7a=~G zti%%RdhoQuO}8F)_rgRIR3H3^;xbTuSYzZiczpMG%&9t9ZxS?R>v6NpV|s#U*M_c) z7^u8=A^MNP={gk}ppqaY4mDLJ#v?m=PCtlR-Ai!3MYyaL#4YKfp( zeKCzd?V%%-ZjXr1$aXcx&b<_;Ll5K++F|_#K>M~5$Bit+5B(0heX2q`)gSrveMSkQ z3aTAWsNF8&CXWFkA)U4 zg#3CI5UNUpD*hbrsU_z8KV8oroWgfaLEOEsBg)AG4cvU}Pjig(h!gVSN&Z1~?h8(~ zt{AP8P(4{N%YH&jS{tX;T+?^daQ(Zbpj%#P52%=DvM`vryYkncvE~;Avol4L?I%UAc^u&*ssK0I}aSa zHt5~)d>*5x>y?&>+jSHe<8}2b`VHsS;}l?ij~PL_v^ze^O6qhPen2n&s>{wgtSfv+xi31kjRmnctafTLxGbTdG^~ zTXvbdnV*TJ#La@AuggE>I&gp3&1@64D%=YXGMO2f?u-m&Y4qpSrx&AAa}VsFKJ+c* zd;62uiQkAm3WDTZLMyEnQcB7fq`Xo_uGIpo%IcDaC@E``B-Tqfm_ z?t|2AkIV@V4NnSlk)L5UG7w#@b-*d@87$@>>pSY%Pm8Ka$bky@jFj}z!~u8 z?q~uz8qAk>d^7R3ILVx7o@%*fS!x|=U1nWlSz{h2)aBOD--x_Mvid{bDP2W2GA@z} z)vpVYGE$1PNB*D;(O&3(O}oi!bbYoMUsq@*{t|1LE1FZq+2V6yF`SM23NB%;cuI^n zXE&2(Rj4A6{6$u0rZCl!b!^IhVNG0DE}px@Ca~?;nrv;*nPRx6TpE`M9%U!6Uj4|Q zJ*7v}o8aqviOc}^cNF?Wi|JKCc+8K48o&Wx;2>ndw6fOd!pOy{O&lIF)w3QMq$TxXI{@+U(Iksg)+^d zKXAFsRV=NnqBX(dGtV_QGT$?QG@rHDEVIoA1vBTN7}EuHokWDohZ+P^f{R0PxR>-o z&IulJG?bNnOc7!`HIrEn>gsWGbL%Ku3449}A6s*q&+^hd1I%qVUrrb*oE7?sO~p7d zw-Dy0vTNxx)K+pDJi_zQbC}ueC2lHzk^jQ)PU1nb&PjRcWle3=BR@)*`LwQQPnlTli?>LUU;OOqr3C>#iEwW zmVlWze-k#qae5MWg4+jf*hh8%Q<~afDy|imL!sA!fUl!(hHt7rORzzBjFg~E(drwi zrpn|&Dkn3GZO(rZMw`D|Jl4CmkG3tg1J**8S>ktoIyabo%}iluA(|}3&*b~?JK^|j zV$M@HNg95-R%$VQk`cJc{AvC#e&+$PHkz|@S1^Z~a?Q93TrJKF-)s?4>}mQe_{R(x zOT;63*{1(NKD48j3c|Wi;gy$?6gd`72>l5T2&#dFfp39MfkuH}{y6_4pX5pK1YE^k zm7V3$3;)BBkukv$&X6;{WK7QJ<*+-4y8HUdgt{qE3(+lvJLVR)mA1OJ2bOrVR~X4V zxo6x{j^p79%t~|{a;;%eM@78BJpS3ffVZT7V&E34luMQC+BPU`JCJj!4fJ{DCp(!h zE8a9Ou->ZC1u@AKEvK$uA@qM@`)V}7hG7Gs*$SbvN^qj)NO=8`R0SqEqdWE+Ho@ zYIi_8U7@Uz&q^~Qox*`&rr;NUz_-ljhr{(7jH1r?Ym#)Bm;mqPj^Lp!Vo6}w~ z${sb;{?^(CBT`HVauYZQ*OVV33>80%DtyNG^1IlF^mb}6vc_p-6y1Tj!sdhd?&+yyD!)c`RSA3(3U6HL5)!Yy=~Rtk(WsysV!8W zDnsOW>2i2nsA7;0^z~=<@A6IYZSsxwJ@g*&KJms7|G+oaSJgM!d&{G{U%H++zc@BIc4yqp=#kMkBPqR5#`W}s^a2?lGOjt}y*&aO zq&<2WsvSSyoM63Zv)QUzSDJ^3nT5mrcm5D4{T=bX;u!d}YS*Ma!IQqCp3d$p?rffe z-Yfp1U|!V*N1&w^g=pauF^_7%tmS437tGD9&uyE)5om3<+NN6;ie7#f7i7zG_qiNG zcCn(lw%H^4g?ju*_9e~^8g&Djj?md~y6oU~b3HhgYX#4YXCUCeXX4o9>}5DfwC1XC z273j}>`}}s+Mr5NV{sG6iCR@PFE0p~_JCUn&Q&_j71xrch9}IkGU2{9 zlH5l3Wv}tqL_{98Hun2=TU1l~9PAvO%})5f&*CTYS>PEzK`dgPD^g;6ej9evIh2I1 z`u5ajdJ*Gh4{_yrhTn#rb2M9+wXh^RoBhQO=k{<1xY=AbZW$W~w~707GPQ~N4>j4V zP`sre!t04U!AL~UJGIJcZTV{CcBou%T402~mfwoqW3carx3sr~`?SmAtm7D&k&y0B zf0Nce{lB!iX;spur#(++91GkF{oBI3wM=9(J6zmn`EC7VZEUS!nSGVwY0neH*S?1f; z+Y3gqQ5Ec!Z8a?)#Tf#@rz4&ifHOBicqh<85q=dMAvRN;$sVXvHK6X(!Muh*VTW6F_J_ zVf=-ftfM?B5)Q?OCIx>6YQce_eIUjE(3k9e=lS5K-B#C2=S@dW=Z}nNoGPC5q>MX` ziq172tN&E^kkXHs#QYTgSoUCMbg>t)y|K2o#8?Vg&YHvKsb)#I%?+f_np&w{BdLLM zKF#BEXTlCW3v|jX;n$J>ROb-?@EkBHfOvOm0OT4fl3xG@YFZGAYb@C}+|TcXgzff`lB)bYP}| z4qu0%u~%%Q-_wieS|EiEr8bZ!iEh|G*P)Jg0t%E~$bStpz&cQ$%EP5HkzL`&;cX!m zu~`Cq+b@8c{}j7MarXu16jyKON9T6ucE>75j|@KjL;97p0gm;~V*Y`lMD-qN;xzLF zTMPSpyUQAHtzzya8bS^61ZKw?v7PXXX+Ztfc1azB)4faFghSn^McxeK0K)Ay5mnS%p^9>A$tE2}9Ni}pANGg-1nLBa`zCvedfvDe zyLP%hxjwi?yJxr(+$SAZUA)8P`s<4JZwO^q@0)(I_ryB3efGUkx$L83>}k7$&1xN8ahCU2~+}eoogW0 zwtTeavVX95wEwnUwD9IC`~~LU8TyncLH?mcMq;mUi}=Z02CHJFjK-%KORk^>(Kpy) zTwA^+ALh=oCz(t18)^(_9T%wuh+4AVenie89wH8=eecbZDndvZ5wwPO&~{?ii7 zMTPtVxey(~_jDE<%P;aV>8SKO;*PuzuMD>h*#m?8e>@F5b6k1dOI=S~>)a;yE0^e6 z>1>uU$?+_$wL^3aaKG{G49`|S5#X_i!>zIQMYc-#45{J_ah1S{G~AW~LT?^(m}sky zms$rud6V4pF{Z^qI(-{x8zQBvQYLkPx&t@%`cT^}VbkDvAFv#-ZMCxw*_YdMdaj zAa&y`ElRb4_MtRRkq%UAsy2nLJh)cbK?!L?o+Wp~%c>aD1ncZOJ)c@amNQib5qOn4 zO<|O=@>i*fR3y?UYzfuCzS`3_$otFF%qw`ucusq3dCPjndoQ{>ySKROxGT8VxU=}W z`$~kSNSU>L#5U%Du-4+Vrr1>LJ+R^)3b*+)+zEa;x0CzK&Z6DKI(>`WH2m9N-Iv4L z&pX1m)-MIeB2wR^G*YYm)8dj>sDf-dXgxdOJ95ri$M(m1&1$pM5?k|~;MKCmGzJW* zXHX9K;k?{|tc?|NiF0xjxjbwbnzcrBQD!S!pBJ!;Pd1Mcn+w0#=8R0Uo6FDk=zz7=@D&!I+b10G2)GC)3{W-!&5jV#a3W;W9UC<5bB4ea<4>ThL}d_pQ2 zX&YV-CQu&#LEmd{N#7A~VP7xrSl>GDQ;*xz!sB+Z{OSuM=7j zJ_@RMTPiViwNapa072AU7IKwXy%#NUc#z z>K{!DD2|=Vy?_%*5jbre5Nn9jgg!zJp{4Las3W!%9|@g>id-SMlUNB0sK}X+!RZR0 z_SW#MZ-K7QhS1S{ffn%%bqt!g157fzn$5!9Ww)^pphaFl9YMZi3eMmfayjWscyIVt zXmDsqD1Y#IK=P;fS-;7j=vREj0@wW(5Y*CwE8wwlC0HT&3}ds%m*C0cUF?0}n-dDk zuk|bRIsT-%mUV}9kolHK@;V#OX6C-IE%<)$Y#GJ=rN)|;Xs4vu@ZP{K|6pGs|4zRg z=pSw?eOE^7yWo}jm)Zk1dJM~RHTWdqgL#Ivp1rR9w9R38E<9xKkVlP+YAWvQlcke# zqVh;nO%bv@dyCsI zK;c}?v<*4A?9?$T2mKI!a95xU2~d^jOiX`cODDTp+9Oz3-{Fhp&Qv zW?(^Rq10AOFzur6aD~M7=B?(3@LfyiDzfL8GtfL;gl||;dNTEvxMd8~S}4~dtHL+o zU(yAV{|ez7kzvX*bfSX@PIqU=aOqrWej2}+&ns*c&Wk(nddozM*op7Y7N@r(7uZI- zp-e)Z#I4NJ(u`(wUiPa{(mWY=s-o7l5DZX4dHw^t7I~<}>`<0wzd`dw&^yQl#BZpT zqV#FnckQPB3%TzxR9pHuyknNI8<~~Ne#CJLDKqw>TjW408DxbMxK}qKvY-}{UmLDm zl!is3!plOfLfu0vLiTX$uq(VQtcKTwvV_3|3swp8!K^_V8H7;qLFjX6O}Kve6mCdl zC_308P&*Lvj|sL3Rh5uy)3ccJQ(5U4)S>3HEISmLiVE~fI3vx14)H2I5K6In^nI!_ ziBs~q-CgIbK)zlr)wy#wc}H@%TgryX=(CL8ic z>*-E(CfZMZqiRDloxm(+`*0QcQM{k;%Llm#8^e0wF61E#A$Pk93?v16&IPF8Gef_6 z7FsqNoVk8NO`ij5@%LmAs5JLcub?zK2#vmr97294CP5E*81;iTsQuMJHuE~VF`ny% z^@-X}HLIEc`dhT}9=XNdas&CO)B()3?vbnER^gAKp`nc6gy5p!h~Ttf-{5Hc@Alw6 zWMIk!vjsOo$G0gE4^ChP+>`1{E#&%28Mu(lfdXr*o(9Ek59FN}5}9z0REB!B7PQ=D z;4h^!W!biD7PdU=L43B8S%jTyE7Oc=%A|w%HGtNs3sgg@81)@)NkhS}iYGgOB37P^ z2D2&&yR4V^NbqC>kgD=POVt$~+!L9{Oi{Kp+mda@)@3Uqcaj7b#;?p<<_B2tB%2+Y z-$IyS8gmbxEj?kFeF{3cD5yo(Pz|8N($PG6o?U5pb_egN@@oB z0P4b*K0lnC3Tm6wycn^<%3`?$2+|Ft6OmGpL*cyP^`RPY-dchT)#t#@K-GX4xa~jS zUxZs`Pk(!VYkx0)KmR2E5r3M0P9PU%-lfo{@S;d%DUY0n4CDj#uy#=UmeY!X3OFhqjIhJ5zt4jM;2{6KY%aD-{3}aK6VkC8w}f?Ocv%YJ%kQJ>zxWW z#BXqg(vTIehpORLaEmGk9j*p$FAFW20&`~HG>p{`?hNw&~K<`NdIK%Pu zR9d3@GiRBs>|8d&*5UebbGVb-N9>9}xR=}(ZYWm=D&XJP-`cPoyNM}}XEcnifJk!> z*l{#<8hM`Vm`k~EA6w~bOLQ z9s1^aOdR7tJk|o4jNQn8{3a8~fn)*ngKUJy^iS}{T0rqJ9~IxVx(B@}^FgG}sZ5Yv zQcvkaq)X&ccxKoKhc^MIWzFEV!1O?gK;D4YUo~J4)CGsLNuXWeK;U8^3Yz2Fwb}{>}8>lnt|irChi%WMO(?SDVj< z*y4Eq>$z{-L+&$v&*s`At2`T>G$fk<740{AE?pMq#nk_^3_` z2m2PepM&f;_-uNZWlTA^GMuJo(Zy*Ow9PkgXNjY}LCZQ3T*K2uOYriJp$aCNjzACE z5PgOf;JIE#>#N>WSfz!$P)d%B#9FN%4uwX+>+*JRcQ8-zTwq`zOMnfW@qhFG^jiZ} zz~FrqC=;v?XNM=jXQ95KYT=&YrQt!Cv40{}rJIq0vJ^>>Z%MD=hto>2YoE{)Gae30 zrO-E>o!9}!;8#4&UsO|kuH(#RIA304AHyfKEO#FDlz~`Pr`UOj-~{#|(-~3G2l_U8 zdO}nXH?>JPccZA?I3H?|9B%IiP_f<)ckro1PvR`Pr?!)&sc~R8QaGbs^b>k4_Wm18 zahyAwQJs0s&f=P2Z@tdVN`|bTia-j}XI%m(VH9Mcr=&R66-k=PROuekdm>_alcRJHt&P)!~xrh4<{a zz=YuN;Gw|nz<&5&KJd2-ob*2l6c0oP+Xm)^=uq=;z0iR0=1`yT=TLOyekdC{4XQ8p|~$7@g2p?b8DtXF}0s1Z?yY)YP@?DR2Y;GZ((xXkz@Uy&`} z$i3o1Y;!IrPQK^NB&Iuq8!&FRUP?rS*9W)QV{nV93eEpr=oz|TM)1%aSAoKKA5nn( zi(PIxRT*0H`!vVwVQk1Hj$x0pPCT2=@V{_zPq-*PhVRVxL_W4Cf0*lttkWj+7o{?{ za2IL9w8SY}n#s#tLDpps)_hTP4Q`~eQS))et{|!sYY>NYM&C*Sy&t@ei>Xa;q8^rv zNP)=c@QTompchfC5IB#X6vn^UH_nqd?dN-rZ5pt0xR`4Z7Dg2$MjD1!U{Hgqu7W)?GN*(h!tGVJ|? zx48Ef7M}>)gnq((zA$`!DQ+V^$#EtJNTki^Y4C%sPd`N$+(BwKHI^b#mAL?|{0Mke zRzh#tIm9>Ub)m1(rLZ>^VNYQvZO6^RZhD6MiRVzBZ@`z}d-J35(=PD4_=9{iK9;`? zH(Wb6j7`JOiDy`3A$HOG=)v@EdNpR-W_l1}l&^?%6><>y7228Vrpd-(^u9r_t`1kG zVO3>Bgh+bmeK0AIH4x{&s$m&Ks z?L30Fw5No(uIHn-sQ0_?kndPvf1q!8beNR$$}QD&b@M-;a@31&!*g>VBGr!URQ3X5 zeM-127{Y60t)7dw#9HD&v4_w_n9Kj*wj=jGgFOh3>z|0=dZ9KfQF)P7yMy>@Hr0a~ ziQGyV>LWa}W6_QG4m8!>bU|h`viB?4Fgu&u#1WW99r&fVF}*-_@Ere^f5R6SM4^F@ zPlyq+2+#See0Q7=ySZ{)685-i>@V~r%t3!)1?D3ipr0WsZ$ZC6_s>3ZH35YRx(H5c z-_-lc8~Ki;N8I5eVP`NdNCmR`3;61L+j|n+?_Bqs5oagoFx1G}JAXS$I~zD>J8iDB z&MK~D&bqGJt_iL$u2Zho?jG*#o`|=E|G&UmcqGT6=Vg|*0*a)$Sba_^m9E0}L*AeX zsze>kIV@4uS=MT{I<{T5OSToZVz!;uN0wHW%H|*9B{&k7;sZROJNcexaL8f+iP@4x9iq;F0q(vPIANV}KXEG;>; zbXx8-J+(yItJLyoUD8^l?@#-i{ye>fW191uE7?=lmn|TKo`yHdPNjgp$>525Btbu9 z814cm@h=6^>@}xbW?QS;`q>`aqV0BjW_vqZCfji9K}$tTDf55g3V8P#TyyRsqV}as zVdfCXyESlz{e;&wn8j2Y{6IU=$LX?6U*e=I%yz6`KclhT1q^3en3bgRDo+ z#re3)ID@XE*Q!JD%A=)RkxilQ!QK8T@b<~?&gLrbc$%?4{Vl2iH&Z944o$t0GB&k$ ziaSM3?v&Cv*_S*tMNf`RJ&@8U&7M9w`OGpw@guSSaKH}}je_OG$BbSf0 zc7@7G@4@`P!(?VJbBzVooZDK_cEf%vsz>yR=$Fx7qs&n?ZC-N|(a%j}W9cbm3bchG zy%7{BYrzb58+S}4kg-}rS7bBT2V4y1S9`G$YLXsg0|%q#xs>@%eIjJ&>dLC0q;iq& zq13<>{|;o_VJnA}{>$%(Qn&;~3E&|&2TL}K_S^rpTSw~qHSvp(x zS;A0()PMqHmpRV78rAxo=I7>}=KJQpsL|}TY_hzxbhT`@N zQTzt%ymwGJm_spCPvW|94X!EC@~6@H09Ce+C zoHFJ@IcGJ;@r>_j^HN8p^!U5rFY$N$?-xnM6LTe>N;;UbEMtQ6i~o5rS`C=`upNa~ zRt;qN$}#_8=8m=nE$`t=-IiKU1oUL(sPasz1J|@SQWfQ#Lc$~73BKn%ay=z5E7;xq zaPf%whGn|#wY^O*P@`dL4{@k(H z+1S0w-P-fkmBrmIqq<{h+Uc})sY_CsjOor6zRAI9%6zyabrGxE0#S=%ug4fsTWo|y zyYey;wr=)j>{9*Jln7=OT#+YwQN}i|(X{!P9R9@yN8nsI3oyM{}J>-OzcT+kMl~BK<(>*yO&+ zMUqbZ&6Dg-4yFxFE0J-`+1-)hn&GrKX1F?}*LK7^j-{V-?Q~ZN%u-5fW0^tX5c{dj z*E1K*?#|XSYsbt&G-D0%=$`%&*u_2&bETeC?|26 zYQ=sO)>(f=AIm%^+oQObIcMbZ=d2$$B70PpPBF)ASHx>@x$c1q-wh~EADakb2w9K% zhV0=Mwj(kiEikhiBc=-DCQumnfDoG*F-I|GFip|9aqpaD^wyRs>EX7)3f|gIYX+NY zPCk^h=y%>f*Z;Ksdoj6RO1`x7>3bdBGs-*SGQOri%b1zoGGkLlhI5wxSNMYZg1pA< zu`h{zon>M6gzO`-zsUkWBzt*rEqj{ytPN1&;2D=SV#Xb)kF-gujSkk~QbXy1)J8f0 z_rW$s7cv|3o8N4;#XQUWBip^WuDR~yTAA~1j^EiV3iud=<)nj_QssB)Glz5|@}_Ta8k3W~JNQrAfF$h`24(19Qk6oY4i z_VC}xTji}$nMUH*YKy&;^;pgwd2Zxen!juQ8~Lu~Igzti_UV}$#}u}`7B6vc==EfG zQ#U=RU4(m332^_v=su$xIhI<=RPRac{=aQV&b9Bs-FLsON93P^K5~uXxAaW$gIi*wKJAw3o&2RU`CwF%IFg-0 zEifjjE2LB~MOsJv(ln)nUYRIIB_lKUln;WsF_x_jXK61{*Q6tw*sP?>BcOT>Gm{;+9{-ey1hA`|VH4oOI=H=@h5)dI)4ylHpHB zjm*|I=iA()a*fW}EywMwmQ3Ysk42g*&a@%#p-*L>mRo5NNyV+3@#b?kcVy4#>$v3d z_}T@RN{0*^8!{h{dXTk74u8(pc^~C#R&aa)OMw&l#^lj*hvUj;cf{K4E6f$R7^nrG~<&4WHlU5}qE;-?ETvDy%UcXNyok*>WAxvr-cTFD@N+eeHBe@tj$#p$|oZWLU#h3 zZwGWuh1_AF2#Y(hYWp7LkHPPdX$HM1*$Q zcFO$8+SZzA{$Y)?oUjeBzqT#1Wwm@3ZQL@djp?JZ7;Tp41E&Iye6>QEL#&hKKzzDg7B-AN!^jNTpL}Bp*sS<(TZB9aw_a&#|Uz>_xtdxue)xtZ$9A zHMV_6)Rkm?Xq{sHVOeLsD2(O)Fh%HPrt129HLntktcjFI&9{KM-1L{6#_kliSu#ge ziqKlPmT~Y;yG8XwAORR>iVhXv!UAEby0~LNo_$<`}3({k7uC zC3w6QjMPP*C_Qv7R5@&mY>>;Vo74)>jlYBPJBNBtStR9`K7@0GSBGnd`UFeh)_X3n z+TX+9!8_8^$CK4l#H~7CxE#(~u7d9R?u(v4U@Q13Rkde&4q_zxi+;~7;ZL)Tp>6wR zwwP1-Ce~VFc4(eIvBx-)u1fU<;X9k@v$n!0Mm;Axv){Qn;!@jc+cEnudwp9)dqLZ0 zTXAbqM36hIDb{k9ZRWb>>S8NoJZ#)NWZIvQ4iJs18YA?x%08u})IwSyy$qj-jE^jk zYs;5p4|pK&QP=OrInw?5(Q3G#8A~8X{Edk&(;>HFXYAw=mq*2MVD)V z|8YY(D1FnjGYduxn7^xM}(KV%~qXWTEYqQ^tk?19{QY^Wh z*hJ-mtJze#0e2b%mO=bfvNC;;tVq>j_7JORA9;g&4PrqovljH0jZ`P112LUe^>eg9 z98*3qZFM(3Wp%neJwQq{{or1!bCeOrIeHzzC|OK3#O}&${h{85=p>g%|J6RY{dAHK z;coIMGB3hu34D~wE61tJp)}^D@lj6J{)=Sd6SM@mx4~*H`Mse7py*!E^1%t+q32?6 zD225X)I_+;T+{MUrOLjnDKjbZi}<7U})eH2$$X8!EV_ zWGIqOB@$+)NjRQ(LAR%xNymuk?4rmTqLQ)F)Rk(W%mUHqs`AiOh#9ZFMYqTvbhX4v z9hn5vA<0KMSVi7uLI;g0L~JJZ1&=VRbw9}XvEhr>o5mugF8PjnA8Ae>Q#E!@2cWw#lxD*^D{EFz!eNX1Lx3 z9PVnYDlepxiE`3&dbAwH<|T^;qd1nxC0|xn3az!urieO*t&Xn4Vn%oBj#l1O8Gh&k zpiucMuO_|{(@ZOb5rNL^!ccqWmDtepoJlpUr*6q*nNrGdV;_A*i&HsfgF1>nsXUgO zTK;IaA`-Kjo$Mb&-vis@p*l^qa9%x?j1AYNCB3966Svt{gsV->C(4D}F(<(<9?SM{ z-xqHZh2`h$b-5~&Y$Rx-2`ha{UI7J75w0KH;p>_L^f}|1G@HCf)z!}$uVi%9SY8JQ zfuy~JDrAD<3{?HA&RdQYU7qyJ)tX?7xDzB(xxQT7i!L(#L>D%QCPL=Xy`f>86)*bYVN+y>UMYn|l;hXvyt2eKDh8kvgC^`I-EW%xm#2c|} zlsd=IspXNj>`J(CKOvq)23Te(qC#?eB0Xq39UU0Kmf=z{n@3aKO~;@$Lnc|iMkd1Z z^pM&BeQ-6YUfMc23%7^tA4#G_av+Eu#q>q2OX_R-O!d@N_M{(K{+b`D3A>rULv* zU)k6K?qT#P{zqtz!CRI&CRu-wW&G1EQ{^+L4m}B7VER(Aaw8%^yF-ga7I_1ctX?yM z)JbWAk;)W9Jhz9{BOQe)%6p?6(M&7CHdks&pZRpSymz4H1s~hrO3_*|^1js7GSR)5 z|7=Q>t!%_JFq|aRh=lY>miNwD<`W>z+!0~|M_8Mwsk)4M1YJxLRY&bayK;;oGq+hiId)qLQZ|ToNUeOF3ol0-UkG8o=`oii+IbE ziFPu_LOyI@y`$!$IR06BGVrJrv#+ zUS(|`j-w|ly{R5V6L}2kXGg>F!lBSmHeNA7r}-+9CKe55rg>&d_@)|VOYsr-N!67U zkko#mqv9ibLW*aXD5sh1dS9(4eMYTFyjG7;2i5xU`(pI9bpP;hRQIOK^|gak4B^v8 zC@*y%f85pGoWQ67zfe^Az!p{WN^kjkMkf7}S{M1WEa(;VDmj?^@aSk0M*BX|Zw7-Q8@>kBJKGAtL*+ek^`NnJ4?E?bss4&saXvRI3vS??CykvW$d1$b?2$iU< z)LRg1X+i2j*wF9y7!DYuGD^s*TvY0jiwM#*9^|F3;o;OPeFicEy(6cn_0&1|U<^i39_o_NJJ^%6NA7{fFQYf~vXsOO)#uRfiRDs7_Ki{x6|}AN z8;}?G5!0wm1ZO%Wzo$9lirSV+grnj%>YmA|tyU`OFrPH4sl`ZA`9Gd9t4 z;CNF>U#*WcHBkq!g|+3#b-yw#A@+a*u^i{nb8;6kK{>0|VRx&G$yB|QeuOzGhv3L? zo6@y~xEqbqJFz#F;>J4Srv8y=rw<~>YLl2(*!wo?yNyQlQ?MQDYCh^Ua_P_E#oJBXyUREla)I}GgFd!guEypWN)FbwGAPW$Bk0j z53G=_DM>r;vH)N5lPoGo%_ z5xlDqxybZKokc!JA8jdx0Q*m<+iPkvXw4c{D5n@m~tMW#5C(z{T)x=%Sz z-PSDB6?jr58kebM`WG!pAFYXm4&p)9!5nnmFSL6-@V2|wVM`8WXC(oMeHOF5*MgUaB&%8$_zeIIirBKjJ&MXC(jr? ziAgvshruUe3!J#t>2twcj)M|oDLT<;bcT1KJCY<*RcS)-Y7dYa$`fC-u4F}Mgsx$| zpVty}FV#eirj{5r)VApO-2#sxMmL*s(s%XV)I0PQ_e7t69i1k<;3!-)P5I|0MMSAz z_1>n|S|%b452`!bbYiAyt6EmSU|iK!8{6oT`egdC{+HaOABSgCYp7A?8qdi_=u&rr z`I(!z1ma#FwJ_O=cmx-wZs@_!seb}9v?tuN63CodcdD@d!+1v?(k4>R;ByrVmcTOd zi1vi)tjz@Jsg0HnZpv~{3^to$$f2XRH;2oCd5(>Sxt(2WYVceh@5&$GQZIcjM3tHW^gFrwujuF zqECQV&{Vvm(u8U>BnzUa`mFIycZ1~6idj0kdW_K=qp}lO zj>m)p{p`1Kj;+Nrn+0Dv8TF97L`~cu-V!JDz3_Ug3=TjN9PlpdH}Og~bj1JT5j_Ce zIy%w`g-D|a^7IN%nS5tC=XcYD@+TWV|VdF`4K9rh&t#Z~Sv61YHk3$*+(J z=D`1m;CoBq(@Zjrfms`FO2cum2%Hs$Vuh`R?&=EWM?X+|j)Aq%2z<6a#t;1?UT+H6H+GPVnxa>~A!g1` zBgU8kDo#U?2J&F+cN=^4av(cw$Dj9Nd`f``o~Cccvsj9*{)6B>eAPc2_jORO^d-=C zi6A6JbSK{LX8frW2ygr0Ik*-iiIF%hisGl#!)q78JL-U_y9LP9wc-D|1LTK)KAs~m zYQ#S>wGK}&FjiKB?=l1QRJ4S04-dLZV&`tFJWtY*u5;s$kU= z#wR%my2C8YUtuM%k{%nyz(o2-Et?HZV@s^O$sioHGyTNgu@IcU_V~A{ zSjRKbi8uy|A=dxrQhFkF{|XteO)bmv}*K34*mT04$JRV1NWLkH%td*2k-PK<0x- z0bb=Yo?jdsroQ6mKEcR2@Ghpq6`?KWbumzxeqy&w#rrG-6>%%@b+TjM%@6;8nfQbo zL8Jiv1>>>^O5?ZC>)N2wPJuf28ff;%z_b4hs^krDM1Fu2J^?hu7NC$fz?cn&OX)g{ zZ7vWkydci}^WnLORr3UEsw75uJJ$OW%u_4=u0Cis8lIh`w=hlDMbjtT;&vFfbk4M0 z-*03^f9zn)wmO(;kM;kc+l|4#{6xQP6vF<#4%*=_VA0=!EACIMhr4(V+rie@1!DPQ z{C64@o)_>F{*7U7?0ju8k`2K)`$y8!F)n7zgA(BHc7(=pKX!w2aG9M2rp6}lZVD4C zh{vD~RfdvtFsTx!(MPir?tq)ghakkCBFlreS&TZ4uA?;cW=YU~BvSjq1z!fz?_wyO zCqq%x0UZEMz$7+6<$O!-Aa9dD$(iI*sM<{A6|gi{f^NSGoxpI(12gk8#(D)t;~dUM zhyFtsz-L_nR_IC)>)&h3;YNI0o2C)^0Q`BG))v{+Xl*=5qcgNE+H^P?j{`gPx>g&U z%nI5P_#fBM{?kmLbFlD3?tt%{rES;x!3ntozH>M@!(FvMYP{x9W3*zLsusrA2I2d5 z62#K$>L)N3>S`Y~9{ij{y$0S>7;dp=z_zS{Zp>B03E~;y0%5i!`HAQUZonq89SEv- z$u&Tmb%z#oDgJa3TDB?VJhBS8pX>`F_)u~K2&?eWgrc%N`I>AA-p3Se=?nRO!FTyn=1<`MCDl{{+HB=@vB%CiiNODNG z)a>Bzv?1S6BhewS96H9_Y*%RIY|Ll)gjWE4V-+(78oj@u^LW}$e{F@wbF!)23gx-HBxWJb{Kb}FG@Vfb-qxG_B>{oUZ`aJGH%j*I4 zwj-24*Qx$g1o@7^~WPD?L{K2|Ddj%$jCed+mifT@5Jvm6D*6 zjgZfQj512TCl!(>N@t}e(ln`+R0&no0g*cqb2vBZwxxpypn%kTJ^X_MLxW`_(efni zpmCQhN#}s#{0N($9fnb2X@xpL&88313z!SgT)$)TFe_;*RgIhmes+GnyVg$)DFc*^ zcqYx&qTqgv#|~1D*o$u9Oyn2L!G%O6xHa%}WhTg8;H_d?ba>k={moCr_QDcWCy%m! znFOW{lSr?pozzkEsazm=D9oCGQTxes#B>?#sl1q5Ux^s3(Z{5jVyPdVkYsQg;?_|@wm8G z$S!n3&6ov`WDtIJ<=1iZ*mtPMcESC1hjJ4=Rw<$3p)R2xq1zy%|0h>av+3(WHFFSY zB>NLIeAedB8MeYSLXntJnf= z*dxr8Ht4gijb6Aobm`oua?@sL+6gFQk5LuTNil(JOj@v_*N_xtr}9$I$U5Y7;uOxx z^Ej8sfZm%7hT%WYuz6rSOw`@jg?@o4PN3WJpj9(dGdI+xCOLO9*_h$1dXs~=YcE4Sb8GW7}NX3A+4W#5N)cr*1Lhr?ycorWBQTMqY=np+iSUuUvh9ywBs!?rZC7?i=gR z=7#~a|Gqc3@3Q-|=Z7oSGuM^HQ^9-3zdy1~T}3f`5!-nCkLZu~JhsASJ0CY3F}-4$*hx{MRpfnS2VGD?(p>qrGE&FWCXYh(JDHk{&+rVcE8XaN>^43kEH!U3 ze=#RQlh>WMa~GMt6hW0Bt|CVIiKxtsXsQHq#y@Z>kHy~%M&Ce1QYU}o1a(2{-Ia2Z z)v1}(JZP{_gWe6|BWnKt$%5`sqEeD9SH?(zi=Pfuk9eZ(wEseTy-b{3JItG$I>}K$&tQYyi48H z*7P{GZQHgz$;P%f)<*x>_GV*qV>_7`!?w%)-R${JpPo!+db-Q?zRz>-@9L@=o#x`U z<@Qsy{g#P_kS3nm8T&5rVh8b^=#;x+ypcjVsD0oWOoP+1DZ5act{ZBwm^PVjnMKn& z;|bkf%_f>9FQXpOUX?&Dg*U`Pd`C6=vpiP09|IGBs>D>rUoTZBtNlzKINu#i0M4_e zP;4hMdM1IsLi#}uOp+e*|3(udC&JN?6bgkLp_RcV!H)i^{=vTazS;ijfysdhAvX9Z z^w|F?Sl&MkYq!F)$Jg1j!aq8+UCc*hv*k_8%{#4+43~8Y>bcBw$_|#!c5Q~vp?OSS zS4Ihk!+c=5uY+%iFOPptkP8jqo(RZtfwJ+TDwj9~mSiVQ9eqpFOG^#gQp*a{DBUG> zCn}7QH;bCbS==vLiY9EC3Q8Ar1oklJ*+zOrmt#yfcCfs*6t!-(%r`YQOkmfk8`2v< zxo)%xlkTsk=jTqeTh2WbLtdbl_>+?gB?oYbxa}VBl0c^LcL#!d4x{y zX4KE7OZPzKe9VoH4vicPl?p!gFZ1y}r+13KH;(mW|M|cIzZ`rJ>>0chn(SK|aCy@} z0w@_c?)~EL@89d+7)}sA5o?*K@r_}!<-8$5o5Q|>3w;Y)RXc&5qdh_GAWTwm-WKT< zJQj%apA0kzWQM*+R`O58b8@=8Irbj3P!WCCv4$pQjisJtsEOAX(ONV+ppd8Gt=@ro z<|O%{SW`?Cip9PYgQ#A5r*^-Iw_GqWwx-ri)(h6+7Qyt@cuMyfB(7;xS)wU28P17> zl2=+Pu2!m`HkgU(69os|OlGgTFf^c-G>PgeOjS_NPBA9UN_AO!7?Vb#v!S|+ds8u~ z89KMUBPW8t0#gEeg8740f-M6#z1w~Nx>#R*pDwT?^iMcFa3oaM+uk$5y}?<=x5w4p zUoccSa#emqby2$vn~h%$*@mYYr_QDobomWa;B#N1X+xG!HbuGczrm%x_QCW1rGY!4 z(xE`40hh{c63z*EWeahP^0B=PDQ3S}W14BGtlgz~&0L|IGgatgWPqrTuI7E=Ky+L5 z3g1epNmSBw)$cO?upF}Vv>vcLwPaf@*45T6*2hrGbli5K)&^#}+j{o5!-;=wk7j7SbLi@FLM zqwj^A(o{5CM}QS?uJ5ViH4oL@nbGtdDjnCyzmegwh^xcxmn^Z$^bdBI z{(z~1`Kq~%Wt?S*474@QK&+TNXp(TeUrb+6%{`^@Fyh}2}QdJ zH#J?2&WU!6c7=9w{10FTesi;&D$wdCVDcsIPxRp2o{RG;uP7c&eV-D>uhCh z7Rxc?ZBVy*vQa*N18g)(7Ysep<%il!3U^yzV2j$1ob)lxHmJD(ejB`0^FEdy@S~Cz#s&?u+ z%q(?F%}(_)HXB#F^8oDzX(dpR9XuHJ@Um4dsk~Sl$?qSXt|5Lp9T1bP_Q`dKoMi(s)S^ z;c>gG1mbbe-XH89`Wh-5ts{b9sqUn^YMyPWZOd7E_r?;SI(0X8(C^pAG=4Z*ZXn015bh^k zh%TsV-BF5R4!m3e#T@xIN6Dh(Oyn`WK!(aV`YZezRpH~=NE(PDidN{z>q2*ee(yWi z7|&$aZdXOu?5uv--!t3&xcRf+*IB=|{WNENbH4C?i-0w$HtPSjl(n|78{%GBkH*n@d>6T3+J1@ZA_Sv^2cc_tKSPmolTMdPEY^Z7PB5+5}_p!Eh~hg}2MYT+rOm zG|(<$BWyC8sp-rbG?|Q(8b+Q%7I|~jY%37|V%%B_isQePgYt3A=@nIBvW%)YD7A}X zmsO%tLLM!ig3EpmyntNbwQsTipeNv7>~5Mn(e)yybat)mt(mPe2W8aG8t&TUt-?Q5 zJ*IeFSFP3X*<9Fo#Js_-xBt)a(IMJP+AkXyv4^Q4v9m(O$R+3yn_-l>%a;mzRJ!Yi zySI0cuVVNe-x8g=f0?a-Z?6Q=FvY+a4?%O}HP5gvvwyRBZFWnBK1Fwk=?A*WeS*SW zt~6C0TD`>tbRDRf>>?cGE--OY;h*jbFL)LbgTBHYogSN{S^;LO7g=HXQ4`w@O6L|O z5!H?Rs_DcQ)NHOOhe2YvAk^nqM!QD$zw4xVZD9i{Sp=!1jwf)_wFppNs$(*ztGu6xdX>KAnG`b2} zvjySHq4Z$QVBNsNU?{vPdYL~bl$HY08&p5t4IT8ZbQAOeWFG9^lh2t98A3mH~sAHJPQOHO!l6R4qwTY-hbS18$3U(i? z-elqc2(F`uTSOUhJyLK+QeTjQvzS(61k;f@ggI|F)Cg)|y!RSj*anyl-=fkm*V;

B>E6322$yH`6&y~eE9^NV^ z6akO($aLzhR715RTYiDs&}g~7tdZm37n&xumvj;*mXXRxU8I@N?d#+msf1h-NiReK0G`p-Rx5=?Qd4dIa5!c7kzGn3BQxQjsT- zsZ$7Z>0hW3J_2!gKm69yQG4BtD*pdejgXUgm*{|+PC8it-uxTz8J`5Zr!lg8D#0Pt z9@$V0n1$3%)h*PrjwwBe#^h9Je_i5Wp%N0$Zb7)r#sn!rtiliCM+uW92J4<9kCt+T zn9x>iCZ~bK-cqS0KM{wD6~rEr9Vs!x-W~4U1l>aCXS`Yu3K{d85=?@jxPFuVtwzuon44rSwUoR_^oAm?A30Ywhuo!_44%Ye z<%v{P9we`px1$D|7sq@I72-|cS5$%{fFK{Brd|X`=i!)3z6}*&-Pj8piIiLp{QjFf z#ckt{VD`T}vO3Z$v_7PSCWLzg_l79Hzxo zX0EQA;h^!lJ!Tyh_rUJ3rQ7;iTiD&{g>f@?Xpf--ioTD*n=zBrLZ-y@2UFi)!jt@tIp)P@!z8h|pXNGI8 zvrBHLob}mPGY@7R$j+a8&$T>wK3okdDYLqWAwBLdd%mR6No^7x2@T`g#2q)%#$md- zbO(9@<&eSM7AJ8R5QdT%ih0f6Mc$a}tgoa0V&HZ(RTv+uOV`lUHY_x?wA8X+wkE?> zG{IWUrnXeIt~Be+lgwp|uMIxkF3n)|dTK0sq;;i6(s~dRN<&lB2HlSZ#8Ij}y#!vx zesnsBJpt6e_7DY-y*7)Ofjq4a#8~nGaR&F71@bwur_h``8l4p>5vdWb9=Z*dN;EJe zv?mbu-S@Y5=eSGf9LTDWsr!+Yv)1{SCyzK>8m01?Ra(kc6&GbA8=B>f&;9jPEX&WIukC`vKtc}n`eBiugK;B3-n%W zzqC^B4~k6v$Pd3580Wg=ZswZiY~roqZ4)fRy%#PJq`H>wy1BKrsw3d2<+v2TJ1#SB zf1KGi(ssujHnlKr)VI@}R%g@b6(~hzgSdtNgRduC6D~*(Wd&}lF;t9h2xbDsY^ASJ zEvYDKlz(94TZ;S_{gyIhBXSw8%0-ndX^?O~dONZ?G(R{X81#kwtgoN9uJ^o4c9zY3 zpT$F(Kjc^N=hKY*?%}=(2>-dr)YEJ1>l|m2@1)F0yp|NRYiv=|Q0;zpG!;)TCET&W z%4YEeza)Ajyud%t_ebtoXXoq&xsP3s+*3n|(WUYk@`y%n=x$E3U9@kG+h@OR4_VLI zEVhl-)7CHMyXL3Hh+&|vnlrmOivOtb>*dcoIfVC4XokFss##+%@%qVGFsxygkv5nF+XmlS64r!@eH14_uc2WktP(@Q{w{gZw{HYR4ocF4QMPeN@lvtC3$MFvOmhYN%*`lrB)GtBus zyJ}|BUpszm`#$meqaUWsf1JYtb-2SyX?0iiRP!$5Cre`7N8?d@9n&e(Gs|)PD9b}^ zbKW!_e*Ka38)bkrH1a)gz&GEu#9hzT#qIRJ@(+kkhJ_CEhCgG+4#B`$*@O}0qaC|=y!3U zGzN(z3y3>p0-X;%kB8uZmP6&L1?bT6$O%4+(N|%)x#ULX?q%K|sT-*ksP7-`-VAqQ zC_6Em&N`Uo&kSeQ&Uxt6`$~rTL^n#06^2y6#JQ~Npu45_nC-?a>kPBkvcsx2SF%8`GESrg|XNhb!x0v|zMHq+~QbqC(Z>L3A8fonOKu=~Mh7v9V3WNIFqt)m=8k z%%7}9;*_|~j*%1{GT6rONmp{r`U(3=)25gjFtlHUc4Y_l3Wt!AfVb5H!ylM;tTx)&Q>7M z)9-ej`?mMn?2HXLpR+G`R(s;PZci8|{tCU0i5$_yhk|tSi5V+UzdwT2-iXbTeijTUU3( zblN)0ZgkX*s~o4bo-$W4_R#f)e>{)wBnTp{H8+_f)K%gV@{Wpwfwe_$5xa#^ekrKm z($Kg6iCmcC&}#UVPDm&0Omu^qs~g-d{}a0KKcfo~Hv8E>-QUQ&2!w$rS-Z1u{JQ?D z{m&;~_k8XArE{h~E75yeD9H80S8Ae-1iY1&ADIP6kU_d8H8KYec#Po&IUweV${UZLYq#+N{UF zivFnneap8aUtfM%^EvHH%GbRa_p?U(mI->H4d_IvNQFIXY!}B_2jn}Ikd=Bk4YEmODxBDDTs@t;b3Qv`t}E{K-gHzW76ti8TZtyysVnJCrZmfY z>tnky{!T*ml(ne?^62vXn;0Lz&Njl_6z=!N`oU}-b};o)6_oEvx5cmG8YLoUpuc+_ z`qcX90k0tx%;op0zk^VJLtPV9pG;~lorxL{LvzqEF>wXv5=wE0BBMioe}ea_ zyQBBJv%Y&^cIli78QXplKU02;`jz!%#*gNogBfRXuSH%f$GPgUQQ|JmQFf4dQ$h>J zzPukCUz3+4QgN)~xS6yMFt0G*)1{H=(xK=FFm5^r^9RPk%O$waxG$kDVhNN9rwG}x z67&Hu9A;|Y>1~$U4sCLgJSXxj%TqgbZBommtb~>cZ5?gn*4Ta;kLt4ME>t680km=& zqBn6yo+kH~CQGfA@3HkzRI~$Ys}tD~$)*3&*KZxE`UNy0a&SS+MUr<-c)m=@Z;#qUTek$fd(UUG8E`;O7l0ppUe+6X-FiHWIRk7BlVLDD)*%Y$^^NY@<~;QJVE|J^rniCB|*#TNqz?J zwg>ruC-)89~-=H|IzXD zN47n5AaLLRIJhhND%M1u*E-eiMKw>h?@x*(zE8QFJU=;_n22ltA^jL_RU}03jE$4V z32C9${^G7Cu9+^M`-F$_cs)NnEB)g_#kkYb5!HA)Lmkv!G}>&7;=d(6ORbf9B9AvU zJ@rHCu@rOacI` z^C?sHgh&F9X97~n3)73KySRTeAnKvUsTRu!$=t~B!C*sgLEowDD>-K~_x)V`EA%z< z>%jEtpVqyP|2+8Dn#_km-gU=2+B=RrC8M@t37C(^eYP}-i%)PToy}7_uPX1Rq;84Z z9Um=?Op@W6ZX#2iTqw`rQbPTFVPAXiDsO#MM(xgdxvkv)^Iiyci*6P4(jNJoYA*9b zztcL&@jPK>LYbskQe3J%@0Wa)^Q}r*nq-adhHF4QOG(pZ-Dq|l9fEc+<7zT=1%$Q@o615p9@u|~=Qxq{FDiQaPpCH$K` zD%Y&s`&nHx)_hC&IqGwvPe;-n>49(Lw?#y5R{S zwPrvle|SwKU5LhNs`naZnKxMimZSEY@#m9Y=WSWwN4^iKY|5*oo=KLZbBPAW6zd@4 z5N(DgKhuOB1oiw+IUw{C)ZByU5YEK;I1}H2AC6}Z>MtFo_R#cJK|QNZYzs1ahe3}R zQ*Oxb!Buz~t`Qs(eC|v4Bs)jsw8)J7iu}0#ZPnLWpRfE_`F%r#oguPLP;#MCy{H~$&mG`UX7&J??&Y2tfF$Wc3CcKqnLG^0k}QGJC{le`jBUJ5Gi zR_ITE0pAMGO;r6|z8--C{sYL~7!}DB4y($jH)@`MKheW@!0L{7rNrmGmS7GYP-sU)N9cfx?Og7xmWwF!gslcWm7i$Qc1s{KvOF~bjYj{q$XXIJr0CyNf zr-HnZpTJ)c&WU@ZLh?<~4ZdU|SC3o8r9?gk5BWRz&bZQC+jFPp9Liez^T3ZoKaXZs z{khcT@W$s%brGH^!RO)U@?zCPat-s4n6K?-IAmUB-{UwOcgv!SuM=0>UO&ExW4Eom zX}+PA_AgC8brotI`Bv#5Qv86(@X(TA_{<+XmSqkQn=IwT~ZM*rF z-DuwsXLsa}&#?c`+{)zDf6`yk=4B)5LR2f&I%%x9fgjD${FKPQ;iI9g!CArk!55L6 z+$3oOQuL-mlQ@CttPL6|>j>K}TLo*(J}7=%Qs?9r$@vp6#l5z#vt6>M*y~uUnWEZ9 zYKCq}{YBOzPQ|uLErp?6c62nS;kt2OqIIM7L43{=bwgFP+1Bi##w9ivtF_6j3!C9>sA zl}0^BTgO5fBTku z6N!Y|gboMop#kCJ(J%Y}bOXEMelbsVgYL#|)h8M|8aEm~8Vg!)#&w8i;~&Howyn38 zvQD%GZI!KV({f!wRtBedGDFcak*)kCl|U`5m{12KG6UBn+94SNi^|2*vdSuwGf1ul>Pt)IIYiSg^ z59v@@FfN@YpOO-V6I@XwfSik%qQjBS)kaJ}HTSaQP%09qX%o9jTUt*lqS**1o2#`dQj#8WY=5vyJ&k_Ei0nXJcEv6H7~@khS6l`*1IE ztm62c+$p%=&D==t4Y!(~%J&oYff4x*?(Oo?+;BqF9~u>07O3if;){BpdQ9%`uCAW` z?&Y2ZKFZ(Jx6ixVPlUtVXrZA{97i&(93_4-U)Z2_k?y8;w{AT$5kzwg2|HQFY{O`M zOZ{|1NViritB=vU$xKxE?<3XnAE~m?8i^;bqMM?5xo&)Ese-&n7$@$R>cu!xtGFp7 zX^r7eT~%nc#<10psuj>Um~|Mn*f2vp14r}>MV9xY65U3WggjyaDP6vWZulZ-HVdj; z)L8WsR$`B6*McV;)dx)Ttww7V%Qa(b<6~1fGjHi>DPoG(SJf_LEy&%mu@9Io)O@5` zZicQ}6Z=PLBQsJb@q+kTyeEp{8*!#M5!a@&aP;31{}JbiwZ*N_>xsg1aG^tdgxknX zh~AH!46ll8gr;T?2qPK36yG8|9f2SI0{(+RHgGp+3~!6p1*LW_t~>-rnCFPj^jqeH z=A3#d%d(%eYC~21K*JF9b^B{ufj+nx$zmbs?#9w#R1^O|2jnGc?d6nRQY-1BxJJ4o z&xONk6Y}ie$eUH4sclRFWWnT9yP276g1)PxC!6QanDPVzw77_`w)+ED;WZn|hnO z!u`uX0r%#A+}Y^3=%10f(P5E{NG#MhR5n-|>ZzE2ZD6Bc0pHjgXcB4_B*L@8&7vDQ zhma}06%W?P54O zDZAjQc!=6?8!-KJN+tL%cB%%Uqct&hgt$rN(0Q35Od7M5nWNdpUe}HTzcEuQX%Fgl z>8k20=*#Q=r!9v(HVZpk^B0JL4(21}hZZp}(G_#kO0lKT_CA(NBQ>ZN(#N(-honx@ z2khluVw_Y*QiF&c5fa65;t$a%t^~!QkMKr_2zU7RpybrUa~XFdYK=aP9D%~u5Ya@~ zNRwzYE)wm`r9}sFD!#2)O3IJa!^*NqUeOb)7TR{^zkXi%v_}|nzx&`x- zsf3=d329`BAV>zmH9f<$W-_7oEsnI{g&;GRqi%z!UK4Yocyx=45`&1@Dhujl1yoDW z`EL(5f)VbHp^9Gqq%4t{m`ARzB*~OwlxNA~q#3dx(#lw=p4?27#S&t3se#lKY}4)H z7@>?WF>MjIR{$!DX6XY zhGsGgO8TnQ5UMShxwELBP||+}ZT>0s1lrRQ$T;~1E%YPmHaZdsU^*Y7Hc_>a^YI6? z_ZB32{EY;<$>b7zvI>3iz38`=MknGQxE+eaC6gaY)K*|G`<3I$OXZjHNhyx2!*?(+ zvXs)XLSVecK{ehEDVe?DT`30r^IK56*DF_)3Fr;h1+n=9G|(&MH2A_B$yH^eY=>^1 zf>s_NC6JLSVqCO9zJuh<(n@vw%{-{h-AZ$~7EWMqzlR&h3WwKBbV3h<*u5Jn;tOzf zMUd@jCyJq0+!9>!G-52S3{#22&^2!(j-#&o1zqiDP&Z#h-|+^K07kwS|9_0{{~qcSTQlE-FG#pC4*)6}DTJ;)mxmS9yr7_W)k_ z2g-ZpwelGM|Lp~Q4*u0UB@^o$Q8eh&r@{+V7`^IBVA8jMGpaqhV?&|&oB=1&6l8lX zKo?~{{C{VV9egr&{y(qFYk2!!!L1TN-&BvYM2n1a7RzpcqQQ#2nfPC>d47B?g;en> z__s2$T5ID`7p{)-NNjBl56^E(ML+lg#v!SED!Ax#!7X26vi)jE484Eqoz^@N5mgTMfKpX?#LSxI{|A_frU;9}k~FGIS+cc;D0@hqHJx zaQcL>yaDuiUD%o-czZIS&&$Hr{RXeiM`$%ZV7q?+{p=;azWZ-WzlpuZ_I`nNeS%lK z#_QgpOZ){sm25mSu|Iym1@sdYuS~pOHlCSy{KU_}XMu_qmuNh__#{8}B!T7N@vk4> zvB-&6!+E^QOnz7xgBVV=(?$^cPUaW|(WwCW@;=8JN zWnpaD5_nZ%xMkwuP|J(s$bx*hZhRZJuEp zJ;&FVc+XGx`Wo;29iNqpZKGg|i&$e7mPd_!qy2Bc=&-bAEUy(yn}qIpGQKZ>r7w@I zP#XWM_W#F8C9G)`Y=esUX?3h;ef(P%TdM+ISpoaLJhnm|Bt$pE_s!rGYK^XbZEUC7 z_`d#s+pR2K@jK=Ice;IPY}3-%R=+=^6h1G{e~(lwgBcw)plF~2;IZcdSZ)$E@f_^w z-}^FvpXFf9GO=D>tg{PxXg7ZI3qSjP4E@BvU-4ZwzRSS&&Bf1r_?zE*-ifX6{%;TP zI1VGY#`&>5`2WuB-$Z*6$ASubD*{iIALsk;&-P-U{XRl`S5qY{i;#ya}nEJ{_o!iyp00>-iP0L@Cq-E*jya(e(WI^me`M_ z`-Q!jj>j+j?km>iE!N=!*5xzy@{`+WI*E%?31 zomlUz|IVTB_}>@&{0(0F7RSc#b0-s@7RLGsSVjZ(LP6|j8;*c{xc(Kvv0MOW?Qaid z4g9VW_Gx1r1y%4YgFV>@d$u`#))ITRDfW0D^eKLy{dI6PXszmt|Br>kXbkw}W8e~Mk@9eXt*ZB0kz!jV}D+oLG<`0s0m3E$IrCmHYc z3tP^E{YfLsHWgQmAdV|qm4$bVU~5Kk&$^3q^B(ryMSR8vT%mt^8^7W_wZm=2V=11- z*2E6uZ+78YJr~D94|s3iAr1Q#_Th0{(?i%tFR-Mqu?O#A8E@eo33Ry6;b=aLW9c}K z=B3z+C-Iz&_abm-Xc4OiFZ)H@JvPBndLMU>+i=01#j_?xGrOT{oP$^1_|GL-9LK^9 zq;m@xJr;oL_ce}#eYm%1ah6452jOKJhc%oRbHK6LQ}tNkadg#C4M5U61=nE}jM|&y zp4c)o-C*@5zU?DRHl@4~SplcleW=qjxe@<&_tb_vz$V z({SNPeK9YwSm~&ok|gO>>Q=#NZ!k^+#zMec(zDRIN2PJ@+&`9XfhKd|m z0uNLXvbJ9zzPny`w>4=z+?;xc(GHb$B#6&7sbN#zVi zQ#0g#R88q6#`1pQiV`Gep=SOWBM+;#k>`kD zY&3BQzUZYf1DuH&aB!B5g_K9JM2w`8RTrTqEJL=#{j^eSu5wJC7Tbp0wcf}rI0_2e zI0bA?m4_^h8B&xiL0FNsQW9=lPL)G81iyC(*&bH`2F~s6;CiirmSHtnPgMp?wVU!9 zWTG@7r&1Fr37W6-v1T%@>H=k%9pmI}WJQ%dHcaUNrqdY4!mOaH5Y<&BV_lS);5k$V zTg`@a;(@W2$}KrX3B?-1vuQy0tA(n!!pXfb_g;m4w*#EQ+Qhk7WoOuFofNHMNqXh$6_5 zN+*s%Te%V)ABsN7TxHCf2kK&)Me5p`E1CxEWHw9NL;FLgLh{=yty-6*U8-H6D+k}( zXSSSDklq#f10nG9-E#ZeKN`xMdn9A$nBnpZsFp#TgihsWri>dNo@Uv6XY_wzb8xItxa5o__^`36OJVi2@T`_ zb-an26PFg()xOrsn@1S0>N45yNJ!Mu86>!Q;7S~k--=Cy%E&;e9-S2Z94(AovB%LJ z(f3GW8;U%__q;5q#i$rB9T!_7w{0Kyg)71}CEq)zBv;EG@IdA+h{q{ree z>$>55pM4-ZF7r@E-mH&VLvlkd9|>qB@W}Fq{k!9BLMkfxBa+*sm{Qg!jYtSOhB(&R6QJYBH7zv+P$xUW z@Kh?fKs7!#4sPf>@-eBTI8Ru^U*XpB4fwhINT{M+LKktE*hw5DZV}E4{gA$v4{XFi zBG0{zUXR8_cECTdYwzS ze|UR^uSfC<1>t5|pyJ6Xnn&yu{eIH}bHsAhy3nRri&`y~V&?s(Vc>E;*Z0%zW6bnM z;*@e*@*vk>ZM0f+bfk3nWM~_lkOP8OgAc;>qA$72f>9o%dP+~xkh)j;yT%%pYW8^! zeZsWFr%Aq~-6=y-8YX{DTnwd7b=zlaQ%jnOF%}1lwYs_~Jp-PsN1z*faK^8N8+^Q6 zThid(tdjOiyTM<5wuFfA#@qwri-53|)>?<+n>|a&}BdrD-U1 zni@f`v)lUFHpzCx_5i-89p7s$<$^y%080&&9%}yA$SLBm1Oa?l81PR%;7h16Y<71=E2sTR-0Xoo>M++8Ot@( zF=Gd$j2G9M)Mol27%y+7rqKHKidKp|3nz#Bhdu{a1--$z(D%^f2+3vfrKBO?_oPuH zH52vo4DHRWto`gY<9j8%P0UFACuwSOZc>e8F=0l0Ag+zAi)Di8KA67Oa6jnAJi$F+ z7Qy4n{}67nFjxnpkU2r)?%fw2qVmLfB&z*R15)E|n3w1i8>C!=vZ#=_PSA4?BLyM@ zQ5lU3eDySUzjRl1{+;_M$Dchdt6ssMm5SkIu zi#_EkV6A?mey|VOO@@l5`sNKghMoDkX;8Xq1TITLkpe~Jg?wyLdEglVkpY^-L=vV6AfkGm4zII%!d zouq=vs+9c6sma3=Q{w-Qdte)4S!mjhE9_3TsU{C|93JJvAdV!);<5KXz%9@Qc^ktp z;=ipbg8CEz2E%J|KIKN`ZU%Xln5xQD4#*)f11|U4oGO|b-WseD7=tVpySqv5>l}NI zlGz)2HD{(T!<11zW5&+GIcOjb(P(To9E$MC;@ za*Mm?`t&$riwnH)NS>`)KvrfBX_jao7$%$C7N>QZt-Y_1^?n_x;r%VZB!c-68R22f6q*WfjrAULC5Y0 z7v)y+XN9}cbLDj`g{Vz_qRumFq()TGcGML#%r-tSoizPnYHBKKI%mvpynzJRn@DUb z%ZlnU%z3Imd4w3CYWSNsCpQE`rJ-C0)!oH%N10Uq1!<)}*tdf~6%jEb?hb;F8Qedw zk_^Sd5m}b5%gOR|X`EOOO9V1BzmyXq&d3@Np#5PN+{RUc-hdGJCr}|!AAPf0p|tR> zaB9Sl?1(4$>6R>``Yk zS@c1ih3mjs`AB?G4FS9E5a!oYFbAH4DvV3ji8x9;gs1)!*@9Bh<&giVf#+`=vZ$7+ zpQ*2^&tvpEKwSfFVNxAselgpS_g8?~Oq0k^`yYBKxgfR-z}emn&ZCY*TVgeF4g35$ zssiItw{fDH-3U3IMNr#)3I@+n)W`2Yr!rn)m28aZ-$*N^BB13LKvvUcut45moYVvU zvvHAk;X&cm;dAs`DNF6)G&(`q=u|B6XR0AIkz=Vw)W2jMbP=|bZn6}07Rv8z@)yW5 zC8+G_=saYDYx;|*f&BQw#8af?t{}$#uEwC!Q(I|Ft%!dQ+_1B2PwNLJTg7Oq}mA*bR1EQSdT|(B+Ru0J?S{HfHcxo z(O;tA8azoZrWTQZfyrc~pODE!F}e@cQ#A~k!e5C`m?a%SHgapKF8mz5h|Q!#6(jdi zKe2y36r6ghHDps9>1)BrDMxC_bWq>^#rZvi(1EcQh>Zam@U=V>)VDtJYE=sy>l)Qm zMHf4$)RQj6+CaD6O!Y73#<}2NXGyP=vEo{}D;!4I@(t;eI6YPbN|+?&jId6+DH-5Y zyC)ii?&4&zuW+BgEi^^;{3+y@SC*!Nhu;}2r@x^Ykfk8peKX}7;IOq)9H<2vFbdg? zK1NBTqMcJsg70V?xrXvXC;u1b!rM^c`IG7iche;DJADEPnoHDIxV}zBeflf< zFruml@tUZKBYY=vF_$q#L0bu;veBR3M5WWo%trb#J&Aq|cjjkMS{~v$aR#gv)HRsS z@NL$EGqD5Jl)i$L#{{r5RLmdD7W`~0j`g?HEYJXIk)u@yh%nI)6~KGw`PEi7fH(12 zZX_-jtwJw;1Xqlo!?ok}TyeeucL3UxV$tc*;~c@!Q7<>0vvM0S_8H1mf@WeNcM=(z zccXKoT_VjRL&8JDdBZg^&VMV677s~16&1QrNMa?Lk(-dlZGtmx7h8vIz*<=_k3o*H zVivv`q@CSRcOSvMs2q`?DguvV1yp(Rprd910WOJ{OO~Q_^aXl8(?Y#OQ<8OQ_vqT_ zXBf>N$ew-WX^@3H?ZY2R*Gnr>&`NqTR^e(d4REVxGGkNB%E#3{MgRk)pYb_=~7Z z)Q+Vn79|H=-x|^ip$b2r-;3GMwkC5R`;Vt2Z!S2DtfDP1${K1w%!8g=@&^O*& z-`B)j(ChFH@;?vAfiIC3d<(IVoQ?WajHpfbQ*YMzv}1K$4Mj|(c^)za&l>!?W!kFj zL3Lg_mz)Iua?2PiuR{jOD=r!hN6nxlU5L))s|mYNC;g;sQyn0OGCj2=3|)==%v-D@ z>?sb(5q9*7bKAe$OWEUX-{9yNr6+XF*t}|ru1?$no$wcWv|fw<4w=q05RkfvCJK0XzQsVsfYH$O!H%~ScVp+Nr z@wdzg)wrIKMxi;u>%nYf9gpxo^rr{Phc-qYaNos7vE^iC^+$HJ{=2b|E$H}=a3=9~ z(hjK5h9>igwH)`X+s)4nHeFxjPn;y~#B$~C(pvG5Knul%G~uB*NA3x|#A;#+d5tQ} zoMB!<8IY)6!YpKtFm~i3oF+=B1}WX8xx%C9qfp!6B42&aQfF?~&y1=WJu=>9w8|Ko z(K%yGX5Oqi*@-z-oG;xD-`~Lqw?-AOd8gZ9s$kW{MI2w^kH?R6?1|&-+mYBl9wlzinlvxd#3rQfz~c(4v{=k9Z$(K!L=VQv=Zw{N9E5-DO_FhQz9)e z$HCi~#-3*dwzW3D_9ANQcR>s(gKN7bHe6~Tl;*C5e+Fv>Sl@g1T<6p5Cz(vvfy^0M z6SEs&H_2T(;Ni+^%M!uf;K{0n_KzF~o*p;nP_G%v3f`$&H(aGH}Pn13`|bblE$ z&1F2T2k@E7ITtU%LB-dKZVmSeH3=RFT=t7d zXj|u<>do~|_SN*=@izDV@>KG8+}quG-ACLS&t6YoUuy6Qw^)7z22oSyp}HjN(=OI| zK(U#_X0bYL1Fc^>NH<5j6<*!DR8Qi5EEDtCOZ;xm%!ham&g!GmXn8(pe+QJi%AXkN z#u3+uSLA$puDT05QTt6-02G)?MxC)Qa{2D+?7Bzn4^1A8U7gO1XV!z^_5_aj$#BYa zBwnflDjn4Nk1&#}g?sY}c%|*gKJG;@#135FPouxkml#4^g0s3TS&+0q9kUGjhceg- z7r=S^iZNv=B#f)jgBS}6-%)-U*Db1!JPOSUHV@?UH~01QZt++=YR@PT!pixk`!Ary zIXV0@(vDlg>&3#-8hMv;J*FcHg4NZMPGHtCt<>Mt1vCko>Y8M5S1uwO#e&BtW)d?S zq`t$*vCT`Lh6*PRq{f!;43>j0vNE)&9mpFP6~92%@F40E?jCvJ6?ubq?1i%~pSl_N z8{5=R@os(9tx;E=&2&P&{UALD*Xkn}TUx0TaJrYkthqnIgBv#y-Q{fboR&ZrLB*~^ z<=t92jce6$Nx@vJEJlo}!d5=WZQ{Cc8PRpoZc!?FD>6M2Kw9vqa9sFhsC%dc^60r> z)6m_J2DN~3k%Z{$Xgg@EZ}3NiB&ge}OKYU+vK9n@Ga&BvBC3&UXk{AFcX8kB23Oq< z<`Q#?d5XN%gG`voKrQ_Yl8~LqE}ubL=*d(neEb*5h2#Wq82W-WKauoeL>(iyQxBlY zDTid0?KDSMVHPnTnF{JXYPDvNW;{~F&uS!1y5^@wMn(RlW*xGY^~iROG3%Mnw1#d^ zO(mW1$sa_I{WaFV8nSQiDt{=IF_wM_S4T6s9aJ=*q%#;hc9twS3(la&zzavAHJZ*p z=a4iVy#P(t?(lo$dCN!&DjVz)oF9CLjOr($f#H+k)6gS_pjRykPq-+&5rdM9y4Y~l zV&V&Vff_-t1X=5>T2iNJPHR>m>Gf|ce0au4Bg|71+Se*EG-sa9-y^_hn{U z()U2O_=@qV31rMIM0Jc}D?^*)QgPr6yFpN8K`jY_sk#Z{oiOwdCsA`rSE?!-i(2wBv0Pip5-{9Net?FIpX%52708fr5+uO;P?#~Jskqtj5 zvIE-WD>$D^NR?2_ng;gCO=2H;9HYVrU7cBrEZT)&PISba>LOBn&(ZJb1$1}%G=4i2 zl*s_olIhJHK|l2#msQCKgo7&M8a!&^d|z>w?$c1d5~6K}G2uIH2cnA`1o z>1u~OC6oE7xi$>KK8se4L-{= zy0YC{kZEf8W?INvK3r?dZRkIDcy(M%ebZqazkL(CONO7#49aoe# zU;r3!ZFxbwBV?iiW?Nx`CHE30kOX!?V|73tFZB_73k~?D+`Q=K$RFYEp>@F{feZdC zz7O7z=bO8~dylIT{0~;=g4}z#2XbdR8@lehR36cj;AaA_f`>yjBYB}$f5-O|pGmLf z4zYbIH^E}=I|!`8RB({4LZjFPS@i+T{RcB0KqqaEuf3Q-OiyMU*g$8Y37^0$!>g|_ zN$SPw>uRH>9rEY9vb~V@S65e7*G)$u>({QUq;qMPX=~#itea{!FQJ}~RCx~PmW3WjwWA7wWYZdRu$uH>_?XHtqnTGoU!MWWUq^K- z^(>@kuYyy#j#^@N;|esAUQeY#&HFX>hr&Uv>pr zoU5+$2E4LIa}shMW)IAMn9bxo$vK|ucg}J*_x|ub2ow#y3+tloxyIZa#Ksd_(<@4*86@$ zb<*aXPun9{5L_3SNv$m_u+INJr=BnL52otvOVnXIf*eeQzyEsLU|X&rDnHY=3>*5S z-}PTb+a;X-c^d8h8B7JoFjLzkSB;0ynWfJWvfc(m98%Jlv(9d%YUx8p)ypxspi7E zGa~Jyf5yCt8SS8*b#H}qBT+TfDFe*&GcUOK@& zC(t<1E3i1wCfJYjmAeK{<^TNP_Q2l2?}5*uF6CmmE$es{rZxJ%J)_Irh&5~Xsu)L{@{*a zKi12ii)}d;I*Bvpt?&tC0-J^I4$cca;_vNy*VD&-S6S=y*3_hAtHeLl$J`hBrG7zO z-`Z65ma6=!%HxfWpHeZtd_Q_}X<5&*-%8&t9a!3}^!n1lWxtjERlcgCYt_Q)*|kIJ z?dS)wH{v(3vn?0D!9z{)|Ju{31T)ZoWncs6=}qOd``1EqLg%yp+6g>|ly^eQYgnax z9e?Fv7aj?YMYFy|SG>Wa)M>mp&>>(4_E6bqivJb=+sHwP`jzMVhl9GCsH$`=`KvBI z+ZW+v;BDSqUyRJuJHGwYM0(0MlIYJHQ1DMuaAYshgXJ= zgtvv4a;p3XXyHrY%fg(}Lk+$ltHN&zoaL|b&L>v*f>T!tGOJTJC$CBjjSr69%FM|r z_1?M_HEXNCty)>xq;lx-rz@QDP313@KMS7Lm*ti_W&O+FEMHuacltm_>) z9(_ALGqE;FCB@7G&M?*x4Y6l zFZ2FN4)a55qn5KS`d?~By-0PW8(C$(+V?2luC&4#VnVKBP@xSl;*1L%+a|^ABZjKX6pPCw&3@55%xv`Pa z3{`l~sJ*@>r}~!4HkFeqZmO74_5`?c!?BN!J$EcvdTnW!vcYAe%2VZ?sSh})?vhAp zVxZ&oB>mmPcjcbWIpx>oKbp6e`HE+}e)}@&t!+)b9Pb++6hW{Z z@tyuR1MPx82U~~l$oU|5Wq#X6%?i5~{-^NyMiUEm<~|c@;jgq`bXuklvV-pS_}+Ny z9s%A&a3|=GC$ga3WC#PQ|mj| z&#Qlhbsaxf-&65Vz^9p%DclRzEb{U%oJdpyhR-@GUK z-tZmt@Age0gSOh+(|a3~xy7m^u2|tNvP$fooWS!bm4Du~K6e&m`lc5q@1m~cb@4}8 z6>>s!PUPvx=I9BWlR7%~QDjZ@ulUYL^TdJ3i;)GY2<4 z`(AH2Fp8O>R#5bf&~3p-{eSvK1iJ-RhPH-ght}t=WWUbP@K_?2JZnm3IrT+Z|IZx2 zCeLTSwt*i4?*vx+C;IwO73m#zvuv?0wD;I)?^@rLerD6D7d_0I>$#fKP(Na&)e(1$ zXMnc_^8SbKU;O9|L_m+Zt2qC3Z0e5Wv}8%@`@{oC*xdN7iDS`|W1mG|slPY6yFM6O z7<)OfH1g=dR6XhUzv4HvVY>k`e{^49$CkM#_{VC?KAgr;^SoRd%j?3OwQ7r%c;fp zYyQ^!Ub(|_3Im%xad%m2fAT^0zRlz6chj#qyX=>U@+Jpc5Zet6e&Tf!}-Y_Aen?iX(3ib9&N5<}%)=j@w<{ z7T!kwoxbmV^L>AC4!{C-!yIN!bhUNJUTxp$9Z&7M@BBCWn)-Hl&t?C2zI(bu`b|9YEW>1#PD2_FGK5w9Zv^|@3omKS%>OQEw zwSGY4Qr59e$o%epO}=N9|3iN|yf$Z2Zc%}!pns!p3QBWs3+?ef;hvDWCh-?J-*e*+ zB!9_FvD*6H4TM4!q4}X)&aQtPujNs1OJ80n9rT5J<{k(w%iR(D+CRqk3sHaXw3nJ= zGcqfzzwEu<;rLdQIWsxvo8rCH>Ol>Pk<>H1+cVHt;x7!82bSV9KE#gy;hbMIpS>ct zdHRu8ScrYBXYWg2D%7{2*5Gw^uHDYA&8}p=#@(G6o~caT#mwwC@paTX>yhY-##~%q z9$6gSS~oD!oxPopr;gSw%LMDci4BRrS$%Q+(#qN8&sE=f_(<8v!?hJJ#!gGl^}iSD zQ`n*T$l_zoziQH^(W&7VoU7ASu}u9n^m0n){|ob+z1bynI(dduJhkpb>Lhu{d^Bf{pxC#}bD6h`T}y@4 zcQ`HJMk4dqyhqtl^Asx|HmAFAcJ!#kPl;!e1@S*q{@DNG?TNp<@s-ha(evZ`BOMYu zSnIGXSy$UKeo;+^iXGF-x0iQ6+V0TpWiun!B$wL5!lxA6)BJ-Lb4#k5f0KV(&K$cW z@j&Xk`m^dzsBKpBOmupDe4=G?VX|i?l1W-ucqRmwkREnLy~vfM5?lfAFI zr=^RdA4Fe_JQ=;3Q@VFp|9Z#zt>C%%6*<8aXBV7>AAO(aNwPA70s}$|L(hfpRmINW zSl^Z2O|1G2I4`7!Ih|ddX7i^t-M+~4m}jErBI3;tJmWoIc-Q+*@h=Ws@B7U^8;mRR z&a?J&riar|qJFztKaddv_#+r-MORLYSJ&9@M+{)yETMq6x zIPuVD>B05?1E)VMYTM!TX3w>~z9bJDeqr>Z)PwbQb?f>mR;pS>KpEBjNkMHXCjPb1xuIw&jJ^QSLfXB`#bM`y#8+XI<`toicfJ0?7CEuZ;x|^ z_rLxH-UGp}tbKTD&3wLKrI)(e{tvw;`=9lE>+9q_#hP!Av~Ka-#oE}G-l_JdZ1ULb zy^9>m0OsC)^CXG#=h`vW-#*J)p$C`)eLsD1dPVZ_)bq*rsRht8J}>qaD|X(Gr(zou zyArQ3FZz9AW8|XvkF}@PSC;oV-luf=fiDlQ+&iIme)8>1i^9kAmbV$xF4Lk#+qVjK z=l$y05xFDzW!2;LA5@L4>=}8rep+HLJ91z5z}SG}IoaQx^UMZWZ#mz)uQQjjo!uE*?Tyw;_IUO#zCj(5&wSH;_u$1o z?(OLtLLJ*rn76&ou42zj4Lcah$Ou10Rn?n4o4j{=%e+3{a;p4K@O)%n#|peN-B(z7 zb5F{X`XjkB@xR1s*4Py$Tg0a&VxZUCiRWS~60gNZvtxCAtZ#gHYKW8mvw8apRDs$KU4d8{V;Of zJ>#{pn%I)mRQFdql)I|nt0o7F-!Gcp)YEuExXk-m;x_veG)L$H#nS9LIY0dXv z6dcV-+qbe-q*I_vU{T<3AVH+w(s#(Z-#IURM&d|fNP18B#70XC1~l*8$SHCQ&JHfO zFNj~Ac&Yl&nuY0Kl6yJ%rOekiI6K%QrxE7@J?{Zyz4J@0?vr{v8o{Bk@@H z)TU0eypGp*{H0w%$=KXqe9zVSYZq5#ss~2PYQ{vcuCYnUgE3zsH##w0ci8(G*|jCsJ1KwSwrIR|aimH74kG?`_V?Bc?1c{}0?`AtU24;{ zJ8Qp>U!5$^EOYKieH2?&JEm%Mng8%pdmr8D-O*_O)=V$&UCmx^`9zzd)Zz%pI_f+5O)*a4U zsk^D@=uNH6w6WIU4SnQXk$fyQA~VaA6VBwV$e-P4G;1=O0Ox_Zh`Nf@U103x8B~oT}ikR5q4m?rzBI1hCEu&sUWGd>G{RhFjj~T^&AP^l6z&r z!-doHL-`lveMps^g;p%}A@kE8X1;ZA_00B0?YGE zWNLK|)f1Z?370k5KY#z?or913#!8BhTYOUdP4n5!lf^GL`PRQaRPMgb8m-gg4^!dp zl;rWMA!Uooe=fhE?yE@3c_G*>Z%$G7;!Vx+o9xMN;vJBg?tH{)0D0`(X_YxGH8S;6 z>|=7^x5vInEp~f{+ZNtd^n8=n1*!0_zHN-VJpLfnR+h!?Nu0*YkO}ez|*eu#?c+=%Yv&g8G+e4}JTAV)aUEvwz{VKJH zqBKWJU$2QqXS+uNH92pvX0TmRalx}Wm-{BBb2*`?U8p=XBe=%%PO3C==68|MKDbXu&GCN6XS)VX3{$p});_i55{Nm`Vkx(Qr**B4! zyfyJpZQGg&$3Ln*SlXxJ!-G#A*i)Kvw}z8NQ%l16Z?Xc_C3j03xd0Tky zbb6*|)hs+_mpy(w7JtOkF|aFmz2{Qam!E6z@NVR+mdnF8=l+}fQ2r(1@xG6f{p;_G zJQUxZnM=g-p7(P5Vs`S`_820d#+>_;M@DRQ*ba2E{Hcp#=f`hOo<{xFo2}u_0_HBx z;QX*tGjZpr)ya2S=z+Xx%^z($q2*o0cZ6Qf6s2bJonyV~`^U{VYh8!gO^0e8geiA&-y#1F-K$3~{!&-CXE$%`^4$Gb)DCX-^PuW`Q2+@HRj zwd`-lR>#}NZjBs|{7AivpJHR8-&A)gKYXO~;f-Z)WrhSV&40S^)$k}L_yxAXh<|T9L+8!TK+c^5aR2YjEw7yBUP9$U3q$k^3eGhRW)2zIIaz=!k z`0ujL$@ETj=iJy&{I3Q#dYe-g`?h4iq~#o;-r-p9Cq#S0ykk78-S5&>>4%+t)-rEr z-!!|w^KkMrV$sg2w$5Vre(&3%jUgwuCHC^gz`?)_tni09H-ujgF7!3G_h(*Cw{j+A zUSrMn$DGO8#^d6bd}O!xcs*slFUVD&7`!yl!S@t<@;ds?qn_|U?+B_r-)?`&9TE3` z>}?OTXLT#5SN)zjJ3T*rE9d6sv41d+T4ReiUGHB`$1BS`L4-SuN?jc|ttpx6k$fcn zRPu{>=VX$+$D%}Wa#rfi)Z)yA=^o6$>}F;5x3Oul1$D=(pFjR@#T&7AJdgTbg_Bo1 zgWXmB8$#3b#x?$L{@1}nfp@)!I8kq~vm*LM?90UXR1ar~^K@cfv=gZINBRt3g4u#H zPMDhDTNU_E=!0OFzyYGBX`F4*Bk-5!GD-#8CGDYY<<}+-^Bk5O~j? z>^b6j#@pL>&{LT`!}1ktyFJ)@rTvrJ*tsq}BYlE%r8|+`m(4lP^A+oByD5=o%bcT4 zb4$7ww<98@pp#reAj6Nx$Rf`cnSMIX~yu2hSx_G~T<- zcXs%r&}FPS8fr~W&&+)8Zs2UZVN^^0E%UIopVN4|IOXXFGkdI?sGz>UzTIt|UP7(T zEv&2jnA&fL5{=V#YJ2MZWF|Q#v6xCkF1xKxCLYgAwRc~^_xw9HJT;&7!*?Z~ir>d+ zYX7VIxN1KqE__q|RPrz1bKbW&clg3oZSuaH33>e+^=*7cqa)ru=ge9cyhIu6g~nOG+V5uiWG=VM>{9=zyv;cq1FeY`*W0Unw*)Wp_w~&7 z%yA0XB~zJsHvM#ZH7kssbGNa3`XTlUZgZws>+z>6?Fo3tDd&LmF!f;0vKBJh?(FZ2 zv1`_At+smxcKP-MI|lCxTD}GCGP|n)$C;l0}@pi3y6RS^0+BbPNv-0gXw}U;0 zRZB6}Ox*6z58N7T7Z~Ti%(umJIT`Yo{nz;Vv2t^u{VG)?9bE5FhCmorbE&WxdI z#SCk=DvUd|)MuE)THRNu;&}^uVCOhposH>8rYY-$L+-`YQI4lip^DNg$s_E%JjF>S z|4KcREKV#)y_7gTy*6=CW?f=a;t*%~+?1S`{wkS_{uY^2cSGcGZGC-3vG@B?HAtfzrbqa4#&?J7kJovv-?G6S)!72na;|LP0h~SNX>_N#6v5Y z+wQ`QY$x)t%>y1}WiHljFtZ19gOB@`vVysov-rNXi|sEsIrv8RF6RrX3-wLAna?;K zI7ptu=LqV?|Y`;VRc}oa&OOK|4*LL{wD6b{}b;e{s?o>+nHng&|A$Zh);Sv?7KW{ZDqCR8ml#DQueb8?Qb|!;vCj(Ok#%S z8c%`OvOXtg{<`NLPJJH3>gSiKC*pHQvLU`EeJGPkPi7TIy}LXUVkP12Zj`gSU*YW6 zHSRd}Z?of`eDcl2IgQyF`a6|I=4QI4E0RunYkGlGi_dQ-yQc~=Q_@46b5qY{mM8l< z9XaLjUh3DbOTL!;J9$ZRNMapX;zw%=QlHt6c-lJ;#F|hQb&~&&;L!YbtZi&THKkST zf?F9rH+U~q691*{nPYvCIqIxRcA=VJojWx3O|p0L)AW7z2+x$ji=JCN>!==l%sVYu zltY%v-^MeGI*ZL&)v>_)iDv|rW&UO#b=t{wo@OPpte3&5rU$)G0ZQT+0$T)0BEi z7w4YY_=@~@gY%qAyt4x{3ZH1SBX^DG!t~ViZB(W{WdG=0>n_Dg9Z6h8t)$DW52(*_ zjjdA^e$AAn$GLgbuUBQzTQbcu#hG&VK~9tKc#`Q`Q!^7&Gm|;}caGD8%BbBqovTUu zP--A2XccGTsnYbq)R5HWsV>RM=_TwPJD%8>@)J3&PH#x%I)f9Jrbj2d&fTm?zBgGE z?V9*1vMjN*u6up5eo8uOy+IawUF3{JxwXXKEqvHN&2ztfk2RWHK?!p|r~6xYzCjLi zI34X3=XAGkCfB{%nV#B_TozlJ8kafQzK=7@ccf=!immHB>-{$ez6@N%3f?cg-TgN) z&-xZy++R?f8hjurQLzZg}In|pzzc7PznDa>3Jx(3ZOG_fmH=mt6#ZvwptVGn+awce6WxFXt(z9nP|L z^4T%Ff$9U#yQkY7IhXYss>m<2U$h>kmUka}B~=1mbPqah-Hz^~oJ8_IYcdyQ#yOVr z1D13N`@m1gJm>7rjB=BiN1S!+zPp>XKretyyPXG|)5!=t33hhQtj!F9@B28%?3wh9 z={GXnGRITTr`9BIO*Kk}5}jgyMu)^#$8L_F7MT@WnHY#Y8$pKgtJFs9*xk$~3zH$B>Rug|imgz6n>h=JE``U-8Y(V+?HG3i0+>v`>Ehty~6NxhU$ zirDky*53MGlK`Bqb1Z>x{UE2vC27z>2!Mz_3Cf;Y_hj_s`dF>Y1gt?r;GRuC*eu34@ z%c%7<%bsYDC-2?cUTZC-T1zn%GJa)0^$YA?EaW`&Le|+0LxP%d&Ut~`k6k>^^4&vJ zw0M?U7sog^zA52fR&+Vp(`4_Zs>cOXAL&3PmVc;|@d>N`w8HC6 zw-2_e1vM^;uoJ34HQ)J~pShf?albRqnc%$TOapVDaYj0KvIF3Le&5F`!Ew$Ab|&89 zJjf@#_-X0f!;ifFuEP+%3J)smReIE zkXadHS6W?>`ZjhGW*T-Qp}ok`Oti!HnPA=}@Z~99H?hwA5wK{E{V}L;6V=HE!k<>G za@xacxBnrBPf+`06q&(B)I|`E1UM1&59sqJ)n5)$*JD35L4wq8xtY_)exs+oM3J)!mt_hj;Dd1(4EYFO=Jf7c=Y-%U-cC7hS^5wDk^`e>)QQ{p&OI7qQS z_{+>xX!j2L!d}l@!x=AkQ_s&pJZ=c9C;%wlS znR7E2@a`t=m`CnEpVNt6fZs*P<_Ab&9%lyKi1cb@U{^f&0ieTw*k?Ez%C)dlXipg@ z4=<*s%b&s>~W8*jcs# z9y{!=>IddM#~!P#RF3IHHs@+8{k%Y(o<-1PE3&cy8ZEV!SwBJ3H>_FNT0Pf8rLoi% z8w!#S;P<)QcRm%L&SG!eIK1ndK<>NX$_L#4Ij22+2N#lf@*caMl?XMwvup=UiCOfE+L$@L-q@qg@cJB1{=1E{4(>~Rz z)=Hm)pB2FV97RLEMW;ql&8ry|f|B6hBI>MNO9U~>y&Op|U|07x>e1Envz)y}?>S4I zRjdH`gj2%a;C!zypw6exOlOv}(D@F(@oV;fzlwY=aW=xI-`EfSOLm`M%GpHiyfx4< z=r(arV*TzRCx!;K0m~|A{SehW25`>tnQklhQg}TBomQ2(0w^^Kx@^G$9t1TU>m=m+ zA}rw9c7J;;qqz(G=x;Z(PiEI$dyvAlmgX3s0e4Pr%hNaC!_FemT~m(5j@Ok&oYpu`aVXC0()A1jgE%IM2%t zl%r0_ErUYeL90Lb{}-q_1KQ1ok_()vR0N!iT;2n&-;B&(3I?3T&eB^srA|J>C?xl3 zzCWM6=M&J1H{r|e)W%-mEW(edpt{~3UI%&Qz)c%Jsov?pKFkrs?vt^Z*J95;2RC!@ zyK=EO{gKT_s4n#y*suYa3nFps&_FjwLB)r7O-G()5@G#}p3g_J?g8Uh*v~OH`jkDz zo@y`Tn*Z9@Glz2#{22`DpU>R`sf%+GJZVQ1wgH6si}UUNuzo}KXIfvOFHiE+aoGxd z|Dnd#0I)qyZK`%w25oCZ<)wOhw}%Q|E2;nV0=54ZAtg(|W%Do9Po3z$O1sQK=vVDRVI!ChFM%}DG4{Ek}cs@dq% z3FN>-)W)+I%^+%Y4WL@s0IGq_CeQRbnDI8FKWNo~8Oy-m!|396&|@*DzrBMDy~UMN zK#O7c0;70$JE(CozTIu?0~k*Yv5w%#2}oEstW$gx_W>GOp+IZ9HKy^}2rK$v=_>TMd2IQNyb- zw5Z~G@;YGW9bkMJn!b-cyG3L?LRN))IkCtPeAqh~T^-N-*b3Wsu-<>WdoJUj$-dTs z;O%`>8oS&29ojucjh_##TqGBT*lc@ zGT)l+o{c^1W4%N6qlL8`I`yC~V>djlKjHhG@as6wJkR>voyhlBSuf(zPiO7kDQNsQ zM)o(=$}XcGS|00-AHg^L5jt+6BG>@;F}E-KRc3o00bz83&hM1)0tSWky2Ln~|mm@PWSvN$SA)a%T^WlNHv(%XMJMQMVn~m4}qRN!u5I&h7A2Hi8y?!O>~dx+~;t z(G~7UF!(wUEysEZ#N5LvTrY!99jPT4aW6qa-$a6%)6X(0|J{HO-_3f9iZEZniBIt7 z`%@jVj0{g7yrZ+>)j(u>6JG5I{%S>4ts?anH2n=E`T?G|0QJtrvR=!T?X3bltN(x* zz0vFcqC5A1gqz8jPGm%9QPr>xq??Il>xLaZ0c<>qO!Y%^W7H7c$@tIVx=C2@GeHIF zvT-D|k+CPw*{ z`!2FSgWoH&zUFAEG1emuPs8oujLBtB{WxeO+}y$_SA!8XZV9wGmoXd!J^u$9eSmhi z1gn%uw29g0@}yZ@8724jAXGXLv|SG`7SPLo;rA#g+ksg61N3+Xh~cM~BY3q!azBUX z&ykl|={Cc^nFhW4u}1MzWbYa7?usrhwH}7*Q4n$t&wd43wZ&S^;M#Mb{W5y=CYs~N zGa3(_dLdiKu?0&QON5XR3xhnpZo^)9S3)Z!`+k6nR@JeDM(nzQ_i42Cv!(*Dw=La zb}K>Bj>u*=#;1z6$Fh}IRbka8>o0%6hy3k<*Y#PfEkyQ0+}()(^YL;Vt|N!dNK_@5 z#5LM+92-?YkJ}-g>SGYW=1`(7PuPt=u9~am^jX!6eeiAz zoL)}l)!(_&k2UFr+~>fl!%#;xYQtRF6MDB|Jb8>J&B*kbs*3iY)i%g^WA3PjmQiq3 zv^jxRS}+dLL6v$f?rx9QB+2Z{`v{|02bE)RSQU4>@ydrz`_T7ako5zMqcvl@2>l<2bd+Gd*0{@%&x4Gm6W6MaYcIG`%BX*Y zisksAJ!y9&Pi=%utV3#lWj>@CPrQiU4Pz`dTq)0K6X@HSevIPP5n8L->Tc+jrfO(+ z#@>?gslv0e6{=I+mM1pDdL^ktx{oKR+HVm(Z3;g<_`K`6e;YqFa7k5(TSB!6y#19v z{D$st;JXMkOz`^{FV%|POy2^~y#udSP-sWi7Zg(L;hAmeM|;|d(}S(NB8;v%;~hv( zyo~vG^jsBO3mB~`9Ji$(yFrdm`1}BTEP*Ej;9xb+UkqM+OY0u&&pCYB0m>hN5?|0e zNnDsVq<{H%VWsp$Ua6X_i(KUMnRLAnN(HcA zHcyR!iuJTzha^byTX9WKMi55IM7>6|t=5FI?dfMndR)M_ZNQ1)XzWQyS_v(y{&OR! z*@Y|n^P_st&H1J!cbx*aFT|JZPs_cyA_%g?>3M73cc;a+yes7D0$M5Ho@U%D`D%fl z7V`>|z3}sl5cH7tikm)uhxx=0XL9(<%ZNhIRvH}O`w;g;vxrd+RaF1D4Ec4SLoM>M z9d4jG+aT=*{Q3 zJg<@~RSDRo)kgF#pZ?Zzohl|*@H)u5XaYgli#t@ zt9bfhe)jNPl&g26LEGu$7Cx09o}#5%+N$Jt1wT>ZB~_O8V!g|0MV@65{qSXbDWs`` z)51b8=vKg0acGi8zT?EKEucauOI^-~vV{AkKiu63YHF?{UFx6P54 zLY|lANu9IS-~=Q>oNvS(O|n1AzDje7>06Lbg1nl;bA76qtAM`)Jexzyd48Mhbs>74 zq_5J$C{OfqpYHT=rLs=DQ4XT%V2BV%!meBa(A3+=xJ-dakWwoF1-K z3o$;a&eBX86ym1|{d1sn4WCB%)`7c4+0WwqU7n@SNTY0CE$vo9e{n!p2a%4p z{5Wt%)!|i@-lApcrqL(WP8XL|LB5{v()6POS=bE*sYZPnpI7jV2(Kz6<6ls07yLiS z|Nrtj$oH!Et_tiGJVCOms^A)(WIN9IDLDnNBnhfSFUuU|N9bM6b#?SjwZt_7KfRL9 z$fl~VX{clKxmqACOIr@#sIq=TwRxKs#aZ!I+L)%bG~AGO<}=bpXk{}-+Z1VP3wJwb zF+xv9N(qxTSmDUZ#V^B9izg_O7{2{HWe!ptyi>4{=${<(L zq9lJ)=#rm`EXNF2E72`e9ep=Es^+Q`uWGboA3RiL^(}C1AK!7bB;3@Q0Y`Zq{#6LZq5(KgFL2e;>>~bsGqoIDDqNV_9^Cpqvy71f%z_&yKb z>fD9K=tE=Phxt}=63j}EsGQ^p>bZ*+(iBB|pN06X1SNq(OJcBk7#|PN!$I$zChmDB=FbS?LY(t_eNUISZ;dAEsAAwkEWy zD(oG2YTK*~Ny^pZMtswXPb81I{4Ggtn!U0$@*~em+)roieG6n!URM{cP##KXt8XPQ z@-;m0N%Cc6%=>>bmx30;51mON+z<^##W;UShV+RG+R7_w&F><5pw{`2r;5&P>7i&O zYD)j4xjH96m?!TvflRU{06o`TMn;b#3#!6@fM-hywzxpEfS8}TxYJ86gXM>>3oQC=q{ei|0qXOm0#Tk2I!;=;X{P~b$*DnOecuw z1OVBlIK0x`2F;m_pcj(07_C)ut=dZ>8M1t8Mck{#;s{NuX|pD42@dm#kfs{RsHXKY zT0ERZfzs^%WxOxr=|UO#jyeHDk|Mt%nw1LiPpBs<2s!05*3bh@bU z@*vIcz^9$En%oT8l18`U_X%ic5AdQp5~282=fvog809wPZAfe76ACRFBT>cNC+%v@ z@8YcV>ADuYlTWAK3;E04slBUi2SB7y_F|wWIn+C zLR?v8X_|a6*<#6321-jOMQ=SvQXtDDe_U7Q@f>+uYCZ9Pyp%1dh31ka@kY`lK1t%nqYzCSD68mPJ&+a!vosV&Cb&ZS zn8%YEK~-@@RM)#^aH=?q29h-nK<87Pvm=U0UgTwo$3oR+S-fZiw-w`+@VUG&X?1?K zZMAA%;+W7uh@uhdeSUUy(go8SS;{c7sylS9i4a@T?Wf-we>6LC`DMp>FW*n}5KfB@ z!Ub_e^gqr!QCT{#lZM2>NS1qspTc5&EA5hmNQNXQ4OT`HBT5>sDT>lde9}MRkxr1Q z$m+VzK01ID$cvC%)McM0zfssB%#x2`bX?wvFrkzuOV0G(;F3Je!&%!>#x?Q~jDIYz zQxT(NNOnnCkaF5o-{nKgawXw$ns-8KA+7vPy+lnPzom8Ju=p!muKyMP89oXjq{n*6 zTWF51wg3yH!7br%L#9WqNiOA?=*%H~ugFnnFSX_Ww)`*eqh*%IvRKk&^K417TJOL$ z-LeRBB8buzBm~fqPs95ryce3uzmOj)8=|PjLW7OH7KP+J z8n0BEA$u)`ly(>kEH6N9s8wN!s4RVwl@OJM;`&=x zNji;e3Kf)3k__r4PE)TI4js0^cZFm@qe!`4Q zQY8P6|I)tTewE`pcpS(dB@*Ndbay$o~9xP>A6NMj>$WdR+`Ar z*hgiH^r;VO<%JM;oFO4sP^6~{bqOV=?6xwBV z-Ds_eE`%Av5uHTUjcZQE*VHL&@)ne3>War8O&3Ne#*{~@t4-XZYs?jGX{QA|)@+|* z+?KrL<(BZ5{1sUsGd`WuroNT%YDGJWI27Tuq#yEmx?2tthrWjOFd@1in5+uA+#3>(DnystL z`7Pg6Xd#Ovk5b?28+oh;c^%F6R!={~)0L+)zOBjE)XVxs;Qs|F-7V%Y@Xz*G%AaBD&-a>A9;|^mp7%#n2 z-cEGQFg{5`gYBo-b@nBnK-m!Kk?3KxNs*&$h3tb+RhT3#k`}8UM*d{g@)?c#FWnT< zD&AB-S zlr8x&IVSPbcp5?zMV`{qrd%f*FF7$fsTPdgmv*;CIuw&imKxANGN$-KQl{KhA=Fa@ zqPRh*n;|o-zhp^F^GT9;k9%?kw_dZl72~d z<$uY8ls-wTrK|GH!pYC1|o1%d{oQgu(AcxI<=s-y_WDdRuLQ;-j_pHFs!8p?4i<~+b};fzkul=jP?(ew2L^?oQZSpj4JmBo}VA?%QhDMzV%hvL>wAcXSb%29N}k_hz`e>CSi{j}gt*(lir z`KJb7I%U00#jTPwlmC#9BD-m_s`@R@L~(WX?#Pq(CzpvNI-zH8}|7rZh6cd*%1d70P5O zK9N4ipE6@p|1}mNkcsUZ-YLE_*kkll`AcD*XsX!2SZHCpJbWQ}ga0V6Q&^*yB4XtS zBv+Cq`QO4U`7COq7)}aDBwxZBwIR!(T&RhLa(Je^LTSI|ISgXTOA;E(#>lS}4hzSO zZ`oi`jAtOt5Vb@ZV|R`H6-rAp#0z1yv_|@(thVAb`QFM5s}++U&cJ`+x4b>cg0f%o zI*gB}PsDjqQZi^V7@{MWXEjanuzDj)EnODwD4H;lfbd3qmgMQ#LLYsmm%#@0N~17X zWpGLQF1eDFN~4ttlSMT0C9ha=E8ApZ28~9aY9vBU{Vj{6aVaKs`7PgCSzFOsTo4kA z;_?X;D@fy|`9cVNq8x(OI0zY}!HPm#qqQc_EU!UPfO1pHI;y<}-?wQNCFCa-^1b3y z`Kg*4k+<2NcglSmPgSkShRC8QS1RAO8=p1g%#24V|3u!Xy8SFH_m4h;Jp$MkIj+ltkV2!~EgG91oCc>7+$f77;XQ-pm3JFZCBKRVyDd;(EZpD6Nd@V4Y5h|X$BTvX;;JXV~cl^wDk zYD*r9{19V(OuVUBSH7)jL;vSxuT<7l*U6elE@Vw4eJ0OD0n;jvEhER>SC*cq%@N=HjUIRF=Zvh{l4kJDAp9F3HfEk z%q)q%Q+8BYOXFJ_e_Fjavz)>ZWwI4(OU9*92EC;dLJ37}dZ!$PnGJ35IweIWe`=xx z`IF`jgXyvxiXN3k7UGz@O-#`g6p$aGn6wR8(FsIQ%qp4FtcN5-(T!TxOF0R>OrP{5 zc`lMKA%R}9Pm*%uOUoaZ#WeAy?6=7<$O};1A+%LwqHhc}3}uuB*EP~_;j(CBJbDw0 z81GeC3gxYIe=;ipqOZYcMJA%YY>VVZ`6%hLu)z4LhVwejTGn4aq;g1#T@Pjd(*K&% zFth6NC=?4ThSVLp(#(`74pu%~8Bs;UYC)gKs}+h!x`Z^!$V$%4vku_HYGv2{?2kDW zzA{TK;e=86zvX$VHKX6sag&i2kHuA8qxel4EzFXKEBx2zimT+wn#_t>J!5ie#$S;| z7beJ4Qiex=nfWMTyhfzRTF9?)N`e%h$u>xqg^R|2Q4UA(o{7qYE|4mVt2WOzcErq= zie_4WAe~hnT-gLef8n%HU(#ZHNXeU-otNL)fbwQ0L>yE+BKx2?MzYu7`5RAJeu}td z=D8Z;V_{x{wbG~*C+j-n0}C_cix@vcR!hhx=~Z9kBP)_M$R`BUyM|o7S``INrcHLy zM5f}La9+5q7*bgzMTQxc|^qE|gf zvZ8#q$#cqEP^@Zv1j(I27|DdB!Em)9Lm>aiWgU&;4B@4`MEUnx_hMFrm^faUXJLe9 ztAqiP5XJN+;?jGQ!IVuGvP<{noyngR<`^9kikNRrro+T{@(X0|W&4kyv9bq-%d-EP zTQxYZ^-7WsW5vZat#wg972E1n@qgFqxrT%CP7M0Xb5=V>)+CGSiOIkVi8NDUB-7X( z&5@a$n`Y(YUuzV?0r{hfXHDcG{4_aQgA0-f&6tX!4RdLty>iFOLdmL25=_QYcp*9deef9>Fu)8B>9o=BFmy5lLt}%n)A8xD-GEqX}ZZ434df`OdMqFp-{)fLDKt%`Ay+P zL*`TWdbmn>Y0yOIq4&xnn|X}}|4}iCi9khfd6I@_#%i1Q@;v2zNFw#WWKz;AE2kMy z&6tXVnuUFbZzU(fEX{@*Kiy=sq=m}UiF#%goU#w1t!B=Q#gOi5rd|HH z;zZ%G$+{a4)nq=E6)!u6%deGoCLd_9LDY{x8Tnn}qG%*c z76*jihEBo@M+@5N-?bf3J6Dq58td7dyZ`t&?vL=j2~#CPSzoTrr#z7lkkK6twC{ zaf#+bb(M)y#DC3e=v!r`Bps3nStv8BqV;*ob;@&*PlN=WgA;tyek{Edcn1+r}VOgLaOZppTZjMTQVbMl26@nrc+vVM|tGm~d#gB8mt zHdAabYD(v&(UK_PvdM3V4@Pc`71Cew7)*Xdav~%!(Wi2d@(d*hvLfQ9tc!lDWy3*P zE`$4uNpzL+TxP|E=3gXNW_DG7nM|#&GFKaUk~C;7w))=SnKxwe8$3-TBchuqC%;BI zXChG(^@|?L94dy8$0wAPjn{mXeuv?inSV3BggivEE>f!Iwx>qr({E+5Zxl*Q5 za;CX9*#{wsB(uQ=7*2@)n&Fbw(zUWqTEC?=ZjxvDSJGZZ6Y`;?-}=O4ViZ}*XH~XA zajC{8iy`|We?r$9&(3&R^0duA{cq+^jAyD?L-?)eMX{eEMrpJB2!j$0euW}T>8;5M zC=xJUfSx1iF>)h~l^0^HhtNa#ZC>J|Btw}5z2u=9Ocy^TMM4d&?Nk=RL@1i)(ps=X zSuaAZn#fjtlvNN)=-KkG)DKCSY>$yf;hFpj6M={VqK}!;F2uCCx4c@H&7rrQaq^y&qPNv5k_dGg;^skTd4e~ zJ~8-dtfPra%rmtfNpZhyiF^wAdCD0{!?kAiK-Qw0Rry-OchKDnjw?@XxMF;Lp^9Wm z{M1T^quixc4qEeIB2Ia#@=ulBmX9GBFzbMg2ch4x3dZL!@wbsD;ffGJ-!({ra7xl3 zBr@MOWXR;nYBaJO%H0_r8no13qPp=CG~1_Ko#D8$(z5W%Et!}>T5GaO+J|9QK*<9$ zSS`vc`>oYE!hWr-(Tu28%V_1MR?KJ(p1jHvd8gTr4*b^qc_*&v!K)XqZmj;+ey|hx zvjs5ThTn=&BsY>D{v#)8avY+Bp@YFV z<$86kGSwzxH`z(i)?_M$NG5KS&KWBp4b+~AdbHYToG?rNh-fa^66UavKD(dKl|e8oV&wm6 zeU>~z*_r?EAu8J~ypbO$d@$BWSfcn-_S&F@FhsGcnJJMjH^hq4chOboDPK=?5$?$< z3w!k=Ez)?5A7`}0WDkY6!h4~GW}8hUFPSjDsL^+QFYi?Ts>y=Oc8VI3Y}rFu9C^{k zx~doAhh)M0=$TsGEx%FPDytzYFPaNOwANmlZnRU^8DC%T^b#iM4)ZcOXp?s`7C=_O ztiY5dYRH9)x8kxeLD7Z0P7`r9M9z{)leaechX0SR@+35R;hv~t{4JBcG}cv8YI1Xi zN?PBpdvvGhU}i`fEV3{}a-o;xUGI(j3uQz}MTTZQruKrE853C_;}0~?;)eEB_5uS8 zerPpLFEBzg938=l-h9&|ix7%Qg%a&~Y1ULTCvCY(`3*&<4Qp=9{DiQ|%-TtI4L{{a z$Ql}ZApA2ZXe2}&Hf#FjlPJ2->Jd>#s3&bPo`;Y{Iw7qWC5>bXzr|~1HN+#yLqmjM zA`{89JUUTUYs}@rh_>>7`#px=OnK1FKY>mlz%hS+%6P-v3jpt&% z*WYGsxuj0=B3ohXg_&1v@MdeX-iYyN4KkUyMB_3!UrB^A^+T{rS>zLkud?0{%){A!I*d1+yWviXu9ja0I(tgZY8 zlaCMz%Ij7{A*qnhYVusVE5M#ZWtrq%N*aVzk_G**9cs!$$PySIK%S-iN!^oTABaYy zT&sAkRkD&>p|WsRG|{?u%@ZnWE6VmqdnjaWwa(pmGMc3`gZ1HsiUn>53602cayl zWT#uxj`mzC!)v?*vrk9+1O4o-68>5A$;^XkExfFv@LxVf2rJ!W~OClwH{h^JW^ST`skN-k!cT&^w;H6QCxnKETHu4D0zM5 z@ul^ZSX1pOKAh#C?0kajWH%CgqTQT&jwD2ysb|Y~SI$CcAPb>cA^E0S?I_HYg;kD^ z3;0|6a-f&6E{Erv401z6*NPE}bBflq zN7rPMO|C}s$?bWrT5rdGVeQl}U`J;K{nY*z#cW#Vt~`)tlM3jyJR(_b&2I}=wfm=z z@o5L7VgTugtgLWAGfiPyuZCMkk#R}1cKQ!w$8tMHT?#cTY0vB{luk;^yo^FQEVW@G zQ!f>E%>Ec<4MivMOH~K7@9`hFau_aG^LsygX%6t-tQ$PaI>v20<9GJrAN)V+2dJR1 z9oad`)0C%|erfH|PIgrw+ zyfx-8>7^=2sBgkNNq1Lvh7YBMR@nPf`L-ChZ&bqqz>%DVp)QqV=6fRFV3tk`XQmSsai&!(SC#sa^e=g>9DA zO(BV95DxR(r5Eys|gM;Y8HriBt=KWMMAC%gZzDEqfiqDxlNwT3IfSd~jquDO=is?@DO9 zF>QvC(7o_uD-{R^X6p~Aw#8AdEYE79>JI5SeHh^$svYd$39^9lo^p9|Cwj9N4U=xP zLjt|o`fu6`djQ%i2OotWUg)Lvj`Ejgi4?bKJ)-0sx+&d)qVkeJcS>5 zQJvATZpflyzIrOc^r7;LDpct2PCT(SQlXj;CsMzm6}%RTDc0{rk37&WNWGVnslPLr zo~N@qquFC2uq=d!w!89_IDG8QRp)_CJ)mqo6~rWOy?A~u&n<<)emT%1LTwv2x#Qqz%M_C)5kK-Gg;89$EP?VuG^ujmZbYQgeNaI%h?B4I|S{G+tl zOZA;XP)yYfeDp(AAe50#a&H7~Deqkgk9NY9Ezncm@Bw5ZO%EE=vaFjbnr-5_b$l=H z?l^m|S9AS7+K{)PUEM(+(Y`mb+?(B2oJT=KPzJVfbvb9%w z(P#Nn9l_G}NCD?_Fs3F*hw{+M2W)1%ZP3=5thPzgx@C3S&0^jTT51HvRXr-7TIvU> zgQPhlRVPs-q{#FzmU95 z4tZaNNY7gA^g3wS3fz%@u?Mk*n*J-!RK$O`C>yQ?NV)l1+W=^4$`OH z4EvBO)#%FOu8-Zh^wNtaX7c8CC@FO zuGW0G?z(-c19cX!2r^MhjlZFcwyCuf>C=w(G9+{*eJ?_5&Y`crQdwj_+)`dmHa-dm zF694>^h>xt3c1O{rfh_#ZQz0`-=yiaDvya8+UHcvy#u*YR7pYmc5wbdS}#U_6!)D% zeL=PeI$f9Wj6zNE!%JY6-DWD%O9Davjfhkh;4u-`$5 z?N}b|Rz1!$XHg~XH>fkpx)mMXK();;sP?%9y*m{tKF-r*O}10HY%@>kg;YLaT>)b3 zU?ieV7kt68q+LHU%X%^Ms8^eOvUd0;&z55>;>N-_3^t zE9lK+WT!8++`feystq1Q{&JD&laQ7!=+f`hV{)>U^19K2DrD@)^1LmhkpDIWjs6}f z^FzO$v^^e~Q5Cnxxvz$ECq|?91FZFMLEPJmZrzFOU(Hy5qR!JI+CIoV@lxj4s2%4g=}jvm0DkizBTC16Ue~=&}tqMb|F`_MHXH6 zWaMN&J#5WYS5qZ(7<~O4q^pEpv#4p<4Zi-0rZ%Cn<6xferS}=Q{63yqIaNN7aWchq z)T3%mMbXbenVHB~2kv~7)>gS!V}UP5NA9H;Euii;M)L(#4Sn#vBU0R$-hYg2y@S5q zpGDry^tqgB#JBU*V)Ws%tiSX<75x@a?Qb1peuFXPbL#V2+PMiNI-8a&uvkAJr^l%r z`~{Ugec2j}0eEnLs+F&T5dmuXU4=Ycgaqp?r;e42?t0T~lZ{bpZ zDofwOU(K*J=h0W`!XvcqxRW?#UcIVnPlM-xM@L8~ru??twh$(ew5mh+UB zSWO#xe?WcEI(HcQ<3&qPN8?VT;&p}lAO5-z4cNs=FQcu+)bBnDviF1|`;eUFS#Mzk z*KLHC`|xX`(7kW=?_BKqbLf2^DDW{gj<rZ*Qia@nF2MgXqsgpmyi1Rz2mONv-EW z)akwtS)0ytmVja-(bVVAp_6H85Yn_1i8~6$oQhp}8!Ajg_AFZM2!?f~`uA9{Kb}OH z)?Rq?Bi8qFv|~2w%=Mf|vi$+Pi3D9_@YOj%36;pxk-*LO;TH`I?{N$6QE@!|^xo z$1fO(U0R3b`VlE}sO9}IQhyrQ)tB1Rk8-ESx{JzJgbLR#qq$Gc#WmQy!%DtFxtZ!UGm!kstyo`RlUL93nZPSz}tdl)=kN44P(+&{o0)ke-B#cR>i8qRQ8i#*i0 z-R$P}arQH4{}0*f#mOgeYWB{?614*#{=*1U)W?1aBzu9%ayj2s<>Hy$+QO}G$)KJ(%0EbduJT(}wvO~WSq zLc8D6y6QJR$+PdYenJXfVSMY+{7p!646#=*er0Z@W8@zsJ!# z`Eqwd>$gC%dEoy?*1PcjGOB$)$Y@T$=X80>MPPMve3lvD{VCvHU%Z+JxT-Z*{bW6Y zMO7B@AMoQ0aP}_iZR}bcZh5Ix{5mzIZy-_$pw$C88{iW>o8ER)IJX~v;$EuXKL>_U zMG=i#%PBjbV87eqr4HcygGoI732a+`u2F9Fb>!l0YAOFDJBVK~f^$R$K%G+fzYPTH z3r-9sUK+&dUW1_P^X{Yc?QEoa3|R9Eym}1l-JRZi1Bz;ftQvci!*_?!wI`wNYD``V7$GsmhU5RGx4;liI>*X+x67* z{tQeR0lq9mM_Li9T}e#$J%}|I-$C(F7b;ZWfmY5yw=P3gmxBubaMcLfnt&I%f?m{N z5%1z0vwK0gtB}ufH2+;rPPp2-jW)jlr|tnW$Kxj@sPBK-|0C)vz^gdEw|920-{i(U zfhg`2cXx^vDDF_)3lu0&9Euf(Qrw{wr%>EoLK5Oe-0kYl{@+8s-}9A+lH9vHyK{VI z&Up{`{2Da92%Zl`=kz>8sCS^f=P;*yWe{c@iaUKU>oY(AzlOKH2+vy$+(}lZ!c)@@ z7g^Eih(>xrdVYaE--3m>0Bf-d-i#uS5s?3Q^by;E(HaP=fqwo5diM!Bw*m3kXh_5b z#24w1gb~1}e+?fI1nlpJ&^;4I?k%{m9njk%$Z0kBy%UHS7QrG;Lu{0QF-9U%>kF?~ z1i59fDNC_4^%&mjmeLB8ngP1?#G1JtI%YJ7bX0`}?TM&or?LpR)(+_4a-}s8-4}ww ziIDS#@WCII>%4|=4>LIoJpb*m#3^`2IHYqH=34@|nKbwZAax3uL z_v6nTSdNG2euHiVnE&_q{T$x!AoPdfvkv&P9&$1ta;U=Z(;z9=A-m0BGreKahoJ{W zKHl#rw0sCC6##G10YA5ahq?h@)*DjPQfLjk8HBEQ4%n*3-!R5kKnqF`Yr?lc8k$2E z^jHxYpvP)>7ecDofagtvZxz6SF^ETQD$k*rcR|ff@Mbk}N1+nNd|)A(BdYm?cbbUq zAsN{HXr&}V7fxXvH3*XX0|0 zME4mBsGW&^ajAj~iJk~cMA#e~&`a(uEI6Iu1+v}3Kim~^S_6_e5IvUO!JjRFbyOp= z2tr4chwy19(eLdGuM2qx_8YnfZGxm8;qzdzhY5>eb*i$b6sw^($1>pNw?|z3H#BJ? zXfO{l&>ky}hme$o@P>4DH3L?4B)q~n=m1e}EVQWx>|ipaJ0HfD!|Y_=iBKH{hBAaP}pnvO@8Oy~zRBX~j7aJe-a>HHX}JV|V3U=NcV;%Yy z^<0Hjb`eg3FTcX))Pm0H8SOokV1;!Uo}KnJY0s=X-oKTwLV3^I;Ei8{s$CJIC-P0- zc!+f~M%aru;|=166R>M9VRdK)(E<8)6Tam&Jk!7MkDCygryy#bk633Bw9g$n6hrb0 zOaB0tbS^lB=LpNtD`O3;$!*2V4Bh=Xpy7gU8#{{j0Q1Amu*2<~rK znwg+%UBn;zFt@p|LOnp6Mc54(2A+GvM(lxnm|>eHL2CDcOVi+MGO#{hgO%TS_&+<| z;4-{i2x9hd_7^N6>Pg@q{6WJ9(6gTS5l;I{z^uIz1(&=q0*S_rVX1!LC3zLCwwz?uf_!Ktkh3=*20>^Hi+*s(^?2ye8HwCD4_6h*Ez7 zE&jxM%>;`&0dqZo_2_uW=O5sk8C@~vKp)0ITB_l*xyl5rg;#*Cw8r{~?nvE0nSRhz zPpq3qB8o9!)kb?sr{G6dVJ)!;+C31TBt!rALw5quA7&KZ_a%5YAKrx!k=H|t6-d+} z#8%&7Cbfm`h>Bk$?wJHxkAkH;fLI{_5_kg=+78z}fThx7?LG_oJ6fUj)oeu4i$Kfe z;PQFw=yZid5MK9iT$_N%r4=YU6g>EV9w#sGPQ&1DO<1>1z)DcT6S}~khOq6R!xq?# z4Ty9#u&h{(f%>_KUt56}zd~~2@U(iM@ z(IIcIaD(?z7V%xmkMPLBc*-hRt%ry^N_i>f_e7Y<4hTb)4D_~H5BoGv_yN7jx`CF* z(DkPdMgp2Vdbi~Xcla9K3Ee&FL)&Vz87v>GD^EzM%wI$9{$sx|^e4xhx?t=x@cGZ- zgByT8@6dfFpASbQKOSA)qcN8p$i@V$ZC8Q+UEyEp^jbkoMtKN7gVN+9E`pB_VM%^~ z1{mR)C3Y8_h*GA5G7`G+yu<3D8&*&=L6fe~@ZVuGv5$*>E@{yGB=AuOT{w%~EDx-- zZi4=gVTXI*x6|;i!?Bmu0`i-Rd;f;iUPt7z23)v?C(Opmvl~{j5qO`|0aV$S6eazhhlf06Xpqp4CM>^aK+15>|I4WY>o1;SGN83VL`WZn_QI6#&oo z3!?WykoZXKxcr5#Y#jFGCNyUtB9o!;!c$=x+u-iu;Pido9yW3U-gdokMcIjHjXHa( zAg@+vQ6Y5q3wRv>`yPN9kAfZliuYOs@3#du_c(S-Z)1&f1oRvZ+FLN+A26fw&jeeI14fyFcXNB<}eDT8%_xrNwHpKKTCztE~ozCO$y^ z8zA2Q5j}!F;JPg6hX*`H7TzxBo3?I&C8NxpK~X8nG9%J zF4nJ4s?7eEauNxZ{Aa;p`IuJn0GSWfJJW0-sF7h$9h~ z?}fw#AnJUG$h{`cgXp{@543CrueuoaGzI?r0`Bs}dTJ6Z$zfday|4pwq22#3m`7j8 zuN_abAo|iHf}*|Mm!PHqO)((0mJqvbh0MzMtSjQ*U9i|T?0`{)z%5u2%3k^zeBFe8 zQn`@IzR=b#ut4Yz51+UMKJri4v%`?jCzvg*$HTDm5(mj{0Go$gXT%Q&VE6Zf4h!Iw z`y*;i$Bt4Btb+c;?E7PF8V3$^hZLeMnssK zmHqIlOTg6zpz9s$w1&SaQr?0?P4SMi5zS=6y7oXE{}6A~7~G(}uF)8w3)U)~ z!O^PlmQV3M@8Ai~fRFVNy-~NzC6J0X(D8i46Y*GCK11)lVYu@!Y)@0HKy1+1+lZd} zAZGA|ZJ>3|8tg7z#0&?(_e@44=7$-r1Em_kM-lhx;*)u>5^W%r=df$gTUdw=q1z$Z z6&RE9qOO5vBe70ggw~m30C{nh; zr`3o5^??M}!0tvDjQ$0Y+V_ygOz6x>>|E^vkCU)dlMDZGRbZ?aWc@UD0+X=ww+kAw z2~Ru%D|8ckH9vq04%nVl$i!?!KX;+U?ucTK!-86rJ!5kPWJ~QA6L`;x~6l!8Pc1Ys72eLK@EpS3*F;j@Yj-LS~(z{e@Vk zT!tLXL!@#NvhzFExkF)N#)B3jJ|W9U@yA`n0N3G}7vO3&=-Cu@{v>F=2YXNrA+LeZ z+PbjVbmF3h?j}Ji(DN3YFhg@QAXml8GFXr@$WedLb2H-Ehgi8BM--xkhTexaIFCN5 zOL1LI?842#yXz4X9fw>zK+L`ubf+11hQ|0q0?s3%O2q!c_u$eqP^czWb4MZ7Igs^9 z(B2ua>^At?KvhA&&v2=%=IS`jlMjX5cbeTwNXIQ6B;POpeWdYTxW64ZJyxWCEN-gN18rpvu zI=Kd#@iU@SBWN27zjzQ@Nqgw$@NY0|>J!j49kOV^UJ2D@#9^;!2e?2T-#X&i*RVH| z#s}j>#uJ*c6eE3ynJk6Q&&A6A4BqDwMs0#MGK1gh013KmXIVp?BGc>`dRNmWXHH`Li>!8bTU@; z`@yFgu+kZb8E#-@<^(_86dd1x-`wE&vmpt!u!LefIu@@1F>$9QTec)^I;bRsc zMt^~L>Nw=DJ7k9JUVZ5GukZ!lSbgMyR{tS(_zU*W3Eu@ErmPJ*Q5UY(*mVejUVp~# z*YJc+utdp_T*U5pjuXb&44u*9w_hMhl&RJb5>Wu&QI{_@p3)wlzJ~r8kr!5qXDn(VQ;m2PJlhjo!pQ}q&&IS91MM(lr*HhXQ!F#0 z5XuK>irxDpXlE}-26YIG!mM^f4r$-I40k_Twqai}(VbhpJxv@ChO+a4Hz_0@Xvc#yqP-s^|ndAGE<~BshB?5sDKc z)Is>J25c);WKpC+*rDw}tAQBD1dqUmJOp0 z#_2HS{ps))t#NfL(2aVn^n|Bt2p)fhd@Tp%$R`t$PcbyRH~zH39X;`U+PkH0gI~eV z-q36j)Xaqqp^nnca93ZFSNQY=pJZJ#bHj83w_Y=rBLZ6MYMx;Q&{{ zF-}wX@wc$$jd3;Q3i@E&htT38*l)_^qYR`B%)J!%QRX96;!sCN@*h;?OPFrdv6*@R zTHpm}PmU_z;-CkV|N9nviH3Etf}5p?rl^h~5mHa*&J?pJLegkN>Zw+UbryA2i^cjU z1wUPetVY2@yu`mU#-b=L6P*7SKNA){RcsK7Hf4fGf3tLy!S)PNeghIp|9*gMCqb@g zubaku2MMv@O>Fq|CG1=x-uDK+^TJHx@L4(D_YSUqhq=~=g@1*y2@@g-*V#eGuNdDC z99F}+JcAt4IW$$DzkpS-;~KiE0%OE{vnNg%GaYkrh1M8gfhgCOa+ex{`jm@^{U1;+ z8KbuVkCRoEH9E6NS81zSTkN!7p^^0JK~iIgivRhV=-L)8ghhy?-=trd{vm8=F$bVE9NawGV$85pX z)UlI%Jmni)R?G4wIv8=a{?FChy=LYiZ-y2*l{@BsA)2Z+1{RsK)~7@;u|;z=d44`pmRFhdon zN)^&nSEzuOgp5Rdti<7<=aDo}MIP08P|c(>-qMaICx3g6Gx$je2-I0J5S;MF-EOcL z)nQ|+;hsQnm|_O%vPCE`RCiGe?-v4|H^Eh5xUM~PpDOzsA)agpY3v58(*f7J<66qm zr|O`?ariVNrsJRu*13c>LkSwg}m6Tw63 z4$EQDsmE*vC`+|}1^9+THb`Hn=9X4fRGCQKG074VvIb#=6+Dea;aL8 z;zg1|>H!~vF_JK12DnUhI)qqHxvJEkiBN^8B8_TmO7Ju7u~HTPSNufJrrL#A{67(& z7UInaTRR)?N&Wq(o{8#eNW!UdlydaRg3?N`81FzxPeex*W=4G%DO;8*#Hdo+1YQI| zS_m^rgPB(L!*zV4Ep=a_^#f%{R95O!^-yJHgbqIuwkTn85r=Bv4|RQ{ng^mDhT0?Qq~~FmsD3m*|1biNjN4{ zqe5sugqA>WPkbec)7=i}I@NbE@RBm4i35~LPqIzudQ^8w-N`81pQ9e`jpX7Sz}MAkd+uDODL;@iHLr7m#8AKJV7U(|p@coJ>P@O;3jwa)G>e5Sfii8eBmD!Z@NC*fd zJ7ha5>sqP)lxklIn}jfmsG686j7cu37L_Vr2?vAfPF0`?rkRke}q5P}2M^%BM`RpQZqm6+z#J%uX0$@&sfLnX!r)#z2`|5Ft`Wl0lC zIB}b*!bn5S7?bd>sfQ8O;Z)Z3l64?#HnL2V<4^G_?R{1v&pGgHx{DA3i28&!LI1Dh zpp0izwI03E|3?n1#EBrD5~u0^R6#)44}@D*i6BFn_M{n=cp=5#-hlXB`6+1yRhH3n z|A*T^ne&x3>{Lfcb>oDULMS}MHNxm3Y=%nA4XWp-sF$!Ah%59xp>q(f0-?#$I*!m8 zNdK!tx2Zobp#(3R0UoWK7=}X5{_VX{6-yBs0RySworu-RW4JNI(1=woBUVZvahar}63?m3-xX#sJBWS9EFl5eQ2ASFLawGx$e{q!V~BKkiEFKH*2dOR`$Q3SeR@mVIh!a zQAPCsP{JzplH%*iZ>a*5vgIq0ILO|RtW@&5vikyA1*#sUnhDAQsl*_p->6#ge<-zt zb4dt7bTxSdLK2~RIkE&)XH6Ah^chJBSq}QQ5@C+K6kS z50ZgODWFm8kCj!=9`}jUZ}}PLU2% zZ7-b#kpCdNOcF+#Ux{Ewc!rfIKUBX->on5L|6!-lGpX8_JTcijlCDZAp}Jq<9if3$ z_BN=*_aIyYvOr`FNOlNCkfexkVW_5@EICyjlf+OQOty=x7TE_vOr&VHQWEK{{?{3W z-t~XdiP!!1GkKXxOaz(_y))tM(5EzB<(LjUfxaOPple85Dp95?rGz>$k(DOS`+WO} zI;m1`9Y0tO`iW`@BOo6XArhiQCH4d1N(A6DTKV|mU&8yK5y%z<;`d5aeL|s>zI{sA zLL>(iV^DNUpHmDdA$F(enc@YSFQHFT4LTv7kwj{6WgsjNRe+M@R7OaI6GYW5v|7l< z^;D}!5h&d)z%EkNE@6XH6j{juqBuon6yp)XS1NpX2EL1jE~kPzU%)j&gs7|tCu|AA z9Vq;^K8yp82y>2beW>eZ7WhY)bHvL++)Ytp2F6Ro-SPO7jXR^Tild(3&%d2`eLQOyS0aqk4sz$sD~_VTEQ&SE_C#W&tKW1k); zGmy_(n8{L5?zzwhoGSqp(s7n$LB?M+;_hnLe?5(SmtL5g8_oa*U{B2nd#F@vTNNh& z*MP3j9+cUS{hXh3B3Ng06W9cIszr#uZO27bb0eib|U&9T;`oq^}b!P$@8)S)=L!_FRd zm$N}>9UF@s^IJlD_K=@aN|1l5SNgFf$ZZ*^9AV=?(^c5zZjE;u!#d-vu3U*nra(=c zRy^Y_IDP8`o-fDQ)fW8w5@*{#3hDeJ^T$rBGxn^HA@5F&oU7-^X?%gxNFVkPoX7-~ zGI6&13e^Jbu{-3fykZA=f1wBTdlRP}vnh(2Kf{^Mc=(p~N;5u@g(|`90_d?>PG{Y) z&n_z0*)LqfT7vf{z_;zRZvbtm#nY6(@a$X4b>u&<=HIjD$`y=fQ>w8Q$Q4+w%oI!c zVPwXyWfNpAyQ4%RqejQB@S93UMp!8IQA1-BN_kB&TwVx1c!Pdr{3iQO9?ilPnK31V zyC`qipTb7DIr~}phd)$2#DDo~?5Zx}M&XHAtVGK@l}>CAvRbpnRlG4DEu3a6?~go%45A=FFuZa;a$bs{CBn(+1gK$Sy2uv z{SR^$*U2pzP%7Bpe6})LtSLW*_JxVl#?5f8Z_ZoB@|f+(_$|`WiItba= zjl^)Ajklupw0MZ;^AzPdJFXNVN303xQVqK8BPPL?4H0_r?&2@v3gw;LLG_Dr6Fi?M zjaU5GQNa%qa!gQ!F!-lVY$|(&Q|w#H6*h@a6UMSeIICAfBF{mR_6x<1aIr4?p2f=f z$gFe+P1_+ewTlq0bdkK2Ypgm>(RU!9^{6QFUShhE%r4g4g4rnLD~ku$j1n-j*fwzTywII3$Ju(EaGoV7f3Og-coFcd>t7lv8ZT($7#=Wd_7Lwx`396{7+kB4Max5acCe_uQwEep~;lTycjxmfDtOs zE!0Fj$61XxQwhshs(hB6S9*&xlY$_N+cfTr7FF<-L*f*X6R#EA8TU2Lk@Md>D< zQ$DjH%2~0oA`9EuO1>GD0WIMhhB6uUI~+3ff$tXn6|JaepBTXx>swy2}|c;WMkqSEh6dQ>znIBvOGP-dA_=v5I@98s~V z^lxm&pDKPeS!&rdd@)NEE=YY;7gceZ z5&Di!1wf0AjL#YKQro6RWS-6MZYfm6uT66QE{caFb~d;f5}89aig{pb>X}z{ z_n*DGxioU!@7dq=rP~GX5bq?z4vk6^s9mQi)^9bOSD$eF!^7DNv7Y98t*d^oZl+qN2ZH#LgQ#UT~i#Aptvncj>^lx$RVgeE#B&|yC zofT1XwPHy@OnF1wOzlYB4%cO-oo*3s4Lt7$6$Y@XKK{J|yLmtJ-R*PFbDMh?cb$u? z&O>Dq53={D3e*S>Y$L4k6$fqi9Q{PC>IdbPm@f9yr5QZjj(F7a_~dcWYl44~@2^3N z1D*wb_S+P6&aa7ghL^)@q33wl{?10dO*NVAVj;4pOqMqK^=@*#xHhr9 zVrIvrf3E+@5_R|E`RIWk?|&}-crR*l^uMvaQ}=&)nRA;Zl^o5vSZ>lZ)bwy$;C0e7 z*>k9AL*PGQ{-Ixktf65(e|R;olIp$OpAxDp~_Z1)0|S`hHb!k zDnB7b1pa{PSKU2BRZq1?58pDcyM8zttTHONMYwxdVPNel4TJvixA?mRRrNXL-r00c zpQF93_F;9ECDuKqbqihc7iM3`{3R_T`Mc!vU*^XqNBKnuMem8)8u{JlL7ze*Cq(`g zcj_bm9P;sT;vaFRiVm}v1)Vbb6}Qx^(*5mz!YA1KXV*=pgH;+vLM2qbpJI7f98vHk|3X20Uag${sauoxB#unK zl3p#fTk5>{q~x%~Ug-l9XQXUSiA+BA^icK>R{)g!~#g`BJs zP~~*+=73C}5Z@|Z=}xnZ0md9%TdkK0)G-gN7*x@%;#zrnNm%~=?1CIsZdT!?qRj=( zbM|~)nA0SCbm{p5r~K!cJY!5||F63(9YFfX9J}aDZuGW;Cp|ks_DU5ejk17)?hWB*11w+arO4W z-9qjLwy*kI_y)hvfeSrDoi7Pl>nBt$1L`v%a$ZWIIrHFW;Wmuc&5O zjk5OT{)HZe6N(4s+X|kS_*nIMEpoTz)XCVK(Iu-}+N<1EUx&uek9?Jw6djWEI5y#< zHmct1!yh6uTiKR;UY7bOyGvP%+|5E;b%f}N=;ey>vgWdTrblh}>3Yedp36Sm$c?zSxXk#)Skt&&|EvBlL%7%D@U}s2-kpOQ2TTf^ zAK_i+e3c#H&4Q~1jj!@k)s?~byw3Ox^Ud(gc7EtQ-8IzN%eYVTQiAR|4kKnZab~?% z@~Pxa=_3m-|EYXI;mn+zg1Ed-#k_c0*|ee=sjt#6WWUe5mfI=&d`5J_u9T9bmS3x- z&icAIad+&x=;VY4u@e)@W7~h7T>e*n>y$akFY@CpebtHf0C7H#Gw0e{Ijwb_shgnJ z2_LjmjED3SG>5cv4FOI!wBCqOtBU?YAI&iRGpSU*WqZiGX`M`eJFPQ~by2$=biM1J zQZ2ArOz^NUbwrmMud2J%9u%>t#)$A^RgZ_}2hOS*=`+*kh3^{QIM+qSOD>+yEsTeC z7VQtxdgM5s7NYrR`%de4%P{kkiZ|vix+tL)tEuj+-7S4l z{mds=$J;Z-V4b_6P~AZ5tZrobJ?Tg3@8C3gnjeYgKs|Qs7uBtKc z_u!bIcA;^8o4xnAUvum3^3WyNE8TUIQ+4ezb&hmY9iR25Ir4l%Xviv^k%WBR}U4J3zO7!bOp{Q z{4NK?2QI8TtNMXz%c^J9yxX8-eUC=tYZli2#jj@2gW$eF&w@I6j(6Ly`^Dw1-ePQI zc&D>zPpe!syHsDqII*c}o@5k%!(PD>zFkmQU##b*C`(zmqlcqd{>VQ#`da&2Mwe`; zxKw<#$SJ=p$3MqEvu}nveOl6Y$qQnRCY=7V`%7lj!*_Q-_j+IWBIw@IC#rbY(u%Z% z)FvtKGlcYPxsTzZ%Q@%u((k(K`bVy-Jw1F!xvlbP{B*&*i47V@~e+tLi+LdCs9eUHxK%U1}|>QKxp5@Z9QNp@(WF z)$CGjPek*8k-nq60)36%F`l6=Cg%$0rcTeDP8h}-2I_k19_fc@7izwj9I7)?S9O?_ zBi0jtlKx{mg->jWtX3Mxw`5<3;LuxROFx@K3!fELE8JUD_H|EASXQ5m@#$@nkECpm z>y~Ux&?Ih%oBriqjQJ_avr~?wKFiN4x9Kmq?DN=S$k$$U z+UxPmHzF|2ROp&&y5i!dPFMH90egx%oE2$mnnNvLDlXV&sdD5KO0l`Bb-Vex^0WA# z?my!r^<3jE!x-0@UTp&_f_%dYYSoSSCp@`Ym1_T1Yf=4ujdfM`20r%b>D$Avp_kgH z-1VOGa>HSh%GAc>=i1&m#h}vna^7ngtZ!k6(8p=xHT5)&)DNYvq9RRKZ;{HeOSatJ zKyKma&2QP$E812xENf_9U3fGr!}ZpLBi9p~&$c-@mW^ z;kTEbPmkSscaOh2UHG7&YeN0FwEU;$JMweG9G4ik9i~NIXZ)>x?S1mRUVB#a*co)( zXP(wus8G9zPE4<>&9;|sEtpaKu5@n2adTREV3})ankAiA$I9rE`n3AAVYK%u-#eZs zeKmoLtLzMaUdylUjGC81wudeWT^S*V?helLU*l`{J>hBhm~FaZk_>kY@y2UTuEr@E z7kyRD7h|wfq~VP*U6-TlqCKe7=$#RTO;Qb3*VXhCoy98HMd~D9vx~N2=8F}dO81t} zE_#z^%b1&aJj0g$UCz_w`tcEo9`SwSZhUDJ6aO~wL)qK+Z=1ci^I+gT`uIp| zetJ$`?V?)t$(%Xmy8huG6jJD6tG+g9USOtQppVY@(LFt2h31Ji*6wScYHn%mXnkou zQF<(|L197RpT+k|!^`qYo>&G9WKh-pVL2h;Rh|511a$KZ^B(B7$vNJ+jn3?}LjRp}H|Hp$Rv&;y6l&8D_x`3* zQ#;c|V~PHvliqkl^HjH5-9{U%T`MgR7W1F&bL92*eia|eUKGU@xaWP%dzv*PyLpBs z{d{h5+S$aZNna8k#tcnZ78Cj5ulHduXFQ$!$oavrTk9W&tc2e`*dH^`?5JLQ|wi6Tbz%rJ-)2GlR7u`-4s4=CHG2 z>X4NYtwVKz=>fI826|+;-*PK9t#vNYcho;M?9}&m`s8FcmOG`G_PWHnsN7b$PH~;# zmf~E;<&?3e(``eRHb^&HR_}cN>>*d2w3*YW|r+)c5(N=WIzDGV<)Kgh!ec|%icrEaO z{~bS-zuv7Nz}Ib&cW1S);iayHp;TSlde_peaCW&zK}2a>Y1NX*0(HTp(wO3ZDvHX! zH}|v`V1Gv|jTUZ6Gjz9fzqv{-C*A6p0z4=Cw)64w9TKVun;DiJxGTsUm{w(Z@TB0} z;3YwaJo|gL_pasl#^X=ZSwtrD+_$;k#sSB9Ctt%@r(y2>y#Mjp>$b)8pVKnet**bh zt~ZS}v@mqg%+Rb59fF46wC%C|ZGBxG(5b7TxpR=y5~r!A2lMR(90pOQ(v1u{rkBf?p$U`L6W()%~?Lhi`7l^ zR*#~9r2(CTkNBOgveBoTzZJ4M(zSyr#BGu;M^#}BH2+<4w=BD~Z)v~cYDJ0p#rf{V zBMW;LZz+3jE|C9Ve@pMA9_l|d>-76w?s={7mxBs|Hdm=191*-JbbQF?uqPqY!a~Bv zRUKT_&+nagU5|Di@7-^^uW`HMyvnJrp||d~Hdnt#A8zdD6l|>K)XcC!KUzQ2pms_z zK6CbVO4D!AHPt4ogVb8d73yIvAfbh2<|(bjx*?&r*QX&$M^Q;(!Y zrZi0piJzG?KW;?Q>d1hYfswk%Hc>r4m3$iZe*5RYqWlso(t8wMHV>1#;4~;$chza8 z#|EDy?}mZD26YQA4~+F~?>XJIzNx)wsd2brlBPe7eEXMQDqme(yZC8=fAOoLijsMy zJ&NBJk1VcP9%&vcHx)7^hvuqwpZ>UEyi0)B>;TUylY%@#&j$Yy_HSr@$m-zaVBe6e z;OD{iz#QL09s}Kmnl_mRIbAY(8}I25`}`&0kD z;RoX;LzJOFe?%v%SE~%-7-1-H=lISMYdLOyTr#G(V?n3f<&c1guis~A(we6IlzJ?A zN8-7}-;;O7UQgH@;}Vk@Jv63M)aXywquk%yKTrAmIjUdc;p8TT&x$?dAr+_wd$@(S}N`8szJv!y{=vu-nJu*82((9Dpg;VZ*Kt9ga%BHo22hp!61 z7x8D<#-Q4MQJzQLHkcMU#TZ=mg}OhrDVh(enqnWJ7n`73ta_&n(l*oms@&1gwt&dL?;#@!sBZDr=)D_H5z#i#Q3$Ps)ST z=%nch|0D*)=f;L46vcLnnH0xAyT`2kcsVBK^R3u}$*WQg#a&8$tn(aW?0Qw0)YFve zlJ4EWC(7Ti>WiQUfsx)7-i7Y*o{K!x?i1ZQ=~wCW%uBdndsMNjqJ8Q2<|<`3%9fQq zEc#GXw`h0Cf$~%42ac8U7|~O8NPR|k##rX^yJw!yvcMfx+lRailR{gE{}wVU)D{vN zrVpK7bxxqapBe2uET%S2EsPr7K%J}BuI{5+#k|-A)Pt=Se-v9_cd@SOs#>qUqK;QD z)wt>O`kML)`e^M7ZD&<|?7Yt5XY7yd1MD_SyRs+6ql*5`v*bU>b;^5^C1)+omeWpV z+)S>X@+5h3O7E1qDVI~H#l|HY600YijB6KDKe==K?XSbq@0UC&x=?!8Y_?vN8whn& zbqx8=HaDv~^Kat&&U=kloVyod?M7brTu-=ll>+@O!r7_}!0J(iZZ+wouupBQ#y)@;YGH`wjb>`kWf}y9#xiFdby}}QEu6Z5_7@JoIb_p zbGziN&5q8ilDRd@FKcO5x9n3nFH)YRHA*{^(lWuACZ-Ka?fZ32maSw&QQZn&a>$xv zIiZw`2Xu9eS^5s9#vb#$`gj(*P4JgaiH9*c!l|P z$?kGhS&fo;#g7ZV6mQHaE^M9kXMRqGzPNvSQhr7Hq3khVFK6a{)u(B`HcRP~ek`+F zX0`l@d2XerihnQtWIkeDZS&!S#6Hp-{X)Y%r`gWgrme<$@L35ipPb4~Q=Er87wWV1 z73#s_K=Fb+oL`h*+rQWr+kdq*uwJmdv{)>o?Bndi9Pv06?1(zF53B>Lr#`RhrE6)p zqg&;qa~fnSb9(La)^(@rA-7R(9cST3cCo6odK9>LN%-+*+%>IxSiFH)FH4b%s z-CpfJc#$xt&FTUAE4r!L6r)aeR+_H6$;|8#t1Y|`_V7cnEnRVX-A0)t$18p8dTFlw zOjOtac1+TVze^vPL)@mh3VqtD8)I<1Byb`N@l9Qb=oCkRCbuyx+*$PAEg6pD5nbjRjb)7wMw;7x~IOvBebzn2PsXpT6LFo zQeTv8ViWPSv_?r|&ly)^buTp1&$89h*yLyWP}xnYFSXTn;I*{Bs*ypYpT_?(Wynul z>+urpeeHRjLH9zn+o^`MSZy@U*8Hr}84ilo#lh;Xx|-5gRzq_@^+J4M5A_^dWHI$N zk1)-#*3tc~Jyhc2ZYt-d+d0cZMyJ(wZ(I`NQ8#;=%Ui*#_^!`{oc96gUwT%ro%q>3 zxh%l2tzsQ>7e`ulYFbtt(!9^pxvt5H)oUxzcgS*ExK=t+nP{0P94WnN8K+vB;q9?H zd9Rx;GefXiGRuwpc*#N`rl6OwrL0sKl($IK6&F~pmX_FNIQ03e)jdiQ#e?PRSr~MB zrE=5ZA>|>}I!TzTiWe^GMx#xcB;FF&YQG9;hA45UPL$Ggc}fedsM~6dF_p-5B%@vy z3iJ~y0=&AFl^W-APgO(B^3rp@5A$<8s+!k0kGC&n6I`Q;MtFBC)w$O*A9A(`%hW5i z7v-ZSMNZaU7MDovR1wNWog_5W4dVw4F;-^qkteE#$_F(WjvZnf(}{w=z4n*x)_*V` zGKh{MWvVztzt%iSeN4SperrBvd@qOCxbmAm!;+{wS1c84RdZ;a996E zSy(zx^E2dPq4_6PP$9D(1x2nCGWKhZT8gZK6qsrD{1O}G=N-Gr%Qd%w<7e||$6)?8 zE5?0W>IP$Jwv(#R_Dk^tak9C+<$-w$uPC{~hL-M9|7Z?F%~TznS_j##>Rro|b@B30 zm5Vk_`J@R``)NM0Rho}#3riEHq2ZZ_u7`T7GSV=F=cr~w?~bsw>ISxh&Jh*MT=tnq z>eFo9RXtVL%ClW=mtS{+aHz(s^VuvRTnd+G8Mj#T46Bq{23cKdpX#=`WQAvRL8w=T z^`11&&6sDYJ~3fu1W#5+jEHX$8jyX~&#!EQL8G?Ii`2W!*L1Bc8xj9pk><-cRjYL+ z#ap$rY--(slJ|yF@<5&^4d>%DnPxYko7~1;U-=}&NJ`m7br+?XV~1_A^`iQ1;ZzwY zta7HKcIiRai&^_sZOf)A{^bMsKsl|Vk(5w2-Qb^L@tm5p#C1jLVCPZhCgwA$p~Z&{ zE58nPJ)U{h@Fs1B@#oSIOMC5-vZtyV<)bwVi#Hov7F98JEA`d)lE<)F25-kLV~}~c z!Oysgi(fk;$sDRqCHC6SXnoJKF>HS#Yq= zQVV&n=9*wsy;uIwc4rTywqjS-N%9urj0bIMQ%lEaV~Bl#E>FB7o>r<$Ug{FLfkq{M zRR1I9*)*E1Qpbu6CzCbT`AtCwpS=YeTsxKCXO~!ad7P@5+{Rc^vcZlPIe zUMg2L7nhx2=dC?TTdD_E6o}U>e+i2nHA*+it4aq5A!Xad3l*oumL0Qtd%lA|5;lk(>`OH@Y(3?h%4}I94pWzDS4$B_)N4w1C0_cT>ovXD2W`Hz zPkUQmlRZ+M=JnOX?DtjY*a`72n=X7Z9O1tj0+bcfblVGMsoYCCtnR~{QOwa=zfW;e z-NQ+GKOk^_RM*h#P)}Fw*JVndRUtwo+bRr}cd$vSQM#>~A!;suUgJjD1J+AcKkJ{4 zLymYy3wv$3$Pr}ac5baNyIF&A>O4t)Ek`?sO4!CstSj);+*mnq<=U~HLLn+T&mnto z7^-_M$RJcuJ*;7W1MlE0GX5fwecBfJeuQx9gRH@x$k!_1{S^^;7Jk5E8;MHm#XuV| zA-gP+fU7pOuA^|L^ePKa;zfd&_mf3*-bBM>A3bgHkbkZDmW}!ZHMCcd`&u8E5QJ<yO6oq78#~Kp#L;r z1$_Xr$5vz+4o9}+E@VUgjZEF0$Y(x=8tzNTn4XGU?JfB29)7!r3cSnU*IQt1zC#A; zf54+U0wj{_z(snBf7hYrcL*x!XXEGj`1W_yldr?|$AHf93R!m-FvdA#WJe)a&jjtT zp`!K}a(mw*bGs=jb4{qcr+zM@k*C%R+AocTo>E zSByzDwXIOoOqIR@KB)@Xr96E?3;vA!({f}cr$G`tke3~bT>FN&e^`7?%tqSZDd{f0x5&KM7iUBBkGl%L1PbOao5ME$X^9s ziw4<$gfv1mQm**c6V>*Cp!{H*wDdwAIb~3LKrb3&loDtMRmG(uk1ZE7ibi%@Eb@`5 z8qE#(Sgy!-lfnOTWV%-%KiC4inP?3{Mb{=gy$@F3_ou7|#l%CaU#&h0MMJWRJ%| z8mPXl1UOiU$hN0GodKW%)esls4kL2k>p|9hA$3%lZN&_}exry9QbIkoT7jNlG3$R( z3Go`XnNZ&xkdFZ90M&$+L&^wGEC`$&3h8YC2?ByABp?duIWdru4&d+<@MtbZ`X1jC z8bk}Ar_92Yi-C946?Zj(d>}{q8*hWb?OKpK>N-ai_36muuZeed!RLgCtwZL1I^J5s zQ(Q0>RYy~GeKd0LuR>>&q1k3o@H1#s4i0O<=ZEWCp81?P$UtLs z8BK)bQK!f>d?P~NBJjBt68Q=LzXMyf8kX-Ca?>*)1y?YyQ^>x4_Dwn;;*O_~(08ad zh{SkF8094_+Xh^>9QXZ!w`qgfUBL5mFxHQl>jC5x41ix@(Dn>vtT@N8ORHBkRK4SA z)6{RXUy#v*rm7qXV!*#=NhEBwMU5l5_5{^CQi zM=@6jP_<$|usUp~vKQ|lD4mp#;BTtXgzZ%x3mbSD^6dTKy`y2(OsFa7fVa(p)elrg zA@4l~S>V+XHI)EMGyviXtC!n2#ZTlbpM?Q#HP3C3FztFeNc*}J>LEg(x%Vu^$U7{^AlsWY^tk#?s zH}mH>F&bjsWS_?0D+SmsYb%~oBtFXapQX7~wwBnl_(frrcvqYuR$(G)NP5Wolwf`k zcAzP6JMRf+7*HM9GGs7JQaU3qI}z0q$Dmb`(olXQKXV+H-y_3#EbA+75Cw4$>m%Ie zMLbA$=2#IU*1O1p#d^|OO`#@1vr4;2b%?nsC5X>kJKoCu_%qObx!6RCM*QN$o5}~U zc9HBC99#JZp^>ylG>PwomaMv1E}WARq_^TSWbs~7ewG*D-ub{XN(GXZ08j7`s99V1 z2sy#Q?N1!*f4ws>=~_S9z=Bj9u+;x4YZB*lXFB+0WX89qk;O z9SM%94nM~z`#}43TVwn8j#cswz7sucB9Z$v5z+Ag7R7c-yH(TFwKe-S)ie*(pVbr5 zQrt^@P%WxossdDvByUM2trN$Jlf*Gtxdn&@{0$X{i_4oM%!tAj_~wm3`=lPn;lQvShx_XAQ~yWORe(2a>%|6RQrz^WE zw#xg-U(3AGh0?W>>EipaYInK^^^9ywp2MmnQ8YVn6YD26dfsoMwcz}H2%PzIs7~C) zT0EH?Om(DAVGs5usz69GnLI=cBaahtqytR1okc2)e`x~ve}sC%6@Ike<9p!|aDA_x zL!5%Wn*Eh^vPEy|X-qKmExTEIu%vx)m%?U6FN>$>>sxj>8Qvr6MK=|9VBJh5CWqBY zYbe@kFXNx7#hua_FxMc5PmpY?e zBR%`Mvm%1t&a%>H(q_^~=_P41S!d+ahiiWXRSb3q{i~U&>Z15ZmLd(2gh(2Rz06+v zJ81wH{$wFq5cqMzLtzf8pBBItRsn27L<^FjOjIE<+;^!c>NB;MT0@_pGw1ktTe4Lpe1O@Er#TgTX|J0soSJT_mHAR+59YH2_D4&^%KZ>mgX zZ`C7pu+|)0S0@h{6&$MFsB9{$CP`vb+3Dd$lTw#+lEDg9gD%MfQd>L5MieOtMn{2~FYh?JWhBipa)r|lj5Gq@;t zitcu>OefcU3W^ANs~)6UCGR0imF|(OhEF83`>8TfLtz~^$xn07cx~V$(Sv-CarQO% zXzPpSBerg1eBvJxS~gO;PLd_w3tK*5gT(E{^Toxil$}Diq^gi(G5gwrmDSdO2m3d@ z`4L>HKi%IJbQnW?7d=gYVM(@~v2-@=GGrLGl|3(uG`7~qmSvUfE(k05o-;A0S9XQG z(WPEXPyb13xO9eWAsbJ# z>s?8v+#>T&rf_uWW=jK?*EfrA3iOP^qR+ zb3?U8c}MX`!6-*6ZYuI+uf;E@{{oTxRsRz1HtKiM=$A|aY8U<$R&lqvgurv;7)s*%3CKa#Hi@L6VX3p315!gWM(I=tkrKVj{7b3T2$^T(%E$h^!;pEG*<5Jb^mS z`e4CqEUe*s_#gR}dd|Cf$6~t)81QMvnubS)wuTl4dD*to$|Z|_M*p1mV^Z$YAE~*e z1?!EW*7?5cqP?U-+@JVP%%py!&dM`16)HT6ZXUBQ;zm%Y;v>^XlngKaOrI5xMFn4u zYmD(&DOI?y*iyFHa>%*TJ;$@pHOaBZQP*|BH&`^4Q7G@I1?``~KQ!?{2H>yiYhpF? zv}ZI&HPzL}RDUYAN+*kNGEFEAwFTqlLi!3hkLrMm+A%^b*PZ`@8oRgjTy~4NHG7Nx zof<{HqOzG{u^^i-UoGz=+rge^no)D9iZsKdvy0e9tcvbJ28)hh6kRJc1)IH=SW&)e z6AauUVEO8K?|G!|d#)q)_qK25d&cj^6=hLnv-K-VI+n~SoL!Lmb8UWN{^UGKo+r0W ze!70Gxv~d4a@1aSGC3fyAOByhI2wF4;$F;y=;mP|YQ4BKeSx&n-K4{%tz@~P1HL?) zRIe`?Q+mtr(B^mabIo;~%|%SroWPd{M{eMd4-d0s73AJ(2xPt^|7whrp6JFgoM zx+-LgE}*ff=E>_xLnWCKU@@g-(v{-6Yy=S#_?uJln*-B`EA(7uE}KMmK(*f-Q9Ahy zc;9QXMv5&!wyYIfnXl{xMETw_O5RkqT+)&a&~jY*@hjc*=!LTsnQnMiu1sJYON8^DPJ`NS#4%{Bqi z#T8VOeW5h$T(MQ!PJU0;R(3|B7i-y;>;lYzvY5S$h5kUzCC`APc9YNzJDs`0Oss+f z{2yGA_n>>Y%WLmy>*E-2kvMqEKjyi{Cx+=oSM-+)4whJpzm^;=YF3g~FeUF$?t$#i z1yVQU4MF;dMCT<+nyMEl}$5VH@~y7z6jK7R$;eO7IqH%2a_o+RGiax4O>}ZLxt?H z_rY(}I}|h2^|jHVQ^F%dJA|xJtY_;}vx%W}xL6_C&r0c@LYQwnat8zWp2P|2B<60D z0%rvyR((C$H)50YFL@1_Pclf7EU7EKCi_jkS@BT*Rz5{SN@%tLU5$B0y`sO7573^!rIh#Lc+r>VQ17|QL(a~jZv>xFAGZ7{4H%u zZt<2l`Z_JnkDjgG>hZXN%*qqrkSr(ZE4DN-Pzve$!~v6!Qo3JYTKIy&T) zwvT*0d4m7soa0ONHxf>8!`(NR`KktYjQ43tkp4a#95Nr>Ss!X)fM zdHECm%bq@{Te#r6AGm``>Bgwo?;J3TmQ&NkRVA;b8p&SC7*-*XNI%H7DGs2Lq@QF7 zdzz%TO)=q{n z(*l#F^iIN418Sg$%D%oC=9Xv1eeQ<_)kW|5B ziX)ZjL2W{fLDyxK$=6(pXRg1Ga3^q7DD%9uJ~xIM9vZ${S2&DL&Nabt$NryVp!=~e zEMTD0=*Lo>l#pu`5wcOrX3D{uUfL%?e{1`zYpI$mMDpi~oeICgq^Kc1#r~kXB7f7H z5EDVf-asc_?f=7X^*-~x0h+(POqOZaQp)KDTb1IKt>(}{Dc<#EcxYpZOy9yo8+&!FC zU0dxUTU|@E?V@F>W3Tluo81y?>1S$dK4xletY>zW1WYH*Ph1Oq{rKIS*BnOX|_(y3mE9np7&$S2PG@Q7>$%cFr-%De?|+ zi`}Kp9nM+qslKCJs{lh%bbF?heJ`#hJ1GlS)KFYh?NZNG)ld(RH~4)9 z*D+HhyMZ?TENvv-#*P+$7x$0_$(~7Hu`QS<)NsuIFH!L7svR|~;G`}cYQ#!8XPT3oC zBUm_AQ2QRZwz!t~L&ze<9<@uGsOZ7Esl{wSQ6cC}$SGZ}B9@*?^xYeaU;kJKOciS;aQWVRHBL&BXr2D`E%M`H$!ebfGv_I#vEc-c_koCMb8x zXxScV3u&fosdSHYj`$q&hWv+|MfsTpOcs+(d5MK08hbH^$p6SY)KQ`k_52+InWzMx zO&_6VQcA_sa4EsBq!6dLNStILYjpFO#A>vnTM^?`M#$?f6S_;(SP--1^ zC=ViIm(H)^oqQ8Mm%HT8_Ojj(Pd9h4r<3Q6>pxc&cZ&T_M@QQ*d#T-LW$e{#-OZ(z zR;Krso~Bc#1;%swHTq7bTlP)v23%j?IIgC%16EjnvzMib>XC{!vOCNtP+K)sWdz?1 z>Z2JUo5tLx%+xM+jif5uk;>rWy^md`?moU0?mu6wFV)%AG279_-Q4${>nhB{u1~&j zCGd=_59Gc?+E%h%s*<*qWymJTH_MA;8hJh0CFwoMX7LsF6;qcP&)j7Ouy>eZ`X6#F z`IQLAKFV3T3-uOt?FQ_Ky(SuyKdI-)k#>RYsuCXJ7VV_|q0f;A2|!lqeQaxShNLaK zgRRaEWco-Zh>Ijm*hsb=H3C^{7oh`^{%0VDf5oNvKly-A^P0RrJXJiqT-%%%?b~d- zZB?zu?3*2RZ27i>rYq)G#){@{=H_;b`6yNp+e(_Fa&omk-8{x|)a!DLbuY5scW-bG zBu=r5l$Ye4(I>iM?EPIaI_QtEI=UTdk9Zi(Vb`*%yt3kfv_9RKn+9}Vn(s2^q<;R1 z=d-gHeD|!g3wCzqaBa9*sBIkX9~Y=fd!Z|32U|uG)rqu|N+|416C{B=Sg#vq zNfYSVm?wXvmogosehI3qB(=p8#5dS<$wA2g*$~MX@ik@yc6qjv@2G9$QdA=^;wy9G zePRCNzJ1;gp11Cb&K1sAwoKbW8*dwC&9JVu_A$+{tTDYdw>2f$3N6hnTP;DRzs;k| zx*Kx!Cykw~`&~-^->52>;Mwd7^KJ^%6%UYGm9)GGYRamx`(+O`TXYXXay7~Fc=Vh0 zlIAjtvV;5?yNBqMV!7wOQlFR)3fv@yQ`@OD zDv@5o++t%TT4_3LbW5^P`me0MJXl^^wodX%>=oN2=VguLwEU{%Eqj8g&a_AW$`qew zE77CKhT!W_5VgqL*t?b!1yDv?1IGP(gr0anejz24n{v{1u*TXVKFqEbUlxB9zn1Kf z{197#E5AkOA|70%QYZ(w4|2F*PUMg9KlVMxE>$<@Cws1Kh3%Miv-PcIygA?W!We8E zXgp(RWZY@!ggu+u=Gvxarmd!8!&YNcLyGB&G0jwLj%2XE)BT11BY_cw zmJP7IrKyts(uPv4?2>$qYKXd#x~b~4LMCr0`&YU}S}JKF(Xn&s`#@oA1Ea=5VFN#k z8|Xje@9%%_pW$DFJ;Q_8JB;VwgW01uPIH>k4bA zRfXNZbCxfbLd!#IQ~PbOI*jxz^q%xd{3@;`ze#{@2kE5xF#Xxn;>MCf$zJJwX^1Q! zjg~27Po+zN`o1l(iNCUa*m2Ar`Z;1wFgXmlkvh;}T7+E1Yhg0>TN(;F;WuF)4nc@U zPUt1FkefxEXbNEjisfI@N{*v4sG4*t9nY*_lG)2_53xhsM6yS6Kyp!X8$E2LWP+rU zc_l~38I?}&WgsR!5u22V%xB3fs8!1%^ zJI^G8$QQ&+A{H$y1Cxgx%sP~CkZ**woRXXEFY&eWm3fzXn|oh+26`^LySnpSn_M$p zO4-?lx zy}>dzp1s8MV@l|;7_VMX+o&$sTQ&n1bckF|b|Pz$B|yC#C6*BDi7D6-AA-!;RN@h# z#oqHJ(nUs53n)9ahi=N`F=N=btV-NRd{yiduZ10NNRCM6OJ+!VNa{&yflf06H4Kf# zZuSX#mR-O$Wl8osvy$n>5KIO=pN^*QQ==$?T93WrpW)l&OUAroikJ4D@eK1+_1N5x+{@hE+&?h~ne1xllDS?vcRBky zrOtPbIgTlgHjZRRoWpD{wBNKZv5&D2wU4mRx4*KtbL7I33tZjYk)9t|YwYu__5TH* zj1$s?T~Oyq1X1cUaw6i(68anc2eY5aXLw*H3B;8W?5t~9fiYrt{SY&ksmuiE_w;`H z586sy!W~t`J*AO{$r}H-IPX*Bh$!s$n*X| zpQL&03{+=Y!3!5M>zHfIJshvF>fgZ3NA98}lgQZUW3WkIx;?zaL*2uk#Z0O)C8J)z zM{AL0B7->-Z9Q;P@_5lsF@?nG9xgiOZ!E>NXZFtQ6v>4!AIL@?EuMoc?q3NsON zzJ<&PM#tpg8m1s0(gwMe@^2qPHK#}_m7Gi_kk5&&Xsak-yl#NIZU`_V8OYQgM%KFq z@I$wSb;3kM96kS$KgxGP?|#g!;o5RBoQ%8T-w%Yp#-Hok=iBZZhuzLLKEZp-JI&kE zTOUV~x5#tWv%>>4G;FKIp?80Gzja4>rg>g^2(J>z!3y4*-ag((-UQz=pV61(U+MSz zH*!hXQSOM{`OnDqjRGE_8<@1^aUA5*E+Khz#-J++paLQO(MTaVrBwLl$C zqUKPjIHh(`8}Z6E>LpbGyZWe)VEj1;pFK&Pp{`P=fi#>&wWR`NBxOaU%_Vn{D}Y>- z0#SFFSchHE(ZGQDMFmi)y#s{YdSv;6fOE-$Udc1yujV2nNL8Qt_P_QYgGE;O=lCc1|MXV|mcA24^k{#AKh!Vx+k8CO zf^vK>eGhyX<$YItnZ7JvK3+@r5q=gG3JzZoGF5&2`#|kkm)p#}VT%?@W6UhBWeJ(5cj;4>P?MC-&;WqrFv2m zVXwA0A5A$hyJL}0yhT1H50GaOGq01|f#%#r&c*j^#%sGVCap#^?Lo$o4avr2Z8CLG@jkt^PS4D~eh{i@&%TP@7OI5|n^D6~bEJVyA&U*QL0oji))#81Vu zgX6ed?hW?^c+_A%i#x`BilWEhpaq>6GQMoN&A2?Or+Guq)c-hCP~n>*xr@*uKu zW65@CpX0YEg;yAGyI-lc^)0;}HfG_q%{EbJ?Ch(9DY2dIs0adE; zzBt^mGf{HfF*Suu#v5io3Dl}`fK>-OJNN%WaV}oj*NE`L4n;Ao-gCw zypc}>4_7LG2=-jT|ILr%hw=juWoN^>!};I&s(c6^j!4l4k9vGv9Cdy@x8ytGeM9)c zd~bYa5x<<@%kRYFEPs`M&cEY7|FUKcV|tKXrca~64I zpQt9$A3kyw8F3}7Fbe*43jX(<^pavqi#>rZ)L?2X{Pj<&|F5GXa^8)oWSn!5c^K*L z!I~Rj%?X%i_a~deGg;DzT=-k!32_E<(S5{bc;RT+uszzbHN3DQc3|Yd;uIiPdJhbk zAECVZpXds7(KZ0rITz|x{ej484O~Yx^mqy8lJBsSm5ThwOx!~=Qw1ZNJ5UL<1 zH^!&i3SHpigMo^jiCpF~U}85TZXN|@;VSS64~37&bejYbu&?0W0~&7`GT?`SgDM36 zChQlcY&d*k2m1dT;C%eR2UR2*V$|+~xzJR2?;2t+MxAr8$z>cjkS)GTTq8~qSMb** z;skL3mfA&ZB32P|iOI0nV4@dB{6?@>HF#YxxF`KUIuxSKUO}<_Ja7z~kyW1u1o3Pj z8~Ov0)&UW+KCqqTn9E>rv4Y7EIKXmV;WXg89sqZG4v4+|K!a@oUTh_-IR)sCiBL=Y z6LY7&fgU(I!gHHqysD1=ABQ!Q7W|YFAb$OV8wX0egi@gZwlBi75s&gikJpL>0}eZI zO>V&@ctPnb!DoY^@feM7N&$kcInX}?p)@=N2-@{P6`ln81$vXm5&2(U6;TM#}mWghwX^^KoRTUgB*}GKfvsC4~ncuM7v?nx#-n@ z0$*{D9Yo&u5CoTT#Y-LgNQi^BkWMzV;^8- zx&iak1}lmNfrj7$O@R%n0K<}iF{=`yLNcD?aKz)4WSofyvMd6xSHV&B*DJMw@cIp! za`gg@fvs+fouxm3@b2=sV}YxgjsH3y(Pk5low$}$Kpx%%+U_3KG0$02c0m{gwfU{8JJC=K)2%9@vcCqH-3$WALk+z%PL96szY9^v^UP zW^#X7sQ_M`hvPe5&j7yT4MveiQ2VZ>T($QA5QjTq z=`A=m<8I4;+m2Uu;jRzBr;p%q3<#m4@T%iznKQ7(3A}m|QR4)ViYM{+N#G&R{5rn} z&qwh713;?o!_RYg*U?|!auT0Ci+*wikK5=m4{%-YfMom*UfBQhP`iNWELQ{(qkom_ z9+Y#F#{>6N4Yn;;SZoNdDaRYP09v>c?!Pnqxd(8--O)q4;k9yo=Dv8PH!w-f(I#{8 zR~z_Q9Pro8;d$k1N@}c98{-u<@RJ_k6oU`}3(!+O!D8R>^BJz~8EjSl^9-)_B5ZXY z*M9)^I1bF`Ay{V4&aA{py8-WAjZbXG z=eFXnO~0(W74z2eNU;a@Kl#hbslW+e1!D9TP^yn%mp5p=SGb-W#F%nbl5z&}a<=$# zC5dt#_j2ZZ-LJVtJnp6(*Ill0TCQ>35Od>}uv80JtR8IA6bRkla5Vg7xi;{bmU#RD z3wDQH218G+FXq!Df#DtvT>3aDU603O9RBJHn@z--Q8+ggzxTm&cf7ynuTS-Xr3b>( zyW#ae@OwwJN4Y=N#eZ!8E0^QEQ}A3JeWP3}AP840g^z;$7uGMwXZwKI4Z?ktbB&iP zK2SJ|k~jP;R|_crtz0FgTuY-|n}Gs)*^ayazlMta7h|~y7-9#klY_pP17FVvr}sBl zEd!6wh|KTt*C+h-34dk5qS^4b%wIm2gEK$yu5!Nj|9>hEXUkQ=O7QqUuX_%9dG@bI zE_@^p@5@D;`ag^RH+Y;B#;qK_`X&s6l#yNILZ)gxOlI^0E@9YcJ$gs6HPpLm3KKY(A9Yv;W}kG_x3zd*cu z{>uyA!86LE-!6RfBSeHnhz3UyNB+V0zeJqbhq${2NbbKe{*1vmR}H@{14exo&R;+@ zIe^-N%NQYz@R2OUr&RnbSGy@!`lto;d>E_~3OiD$oCrq$DUYGm@Olzjzbo+hi18qsAG>Jo&nd`-;%|0CK8sQ(fjCQ66`?70*N%7~9(@U05#_(j#R zi}4ioF^j=dyB(e~82Cd85T50g0UNRQ`GWReia6_qk8D7jkHDVC8XPHL^PLTh^g3L@ z1dNRf;YGE;THO$BS`AmqV~$d;wPS$AKEW1gu&Eyw)}zm8fyws56P(bHEJj^Sx%$(7 zj4r!i^()u|n1P791TFFb{rn|H;{;qwO+=+`@RT+1;`taaCt@Yh3*K1^aV-fWMi%NI zF5|yJyBpf|H1-n8b&dZ(H0g)P(-i%!ChA1o`0hr~E4QJwmm%^`L>sljwbVvCOh?;I z1NwU*=6310f@K&R#-W|_ab%e{VmMX+&UHfXB zEQ*Z&`9iEvN-IiTB3~%6oOv8H#@okpl$#eZHa4o#zxKA~IQj2NvE1WdB(6(x%J=RV zHjz9movzzV@6y&I*M)W>U#jZc&->k|+Rx!j%@@hDo*AB-WwG2Q&u{L{&Ro}K>aJJk zYVOV@y@Q=isvD2Ck{^z&u zEb?;TAAYpZNmxruy+fDSAGk)b5d6MzR6OlKRpe#ji4ZB&zz%SIVJlB^@4YU+f*VAd z=nj(pV!xD;oRn==y;NmrE^7yb%np}?7=xyU-q&o>Jq;QZqz;;+%$4t!Y-Il>PXK+o zIdA|>38Fw7z6yWZ)!K8%!}})kKiw+ssyoDS%^GH{2!5VQWvxm~`L%y^$^GkltGrJ+ zM>EIf{rog4=lbW3-<@B^>81G=Mv2L_vAfQHdee56A4zSS=usHMa5mFh&sbf^iJ}$@x=K4^nCDM z@tAy_{L6g>Sc99HZfuhHuw9rLd%MydTQbX8y=ZyK ztDmpG-pY9TdCM!yr_Qf!pR2r|no%)-@z+GhKSHWV7d%@V9Y$BFl+Y)oVf`9)zSKRK zvbH8!Wpu^55f{W8By$5pZ3CQRe;)kNJ9}1|H{JAkZ$?2*#E**li?(|Hdvs;#iI5+W zD-xNycNijsy~04%?a+Oy0C*Z@*lF`RS9%6n8~zxffB5;s$49R% zub$k`deZt(nW+>V&_N-_+~S)SAGqvveKp@!x{+pcSpB%6)zx+TG;i_y zq>iJzt^32?VsDdn)yF1W4vSDdW>vh{(Lg^T&-raY#@ft>-x}m)=yzCVct-O|$y%m~ zB33siq)F7@6$5cZj4Y}s_+MpjMOUgWQRF;sli3I4*D<=%t9{${!SS(0dPL^A^t3P6 zK2>^I>3y{qzPBqr6r}5l*Sq!Fhsiag1J(YCq!U7u4^}N`=xwsAPV-vFmGTOeq_MvII==1mV z?w{v>`lm3$G=^@INJgYqTa?%+URsN+?5#PzmM$ryO5YfNg?Y+m-X*4-uL+rJzFN|2 zeQ%wiOK<(9-iJ;(Rf_b6Q$o7f6dtJDvUAal}$=w@LjB5PJW5w|=gCuwsgH=0(FGvsf=h9%jn-l;LW=BC8d#P$h^$=#}ts(z=^_6oE{LU6uBcZ~J1WsqgQ zsjp{~dzbGZC-a}+XVE2+uDUN#g%SH>wuYsL4pco?jTf7-%TnL7%{jsyWXv_MD=5gZ z7X*L$?Ps@aE?-`-q+s=rtKTPo{r&x*wACNtzNCIQkayA%L!?wrir8KGSYm9XBx!O& zcIA(Cqmteyeu{NyUa59Cml`BF>%P>^8v6OS{F`~FzK_V;_H9W)qRrswFHH^p8r`qv z&FXah$U28Adn*l$4NGWL>F>BB73b*wP^=LhcD{EUGDTVsSo+ysaFKz#bZ_P{_>_IeQI`}g7fx9>cuh2=(~w+Vm?Mrh@M-~mbg3Fm2@_4cUX#g zB-6}xx-c?(apv96!?IuIk1IG`s3@41ms@(?){%cBdmNr!DKn*YExK+|jmk;8D+Wj1 zj7^DE#jcK!2j3Gf;)Co1EZ2=L^GC}n_hZp}`m(s4q>}imbglYc@ba)~Q5O;-qwnhC zRH@8R@}j7nce&%1qmAi;Ey~=wq(Z^ztdgvM@~aj+ExcUTr}#xd%#S_kB_Ec(>G$dA z`>`40@-~@@e5azQ&^-w)Vv8dZ;~K?RVIr>($NO zm?QjX^z(TRU1&EgcMX&64IWx?L)9(Suh$H!IWhTrrBBfVA`2siL=Fyj249ue7GBvL zmXBp`OtTKd-oJfuqc7Qd%!~a;?r{S7F!BuFO*zPc#2bkNtKs?@*D)TA-U2 zfwtG{0bR?#04? zi=Wxgfl2YGUXj}{qt~Z|k8{)BeSP=oc8mqk{L&5ChUwk69e|5bZ) zm2t7#BX%m*h+U%BPNL}1&+eJu(xe#&J`-81^9E)cikexo-7O_!)L|9sRmx9lk^H#! zyOg5py(^W)bPHb@TB7+NUnqqhkEfT_ibY=^v)MAzljxg7#>rMIt_E+?4G-HFQ!Q>| z#jWvmqS}VcS67ngNWbvX73FwjduteKd}W}^=9Mkezb;vB9&Al>T=5O$w|ftJezz<$ ztSXAgewI@`{Yb|APfy+)eB1rY_l%I7hNjol68Vwn@Wf_Gu9T0JhbDV#+^nomN~$Oe zsTAZAEwtyBtjzA6*6iKDx1B%OKbL;X%eEHvGwpM%B&I0BqxMz~s^O{MyK!DUeeKZX z{_$NS-63l<7WpKJGH}>E!`!cIPRZ$#|MVU04LwVVhO!j35FArsQuLwN3l+O0&~a-c zlcS#mpH@9%y9NkXnEhXKZQ}z|7Esg;jifCc2=kt<_P)8^O#g0C3$6}d=&osg^|MOO z$j`s0|M%v{hvTmvz8wB>=x0OjBV#9`qvA$|6^X-Z+^RXe!Jl;|);m!nrfOEaC#;U9 zK9lFjD7ljz_^|4A#`7Aln!i1qcKPeJ{Gzf(p3&4I)xn6@l?tn0sVg)QH<{k3sE#Mu z64x)>pqe0iF8bjsvOh6|mb#153pNzq)E_ZdbWayq6njISL@5*3B&ic?#-+sdj7^KY z7V#pyhjzb=qo4W;oH@3i#sT`P=IWMzt>LbT?(?3Ft{(oDP8sqDF{lg+r3Q)C_#4<_ z41Inc%eBX<9nL}K>xPM?`$~2drRy6T?^ye}4hJqu=BQ&Mgo*FQ=tJD1`kSM;V zObt07-6QU0LXS${l5WQBO_&=sE_8HIl=Lk)X@dyGop6mlsna`SM5-#1++J|={_nB)AfCIOnZLb&lG<+ z`|jnt-_o9EzR#tLcUtZ{Qv(L^NtHjeQnV{^W%BTpu{F&pU6U^*oQ%lP{v~}Py6-+< zQ5cpNzbc$wxTmOQN%Jy=sm$KZzm%%4C=RY2xi7v(<>tu;tA$jpSZPOWi?9~j)v~uF z#gF#%wSF{~l;#&_75!WExp;=Yv03E^_sV9cHP`W68WcK#&ROY$Nk}r!g7JR<;x!&iSU!%S@{F;`v>YK4(ewmDyD0yA8 zglCoJSI)1JoOCb#Xx!TH-C{+{?TWSv7j(B2E!bVcDEBzav9fE$j|$UiLs89Qvce!I+_ORGg>cftb+9NnvhnmhvoHi@4^?aeOlm z*S{_7|1;%>Dd$qo^X!8;eRGR)-u`S~@U!@B$ty#j5|d?;VY2auv8P^acxSog$s(ty zXK6P^u;J^%?}JfyUXV^b4lIE2;#bTj@jzxLHNw}}HN%inTDK&)@J7kUk|V|~)+yeZ z!r#m&@qOtw=~dYQ~pW5&TUT7vk6L38sQ-2Y1nyXyIP{F&Om5YWJchQTC8%+Jo!KN9;LuFq|&l?HzH+v<1 zE^$lRM5PEiuFDH{1+CCDQKrhQKs;R*j}!;Q@pL4~2n#&Z9Vg5+4N3a2(i{3M#=6#C zt~tI5!Y5HU9nba@SC=Kq(-e{F6B-5hiL-P{-6QRKm0hM|$C7VQb2;4|foj&P_M46~ z&dIJzZWb|V3V%YRVw2_XwKc=FQPtwk#@pleRxFEH8`4tKUs{!UOMDY-+)}^J7wx&` z8tlxrUAHx|Mcd1qGG7f*YgVOf5d1YPHBug9i>71VMs^IJ9eO;-rw#_|w?)!PtfBe_ zw)-XS8uo1S0mDRnlhWeiKT3C&R@YZ3OE!e)zn5j|6HDKeeJCEQKV3qUPStDmOZ0Mr4CQ5tC_RYO7vX0$h#(1ELLJvS$yA_$k0|>ojhdR;wYn#eE@&RaYLb*2q?6c0 zLMmMMRB)HtPdU!o2RVICkvqvd)~n?kbdp86{Iwt04~k^i)>A_5BNF8*Ufd-cHKP5a6$j2^ij#x(#!gHW#f!abDnvaWveyQ zbk@?_M4Ifzk%n=GSB4>m1U*|;r#PdqVsS$072|UIH}^BX7d?$Vs;sOD3UUNVf|Qz6 zwL^Uu2+jt|5%Sv7)$D92rsQ%{z-`*fq%cg@@6*pN`);gfaoJ~j#&M~luG9c#GiszJ z$h69VntR%wx&tAXLVJZphwav#*0xYql^mwV1^oWYUcp_*b4nlrI(gXD!X2$MTOlE%PdP1ORROg<+=H(X^QESQCFrk7M4~nyFFBCbJqB&Tr~Ja&7$wVA22 z(QZsNJ~cP7J$L-=DfDmTry>Vgfhwd2v7@9l<>SGi`a!cS=(pgE;I%E{O8E zNbfmEoc+3Wnq`?~y|tljfun(YzUPc@wf`dbT&P0)K`vqHNajhymF?BjwReIW=nm^f z1aCzWD^a~)*<8L%mH>q84w*p~BL`eYc1r3J_h-k_U*IE7U>IizdBQ&IUq3_+xIM8G zIr8Pew=Bm_@k~)KBA)!0nnQ!(4_3@z&#)S{BBMoJd>=XhjO;PS!3<$Hv8&mS3{>Ap z4Ur>sL*++9-w`0Ex4P=OjygjfBOD8D^{uzf<4hNgUkr)HxUw9>GlQXQC~Ey*8eaoJ zz0{CV*1jy+m}Yrzm$}pZ9Rgp70#+ftEUTrME`KOnEpH;9rpQy)Q65*0S8|Fk^6v7L zk~VZR;yP#Y#(FBaDm%r_BuAEGp6gF{N3YrYn{SaXh?~WS1e%lEsCjHF=`eX+b${*b z;Fh`(A$HwoT}a5d;Pjxj+EjJ63W`3;MvD5f&*JH*W2r(q0+;!T{u^G2_lD<`_o>(A z^MQ}S&z(grs!6CH_$mAcUf!idA~2j?fDJs${01a?h&Y4!2l_1|fSc<{_XDbWF^~ri zWxIC4d2>4jo5;dgc9>V0>Y55oYmIkI2Td~5 z40Dn#)|ucw?~mu7ijGrrn7<^!vcs|hSwQ|=-bp!2C0C76*_9iV>*P!1M^TrJsuI-y zKSSMCPj7$EPFJq$v+I)U0FZ+NJip^kX8Kz5PUy9k&^^TyWFwSgG*Lm{f^&3vx;`QO zL*hbO>h=fM4l-$HXmV6)<#d@=QlF_#T^6+o3`Tva&ez%-TT(b^Th)>*PDyu zw(?T2@m+&rsDaoA7J$3-ec)3^05R^R-PBWRGcal_v?eL+OkW4P#R>02-!9h}x7N`P zxbPviDpoJnCKF7o?VMR>eQ!2d4w&wvUh0RTnz^>&is7x%Z|m>;;)PP7P?yqE)5I>Z zNmfJJTE@%rh9wG%D)mVrk*ezCF3Cb{z&yj^;vyLy;-wGqYiqbJ)xzw|Eg3fi(;{Sy)?i+ zriW2S$Yz9;7zzHsOWa7VkN=r}qJNNYv2TZWxS#W1<h29Z_h#nGG=2;h-WI7#4+`FU;`=j56J zan1R+01JH>2&nGFA@VQcA=Qs$sC~o$FcwwfqD0-eqx=$xz<2RgamzioTvMFS96zl! zY-JX`xxV$Hkux4BTdIF-_-0J9Zn3K&Jipd|M0g|WOF5{k%yjmvw5`0WIz~HIJ2Ys! z_NXRa{Xw}({!HR$ztVmxj{1XW6!@L*!7AW>`KDEFT~)mPis-@D8^-P;uN zp<})u$Qu1CBn8YwKKX+_$j%n;lM=Gc@`j30ifPKx%F3#`sv^}U-KqSguV zRbZX}Bd3w+z&Z35>hm4AG5&r&!uQ!*#T)5;=BWuB@^tT3-+NCze@m~LTjsO-1Tg=0`(bSDFMBF$A?qivCd-#Sk%h@x$wx~}(ll{@aWD2K_;@yvBcZANSKy8C zO1REH;5DeZ?1CuTPPoIrLG4R+zy%c8BC3q)OJ^|e8I|~lcz~nE}QLdx7fB zCJLbjFowtyiHR66ho&GmtUxY%5wOLThz&py^2G1JT(%)wl5xOj{sm+)L9d`g={odF zdN2KeZh>l?WIBy{fL^hHg|iz2ZxPYl;FS6uJhWzh9C%gCF= zOtLFwAg+`7NCo24+Q8hwSwTa)UrWIUop6!}4D2v*xgB1RZOP9XMyYlk9F2(yV!LS4~W;SkYT zw4RfJzip$agV0sT5OqWzd?Viz6_|!VG%&kT_RAsOr z9%3qqeiEC6b?k&dTjF)#G<_$~OjMivC-8teF63eLSdFeDya)E~FTzTk<+{^f`LWbS zemwaEwUWJ|74LxUcZvStcQNO<8kC0LO?D!x^RK8CTsXB|&_bbkif9%P9NPsAG>+S2 ze)B_gS!e)OudPCJQG-CfXc^ZIo;WvP;%fj!-~~h6BhhW(FpOA_Oa=2=)xZ`}1ED8a z4zFT8xCMwanebL5<|~TofD5fvz|5M?&+>=e{LVg!+h}z;Vy3s%R&SWY-mE1xs7Z!*< zLq(hl5ICPm&K5MHCFCjLV_*|861b&>f%QZ;;RR_4j1yi`3;Cyn7l=R1ZTYu!F+Wte z!ZhWZkw^J2lvH#}*iHrfHabOQ7m`H3fm>#x{|p(6@BJ%KBphTWaI*uoPdjmcNpAR-~1%`NLWI$qHCTEU<2j_Ts{-CggEEk z!Hj~|;d}QBx&&9dN$_!rOtL?ZI_23V9xHV4uO>4|^k%*lOZkRT|KsQy!0TAr@Xo9b zUoT3^q^W zY^Sl=jW_TzXQ0_uU0=lBAyKM>3nv$i6!H(xWVuqFQ)`f_ob z+?z~9^qH-XW17&jq`6X8wA2AKEAH~gq?!JWyNsOcGs!|8lN5E9x!$%=kJWOJH{5HZ zKhzvrdQR^^%3&87%wDqC`pgW}vh#nTv)5G}W?!NGhd5P|^ddpp33eS@ zL?5TO;$N6~$xUzv7;cNob8YEbwLJ7~$J-`IS(&QrFtSiTVayY@X^WvOd4uP*CRPM( zPZyc9wX?h%v1zP%jT@kb7<>4-W}1H9>H{B-z3L%m0oPjVPhTr`TZ|s5C)o2zUD;6c zg*BY*U=$%9MCgasRHLiW&pypxo$H{V(HB^t@J+pPW2zj%B`d?eE0cp;x&Huvm1!?ol>!O6<2K<=O~YJIaPIUR6C7EF$} zH8h*bm4XLJZtIS*N9by6^s9eJFk4lW%0kl~0+UJ#)$o*QKOK?mqB} z7(s@b+d|r;uTlRC%`T;xr&K;T#&f_KBA?S*MHEil?p&05Es{vDWy5|-KHwgm(#qM+ z*l9f#ho!!72kR5XmFWw)fwo@shd0K$AD$!iBTmmSXC|HcXWL%mtp8O0NzL75Id{5) z)zC@l7?{TK))CYLSyJtL6BZMaGq3KE|BOwMd;ejbsS?A4O$L{#`PFi&6q6=k4sF8OC3 z_D3W)A;gyWJZj^QME5n}N}@0F%U^f+J_3X&? z?%j=`v*Fg~Cs|wkyq@)C`a3;2y#1Hqd9J0dR^DW4_4!2p)nC1ihVJvf`execyYGle zF6u1rViQ}^MDjHFSc+j*7x$96MkTpK$m+B$d<)l~ueoww_Z_opaZ{4s=1or+C~T)) z)QD`wU)&BK=vpW5OHHtC3;FV+WQdN;{#o+A@Hqc_ewdN&sS>Gq53*&s|I!wP_m+Ej zKZM`+xW!%QgC#lIbB(33fp?i!D}AKu#%1n`?Re@du5D0-K%Kz1kkxWIqc?xon=^=K zx1~O{3D}eUaemI@Mhvl)PKk6ZV0J6D9oYkG#no(KNh8-~5baoBYN_yrv))jefg$+N>oaAoDeN)Dk+a8lwele$K5nZ3^xw|yIXoH>zPCri4PsO(d_ zhzlGgl4db~gF28A{`Fd=;7eL8IqPq3ixJlQ$J1`|;VfT&=XA@d`@&xO`f|Ux1L^ZK z7yVTzbho;~47VK#banSk>>trq>E{<6XUuSCp>NOfy-j&$7Gn>PSz;~-7gaJ(3$@kd za`~Vw=|w|Yr4(><(a+Gj_Or=F!~FgV&dSLftd3cy|9S||fnCA3yxFCHS+=CwnHNfQ zP<3|3?_yyo(kNHI^jiE~$4lvp_L?TyJEmDy52_ef-4V%4g1-0%I`6BW{PQDM|7l~d ztDO_#m>hELpmqN)WS;MH8_&cn{sk^BeMwj?e|{~i`-VM9m#)tLIWGi0E-KdFA!S#0&Br^1I z4kLk}t#OI*|DL z82J{qAe8&fC_S8BVU1xX+1z942BB@*N9$C`I=a?-K9JE@*68OHH?z!jQCuOdQDm#7 zqcPhq?YOlf^F@CNpFoO)ih4ro999+VBCAD7tDPKQ*wadnuqB8^!HV~P7G_yefqOZ9UeT{J=vEt^?gVobmoKp^k$BQuJ^grmd1wg zQnvO<1w4g9hPmXx+w_~(q_7-psFEkSY1jwWQvVAar%OV&dly;x%pAO) z%${C8UUjWnWaL1u5ZleJGyB<(p9%BNqn|AZ93Vzc>Kv7|bR+W0MIK4X+M zp8H}=Q?975+|lwkEz-=-vF++gkI(irI0aI|40|JceQ4QX1(Q4)3Rz7qb`4loWnc0s{SDlty<=NKD+Xb zZ>v`)1FRqeKGfDb=B~7lUvEw2U#r82%0+{B_)D3~I?QqUWb&2tWWR$q*+siRs#?*c zfZhYT#~X}0!b8+loE6SBw(&|O=J+0G@sHdBGTflfu5=0Z0atR_O~0B{-aup-W31YD`h z&??Z;>TV2TPO<_QPD_E?9L4n}%d`^uZ1xs98SjjE&d^@ayjETFs+IMG7FF|><@DukoFFeMR&vD;w-t#xUFvN9`HLyl5s|T zGYk4W4=ty)1p3n*iEhjWKQI-Hufb+*GJr-|Rme`OGOf(CgqFZjYXHq;eu4(`O5-{y zVA`1(dQOtW{4z^G=_Dsm{kx4Y=CRq1j5E(N`^_Q79H{S%WgPk|)Z3$(?q+N7Tc?;U z;rRYp-)-hZJ^jmcV%2A&KK^0!X5vV5vWuB*lmf4^1!-q2w|=93jb!qfabUFeXKGs; znNHAb&rLEWHF0b&(ucmcmb1V4Q{);t%r~HdwT|gxjRf!VJ9NQ@SfxodatU3mLEzo2 zw6d9LV4+O4I->XV%n*V>0`}zc*=7+KyC2ziBsmY&GaPu$x=b6xxyvqBF@CrWtg%XG3|r z5mw?c839FS864e1;7i!93DO+z_-f-E?xoajoAh~^40J& zo@L%KICG^DqtAu5y^0Bh1yznPsN6TjpPp!|q0o^7ocw8ef^pNxi%I!D%nYEJL%yQN zHG|XukG3crm@C1f={O0Yo#|9E2~YSC*@$)d!~_$8_QS`S!L`jwet@4ckr|J=`8^YZ zZe&i<4>}iXaYmz|a&Q(ry$r0?F6K4VbGBg}vO}FU6S)8u&Ooqz214`s8(2l* zn3uk?jxe9VV=01FDhB41h^WKCy=E>xmy`Jh22peH=zf7svlCzO9=yfz@PDcXHG=y1 z)gBC*Pnb?V$KRX;kEI;PMSJe(fl$T6_AbCNma9aDo$f-XQt6g`CJc#nz7*zH)4)8M6rldE{; zDR|xW@cKGJE21o3S99#%&d?!xZ|M3uqmK~-)wk)!2=ova7}cR>k=Mv+6u@XnS)g`+zD>kIC= zjHe%k@Aw7p=OHvM>SB#Tz_MTv%v$e0#B*d=&d$K7xcK%!E69U{{m3$SpY`?hna%fydC!TMdPe-9jXR4V++<} zxA7X99#xIcdQD@T{u4?PK~R#&WppwkpiVK_}ur(nuki14=lFc*fPxe~%^4uxjLT}S9ati<0qhzu^78}r<^OdUMu%w!)JGNIV{=ke^*z$$A9o?i(t z+Af2|brX9!(JBMxb~AVw3^AL*r%W*x8U>-};4~7T4jrqn(#J#ZDMYt4yM9e0`e!JO z7+M8=s@_*WpdZy=>(Rz?ih86<)xKo#Tzv_SsDGaZNC`%wH>9eNDepf!*S@2Vl@3gzI%l?9q#Pw);d zgE?6T?7Ujws%0=53xjdi8T`fS;JhY->wXySb3?#YZ;TUL4JWGuoY`hWNgxOOC-PaH z;5ThTp}~fTeG3}-&&{rwC)7Zv@4d;Q6YsRX!V%#NoB&>%7oj0=A5)zeDCG1*yzdJ& zmmSb_c!E9h3z`D5`3n8j%INk@!``_M$EXzeJG8{>yA73q_slOQ7hZKV-a%c&rF@7- zlQ1DUi}!d1+6uRzb(26Mp-q)e>`+cANju_c_k^NKH#!sgjQeRFtWjy2l^R5Y@&}>U z@l~7PkJylm#ah%Pe(ZQJ*zNams=G7C(IL$a7T#%`-!jlfu%LesWqpEz?O?=}D!6-U zzHMEDG)p|pQn-Fg)RjXbR9Sw zqoI3n2OfK0vCHe2M~w5vI-F8s?9%r_`Q^60Opntm>ZSBV?YVXzDpH5EEt;gA(vsnN z;nFv2-?S6jCFt7S)UIg`y}9m!u89rbYZJWbZi0=u2ob0a?%r*PSZAUC@CYhSE~vOH zpbzNf|Lc()qsdr}yHL)Fg@(>BI2Oj^BMVdL!BB(o(jjy1ysVWLGxxOElgAJ3U?!-uf<)Rp`+o(3RZ(A$+G`P zj%y^11k*#nJ>7r{Qem8oadZ`S(>#1_GS=}B`V(Qa7%7B|DTF-4ZMF_<>y?OY#lgRB z0Um8f{5c4ITc^QX_JhG(0+^xBVCVP28C`_@CxbtK9-PyD4W&qbwWK&{jM-)gBF-+v0zHd*K;85-swShCpX<02H_?Kvya=`<4DfhF6PT!@boG zaq=~GqYpc=22N^aMD8KF+34?{5!hGQpv@)v; zKls*s9rSuna=E!E?kJSMYO)WY-IbxFbP$oGJkG!*2p0pAqLmgBkC;KbHszr(q-DE&#&;OSM}nrf~!l5zL$)j6ZLG0<#bLEVYGfx6@g z?h=&y9tfGlIMHrfX%p<-?HBA9?Kkb6?FqIJTMqFPJa!&nm&McNh%_kk8w{Q0+0X#{hDu`yeMh_CcLN$$25uy%%R>dSI3jB+ zWUq}dfo+GI@}Tw4jJ7@+-NcuA&UZGO{w(1sfaWWrsp7JQZ#v8&+e zmOlf%04=FxvY2+l%^OKPOglKP^sPy$epjDh6;zG!*|pXZ2tN2=Z$NRE+*ZYS1o_cfmdU}iann+>j zQ>1bW#WnV3&Uj}}*LY`Rr``5LI1Y?nG*kl*0f`_p(dH!GuR4{cGArc??C_72aw{G^ z-ZIHEHWy!7XdzUFYVS^a7DqNmq%*=b(f!RG9wa;KIqKUy{CKVbT-fYn8WaTlsCIXe zAJEjw0?v0DJagmOl~B~PpscqLab>aK=M}gMIk@Za3|fXW)gPF{NV*xXBAr={9H%(W z=R4~@a`2u|ld6c%w;=g&ufKpMTm@*B%_hn4I-6}y)nBRklnv6^K!d~A-}|!!b_C}7n*|2>=J-~5dwSEP6PQh|6W)n)?3*1v#|TG^ z%_n~1Ub3@jbut{gH;UcI#2Itch4O6anBU8)5yzdwU2dX*Xy$R}>x# z-$k!&iT$W!x}%+=u>FZRhW|!ilEVxX+~{g#sd30ZD7{5ES{R=h2(3(!f59&oON*j- z0TK5wcLdRUEpD+XP`o=vA0c`t{2v3e7BS0?%&#Hx?lnNCZi8ypWW?EyP-WW(y{E#s z@1{~hzadwBfSohVyr3`8+Nv87HH%6orQTAgv^JpnHwPN{cl+=8!hNPU#XHK^+PA~s z+}~9S2}A{4{s~a|ESWBNJNl05f0KzBYMa zELRewi+`FPg0YO;O4Tep=o-L<)J$~i+{^6wrvwnKv}((^SeD+oG(;?+t?z)uq^c1nz0pt zd})9RvF9%lPV&JDl6;TQVTKM<*%lbq7>-|ywhJk8OY{lE19#+OXdwD*!s!*WETiO#izDEP!{-N>j53-ihOzQ0aWFhK`FNw zcadAf{bu*EvACg!auNJEKBw4C9An!8wb({Bub3i?=HKEby3Ri0ZUI*n#J;EH=yudO zWl#wmMeg>Dv||Hg5W5N!*fF>VTSNO#g%;pT+Jj9XyP&9e8N0A4bU|6Dh+qPa^;`tq zy_1;WJyO0&C8eqUHU6@J7QQusYW`*ZOMzwn*8W+zf!_LE661dzs2kWV)e2;h-uV7` zu6sk$wtIU8cIz&4X#vN=T`H;~`Q4}fEt7rN;eY<+;*EXds@2aMHfn(9`TDI#*7 zV(JvFj9yDGr#H|KYa!-HOQL1?BN6zE- zEu;rsfqSYajmO>d1=;RHavZtwDbk%LGRxpSSe~9l^x8mkpj%gsM4&VC&_w+LKE(apvG1l83>iO!^eA&Dw1DQQF0_W1Fd)2fh zX>q>Ufev~HAj&I49kf66fRniq!euJ+1G$;N=06eRgf8$HXlm^=F6dp<_Nq;}kKFLA z`UQ%!ZME6zYVDn#8Mv7xh^@-ifu3s=+X_`-S43bpJ%S3+ zq!RlJ^}-}N9ILb#H`;b4H;F^9E*p7+DjA)+gQ(z7J32ykt`Jek1 z$fMMf>R0IRt}@eroq3FmehhN06xvous`gU zieNvIGBg@p#-6AGnvrd&fHL8R+6O$|D#MRH+*tiGlmsusbGRvfPx+J9r)=3p1 z!Bt3=l+60m|2xW@GHRId(BV7@mzvs$n{|Lax4328AE-y~=Du?x&caY&CVNA}w-dV& zN~=TRH(rRGLaz>JE7ZQDu-g;R{|%6C&<_1e`?4(e0$z&}cMgb6p0CVr;UD3nEWDHJ z#4Y3=f)|mSYlF_xYUIMh$syVqRpeF7@@vpP=m&Kqz2M9|f~-I%bT>I|enPM6qqPNC zywT=U<{2~;?+_juJYPr^XmZ!5&nyC0xKUP5=!={A$v(K%C&0filsmvWIN+Mt+-!OD&+?$Q zsKH)BXH7%hb)J?%_k0GrN^$6`?*T&TBu#=U=rQW#N+E+N%QdFG*rwbjT7?_SR-wJP zr8L0m+;rv#)UG#J3%F{u4x_MlP*e6IK46A2RQ|32iPaoajKX>xNwMPerL?HE#@s^M z>TRJ1mep8DfQB=|NKtd0S%buxKa9dmep3a$LIo=ACAz!@94m5JaIQg@si(o?-Ya4Z zv1;fC^<8GVmZ}Rz8C}*tsUc9XouiI5f2jd&g)s^G%kA}_N>_9u3Ti|3CVFMHG%&7< zwDa0UwVL)_9i{F-ucCu7SC2H#p&R$b90rtQ1LDZgFGV-5JdH-hc@DbDyWn|N4-sxF z+Xz*{9D0VffsSftZYEHzdC@-}jo6xrO@cS%1^9BBa7mF`p>wnhLVk3 z4EvJCa3S z9b)VWWGbDkAn1yx9H`_h)hM08Ag-i=w!Vm6PXP}M;~VwVaPG~2h=2WXkYBt)#!RwK&SJsX=4tV z9j(dms%>Ol$8M_tHO~pqKWu7^hW^_JW&}wF{`DkQYdR2iuQ9aQeRC_^MQdkT#8nOsi3j&M|0GX#0m~uL=Ey z+5B$!#jfSMqV`CklgMU#uUxpvQh?yi0hUq}s<6FKoDD(7eFsysXu2O2+i_Zh-3VRP z!gMlL@Hl*QKSE2r6-mTA^BL6Vu0tPfJ~I$m{s>fZQA}5KeRsmIcMW>qgMfdIHm^WA z`jS>$tEGmb-(gpusK1p?>LPfzw3Z^JUeaDEpIlN7lBrx%hJL)fSNiSm5_sdw?jIxf z1bcA=X+jg>E4-QMLppPt_*}w&!gOSz31S{wPk5sIAuzW?KN23)m@m1l1@O>4tb9|Z zXwx*#_-afwr{JXirYqp?a256Fdw!g_+g{kY!r9!}!g0g)LmVbGUiMn0)0aHIw zYpuRhUdrDjyF5zL0vR4pM+26>bfB12&Och}fGXc9weTBKyl+EblyAN-U*M2j4yvIa z@XGgKf3IXpu)Fz-;z|2R`$&h0_ZH9hWP7q%n6>6C^cC(Rf=XtBenveg{|j7}kE&CQ z80HLV3N`ol;2A0WAF+=Oeszvi$0yfcS9SL_S7pb0TP0yHABD_qGQS-@w-KZ#rf4(3 zdTN4uce@pW3G-@ZA=yMFQiGj{yyY}h->dMe;fZjCi)XK~pJ^4=!X7<Hg*`P7T6rho5FKeY)B9>EN+NvX@5B^ZE1nYWenhrlyZiZ;)0q-R4`Vyw~IC zWY$AJz^T**P5e=OIa_mQKG#dfX!{u9ESr;Fu`!BpSO*^5!mEQy!1!4ns z<&0j031fe7KJFfO9UZe8!X&tc33N8+IO-7rw+iyjZ4!LA4V5>@a$~cj$ZcC*HE> z>2gdOmjEF!kEXG|ak{5)=b;)b|L602t9u%x|Cc;Gd06to zkufwqbEC(a7)ig4T->YQS)Aa3W1(0b_mTsQYv znMiLM#kJy2VP4lC{p;HB2<<00ZSnSF_Mf&=;zYhKw-LAO9^501`K){-yrqI^8FCx5 zs>jfCE{eHk92o*d^083Ke~;-;I??G;_AJ+jAI*pHIl0a_*>}Ll`G`BYB{va{-%o*3 z=nVcw6U<(V8I_@Fe-ZPz=jvGHjkHvH7f6wEg6mU6KdEwRQlPGXkvGy?!B^I|Cvd~R z)9?1BrQc2sPCbzPE6EMtx=txco|@7VeJ}720|Yxi50jelY%U>_V{Fix;4C4|pk4Ms z{9p3UdV~k zbFN124z5X#Otu2R0Vkk#{bQEFjs6yQU4LXAbJ->AR7^a=xpXd0pf(38qwThPVj6m( z0ymni4W7_*)YR*_scdn&5>vqDq%!%9$?tyD$meJ$_86VdR%J({o9toT@Z2aa6cN(+ zW?1Q->|)vrHPb=#pzfdpl?G?~>E=PBCUom(7zcn(=wtlS_iJ_4<%nGaWgDC-8p~Vd zBJvz5T&faC@?G;5_MS*Da{qS^HW-I_gqa~aJYOrtE zE3^T6lX2{JezFh)hnl&zH{vKhmOF&0YBTx|HN{e3L@trFmm}u2VVM0c^7&wc%FF^yzPC3eWyLI(mkmUQvM~K zNs<$1rGQNs;Ej!BCEs0ii)X<7iedY3tHg?q$L?mq<%2u8{@P}7W0`E`Mr|3UZa*+Z z*#cjSyy_2SvAk3+f}e9-i!`z_UuiymtyluyN}Ftp#D(G!Azs{JxRw3(~ik|mZ%#>QN zSLtX>O|pWYa~)WP6{Hs?@_pg*vImYWv)KmVwDnu2u)_4{m z9Cm^wbju7kH-LlI7HH~dgE6k^UBTl^R-36S!C0)T)K?ZN9hJMv8JQ}2;6>oKzdO32 zPH$CT7B70k-U*%y>6=nVq~=cPlUmFBD)3RCM=oE*m1O%cS@F8B>VM4oOcZU%mE)dqC)rp`CD+1>up`@un*q0(%tBAB z*HXw)l4lVEg7H93p9NKLXxYm6$e!$hIiZKN~+x4|i>PGdd`b>>gBh*F8 zZ%mn5$qu=K{8`!`xag<82HpkU2c8JuFYkWuHV;i-mbTSX-(N?4Y{hWTu{Vl}t@v2{ z^t;>%F}riRyM6E@*H%Yop&H$5{lePx#%#N|*}yod=hZm%Ae1bostvSux&;sU&B%DS zvOD?T!U9{UO%eU@nyM+}7fOn`Y$a`v#Bcl{wgbtJIs1640GY}L@r96m{N(zwb&%)h zrH$E+Ts8g_zmtE)1;b169j3ksG#u4eBeo+Ri?o{&~G^-qeHvL}6;zpFC9 zs!DqbaY6%rIoBU^>+9?&;e>sSt6Wfd_Yub{VJ2;1jn`*rGvJ_-1SCwjAt9DNm$S=@ z<$_9rI$p11PC`95i`|c@&Uj%cyr2AVknRbOva4{_>SL=cp5(8vPni`~33Hb@fPota zdKDX?gJEL&eU52PTF{;BSe_Lo!|7CI|IrDUVUI<|l|uhR_w_MQ1YvX;d5X#KFm!cP zoLLHQxG1b`6nht|7s2LaOW~Z(U@PORd($WItjJI|$bz{~hS$sp=AQK(9x^R3?{xsR zRT9~9Ba52P^;pb0{pjIk)>r7I^#j@{btZg#Wn})RrBRX|SRR<>FXjv~^ z0(gtn$zJ-8D<<6L3-bnQ@knOAb&fekYP02#q3vbe^bsbhTcE)F0UXTdV5GX)PM9t3 zLv@r$%Ar@k0^N{v*lTAnQwPTZHCq)pd4!`feT11>#`NPR`coaC_B;!{mu#pi79jr} zW6d$wp+@=$G~ZdA(FFJ!?bb)YkAAzd9aHeLil)2=Q+cE8m1YJ?VLH{pulknzclf&a zzIsZhcS(y$iA{+~ed=v0I?+i1&~u#?#b8P!)Vgf9k&>*EM_6dwgY45ax0pfu9)5P7tQsx4RCykGmS# ztMWl4*|-UxvdO?V4}lhUY4bIB;^FcFxw-O9Yhm;N#&jBx#VeSra3^UCx2sNE0=JJl z#|OhJF_AA1PeL14ke!EDzLFGXA8d5!_aE-2Tz^@Pth_ z4*}&+0=L6axCf2{bA24}qpRScbDhk9e^y7<&F-Y%;T;nUETZH8h`kx)E~f5jzzTU# zLA53UbQ-bd=uvM5I$#+D2`Cp>`K8V%60q8SAD_ zRxY9nZxG1qAL(=Z)_G&RM?B3u?VwS?dHwzkN~n2@%;s95vZ=CegTGDcL%TgUdJ31wTlCJo=6uweiuiB)u_{-aUnPtdvJ3lBYs;vXhXZ9{(u1hGTVrpOAmf;8 za4A>}JVSNN4}a6!K=~B{>cAq|P+`pg=4%Bei?dJ_497hhMsEW{&h^~>M;xP%MSkHXLBc5f3-}HQ4v;iw7+hVj4)w?!}M4;%p+L4D+gdB;2o#5Sy7fRD>?3EBhirWkV|KHCNy zduj%o8mgc3P*EAGRRp&Ft9DAip}j@#BTacFFO)aobTP_l;9e%mb#dqJMz407Ap#9J z8Lwpvx?nxXdU~IYIz?94H2g_#SH|R`e$+h8|%$ zcEvj&=w?tKa`uyKE%+jGP+@4wMk7~UhFKYK?dZeKMb2LX9oU5=g$bY+76N?Bf7o}~ z5YvZ3BdIf{K5emvRe&3OjU7~$3}IyG>t6$F0sf7c37nxe+JiL2eLN6YD-r$KFxl%#&dCx)!Y{x}F!$Z}%C z_eurg(t$Nt2yd%A*r)k{JZny$p+}pCWWtRy5MSAtwxj)UQXc?K6pdG$2K-+KpxO4} zbqqm&cPBg?3gQek!@fI%-B5t(2+VN{v!>a_oNrR=1$LidR5Y^d-?VC4XKgKhte>?}M-D6mf@=~FrZIe7=R1^PB?P$`$Aq3B3> zkXOYc>Lg(`@`A5E13po$fR?@lY`hz{SO<&X`1ugkz3)<%YK6?Sc&#FExEX>KsD*CPIk-W@Vb7Gt>)`N}I`E1kaXU8!c49xy>|5xF z2uvq1C5B+oW-J3AVE<+ip ztj$92a{;rS8GUeV0tMj3}=$sU_X4a`s3e~ zKz<*FyKM;2n8o2(a|m;zEWm0>=tI{<-gyLGL9gHi(ii7(H1>-hPyPU$%`Bj9>cKCi zJ=Uxo-eXVr^xcAI8iOq4E-Lc9%r#6ngMfZl!92Kxiar+KQDr>f*CYZ9E8%`Aj~-}q z%m^}(1bn_a&Ta;Z`7fr1{eZi8gJ+WmNc%3p*#8B7at1J?9f2YpfY)*mj!xT=uMCAV z3K&ngVPen$euS@hkDtCAefZ;;MfXK)UWwD0h^h2IoP*1FXPJTBzJafCLpkN2H3a)^ z4*sSgFbp@~Otb(0@AfUu$zR;`8NN;F_;ngFdpDeNj)PB-4V{K!r~*r2j$I76_iXr1 zAvj$eZni3TC3P^v6`1M3fk(rwZ3>>%PB1&FSQ+=*PgFe|7+g_MNEr%-!c224Fzr>) zjTLa;-2=y~2hP`N)F_=0=}jXDXR#{o*vtb(LCIRN~AAE1;YfsNXP7`_RC)B@x8oTX0FKXsJDeXJhHN-5UzzQ| zb8f|6YXp`IdgHk1qe)>>02x3{Xu2)Mx?qaPjKHhxiQE4^R&op0XEUNqG(0eCATmc= zl@Obs8s&fj8jg2-#(*y)*h&KWhQaV-tbp35i}6E$uRp*kxPxg81@^xdcsi%S2N((- z&^b)uCc{CoF(!M_;1m@^UP};dH{kvWN1QBzebWMwGZyPmom4`s&WSz#81M5K_TfTu z6j=Ftv9_wx`cZ3I3`x_kxQiD9acshNx=HfMW)pq|MChLV(XZ^q$50(I>RxL zMgDac&#oHQ`Ydv*%ec`GVg|hH|9ILF;9x{z*7n?70L)xoM30m37CVjUOcygdG*Vt6 z138UnwcTq#AKHUaN{R>zbMX`tD;C-|R-J>VwDf2j3 z8XbZDFAdaxeQ>fCnRBu8dtx`n{Xfbd16$)8_=F3=7vM1!KL{tZmWaYJ;FdnY>-dWs zdj`~gr_3gQX$f{R^sAS3c)Ki-6z zQy!)jv|?kdEU0rH!wcyZaIzDy5t3`+$xsb{B{J-aQIM$>9;+leJvzbZ48UBd(#^dboB*VZJ^q9TCG`xZP zDio36tksn13H_yy_!pI|-(VzOv8s}qc(%ov9@Zt?VY9)U=){<&7u*Q2;_!ZYSnHVK z<|cgo8{FWVfU2ns4qH7;D{3*ZK#MeGsz9fSVrPyZt+DrJ;JpmOtM8BA=Y06B?ZJME z`+wH%WR`(lC4l4g8Zmb)^9Qk>V1C;HJES>ox*N;@k^>pN0pF=w89SJGnWMO^t{@)g zL7oswx+0Pm!QXVk=`aw@4#T5rk`)i8NdwXjXq7YAS-F{d@UL5rTXhI_S9!z-hKWNR zG}@dASFp9ynhY4EqrpM#WM;#jt8Z0hOfYIIHPn~U4pLk(s)A@q~A2FhDbWMcvNfM+9LLtdv^!uV;a2}+ z29ZtXQYI^TXwG3C89LV$o~CxK4%m~S%yexWT<)`IbB)%h&C7!Mw#kfTCezdAU#x8@ zy)y{~PwKGM9`pN3dUM7>+cN#NoyI6u)dnDgwvjAY-4Jm3>X5BEbdzu%vg6zjCT+EE zBoYj${04FYa+WD+xr`~+deVk@s$Hil^9biNleLhYV=OlxnTh5R`VJ9fn)!hp1%<%I zOa*p0`K`@DwF41B%!nT0{O!gq>p?|wl9aP}Fc{OALBwVz<8Ii+?A9MJol#pZf=*{E zTGOfvh5qT*N{$B;ccOX4b{0S9kQT~6q=ghjK-!m7GnX56tTpr|b6TIxBLG)lq-8R^!BzN@sFe z-9h)ECN4wf`A>^o^=m?Hd6U}Meh4|?iohIu4d|jI0-eKgv$+KAC@RQM>w|5A5<;%) zZRrKF&$!D}Hh;5M^o!&xv{#~x2DBu*R4Z?GVh>t7%$eLIbhdb0f}P1z+&K-8($3*cKk6E$Bs$NgePyzV z)a=36_3=V;X_tGs9LZ0TA97P1KmVn>O)1*BQq80E5I#$@LMlj%$r=SH4xxW5!*gHx zqv|!ShBd-IPcN?>AWC3QJ-wW5sC7zRPO=-bonxeV)>-<{tV?f6BgL*-&Y(6)Ui2(p8g>2OZ8NMsw7yb| zRO8U!;+m#xa#mGR#9hjEvkN~ISSCa3Y_H+Fx{-zjIyiO&1?hTVRImyZ>Ww2VZ8TUT}A*o@=T9rZP3ml23cHhq37!>~D0x!MWZf zrbRGn{PbvPcAo#atv7$x(?)1!z7?){ z2MPPh2#wftrj&B$*Y5CxF<~f9N?0L|3cl9tbG-y7lb)V_?s??9r@2sqU1GL_;z3sV zk+YqYS1%JPrc`jw(*M(^+WH1g1^54V)|Fs&G6tB%w8Oz|eHBa>AF1uNho>8kebRY* z4fP8;Sh2<=`w6Ln{zCX_byh|?ZUi2Scg;c4TL;xzsF(RKMtL^anB{vNDWzt&6*lT? zKkT2R`L0(InF(n(v0F91{>tO5F5 zrak$t>LKw7N5jLp!pY4dGbjFam$S7?To#mBOW@y0@!Dh$Q5PL!! z?QHGY&TTPk*jK0L<;uAlCpHgG(Z_3#=^Xf(ZVOy@HAjWCCh*TRE<2 zd!z-<<62#YXNSs5!5$eV$GYAm=M4%sUdSnYR%x-Tk=Z*?*Zxyo&o^coq*V_|<8S{R z9z0Qf%a64;{pk;G$!?TSvp2N_X0YAi7npbU?AkkeS2?He3W`;_nA4$Sy3JlY@X44U zR8!B9^K5w|S3U79l~v9kd{wP4-S=gFo%3K9k+z@=~L(_ zro3&f(qD}f?(s{thUO}5lh|IQ-gyRVdxHsRN?J?dr<+NCVb)X`M+$4#B9)=SCvlNc z#~VwdKp!q`79@3T5#|BrmhuRl;V7N4M7|Pvrxk?;N;h*ITa^B1e7EKrtGUATl)Qx- zC;XK^S(4GwG0+(2FC}ynll@m6Ii(RoXPPVhml(@u)whsefxNav)Ixo&bp1S&ZmX{o zD4_PEpKX~_-&L=L`fHIY5H2!LqTeADC1F$`r=1z zo4cD_MDnmx=t28sxvRNcGy*&LjI-Xyu|&F|tYvbDPtiR$^oneEshqLdHl7LyIvy+$yz_ctVe6Jo+Fq z5bV_jY_Q{oaX>z0JQ7n(NnOT_hu*^+lGB`QmE!KR{h-ml9}L=E;Ooq$ZtJw(#aJqa zlj#A%N5QY;f%VIX5pL@(^qjyGEO9(&1|QHo#5p^m4_K`+4tsUQMKZgNfxWM22-bjAu{J>Rx4vG zK$|kJ%wft;wW~g#xSTD42&rkE(e3;*V1xQ;)08DfH`c*~VP4?i$}nf8T#C#!6;A8Z zlw$luYprnFNGIdfjoNX(l(8TGqM7Y3NlO?l32{Ss9 zoNTx?MTj+)Xp_y-#&5F^Kh}IrhZ?)BIL>5PXzq;DHy{s6HIvZgOA}U;Yu0?)REuMK zkPs_bSfp;Y<-%MnR@uzVgnj)4Q}GsY+&$=d@NGnE;@Hp}3Mn77GD zdY{1M+-%CNHyuV0rX`egHZNNTXg*Av#-Y3JH&3y>j0;Tc|2R4aFgwmD3=glC$;P(P zG`1Ss=Eio?*lHTvwr$(Cjm`3M?x63s_t_>{+@1O7$2sT2v0e(Tq*=je%Z+ntccCWL z*HFb{lmvgOVU#f@Nqw!`Vj^90zBAHarG7B}aR=g>?}Krylf%j8z0}q=3N7$sN+Y7F z7RMR0aa~w%^;0kC<<0DJ6!_2y9F&~lZ*Oy{Fdy&SwOqvxaE9j8m&&95IxIAWu~Ay=3A=lv z@Q<=so(?WLlDNQ2ZRRWJ*-v2&T8uALt> z_zF_P^H8H-OqJ-6QGn;M%cx)!FwPi#&A;?GR1(e14rZM3(#!@saFY?DN0@KTxzpfqj-N&eh{3R(g# zij9LYTGwDB?PZ_{Ncl(YuJ%K#984Rm2aJFRyy^RQb;YL9#n^*b^k+sBRdrxc4lv>Gc|QLNJw)k2F<8! z)}~I@2!v=K&$%*{^O95sTMDbFE7hl-^@R%M4C=PSsKH(bqj(Pvdx3t-Cj#5ynk@0J z^ndon_@;RGdlq}PxGQ-YdoFsSyp24QJZs=r&-1nn{L*pam!e_qUV%00kzXol?6(}z zt|V6;S3Bnf`zUof3dASqGLn0~_q5L7hmzB{z<5`|NnZ`-&<7cHK)>>dQ{kf4hRqzJ zlu@79syQ+`k2!lg2RQwXp^k!%;`SkSAFQ35=!Neqvht8~p_w?DT6jrRCKu?t{()Zi zXW3_7u!`)U`14V*zsE175E#!((C{TFmzLu6l0YS>xcH4KYia5tt?_-Y$o{g~_($)q zcMG0Ey_Zq@i627?{|jFZUy^sa_n}w!j)(t!3=Q!L&qhynZ&mL!Zwqf9?=Rn`KyzKS zc8gh1Wgb@g!@%FH6joI_5u|aRai(Q(=uFVLh;s2}gbO&tmw&dtinCTGPF;SRX# zEnIYFgVg;C?q*U^u3^40KH@DlHyF?sYS#lB0wn^wailoun}I6!hIh5MvA480+B=ak z8Nt5_dNX(@dp~*G1g->MCwuM7&6$mtlrk{#)+v|OQTCmVI?mkAcXmrXi^eCfP|uo! z(@$yh1`bb=^b^Sn3w0?<&n}uHIDnb(#+1=kyqB`b`IK==j1sGMv$^2nZge_b9np6l zv){JeQIi#UQFL+I0_$-CnoXPNq1qe?^F5N<WaQQ$c&$;!TPG~gS(k}r?X^zHE7^i}nh@SgQdb^mlf^1SwK&`KCbgbq?reu09{ zkUpUGnj()^8`-@&kFiRT< zjW{VyPT>Q>el(5i;dp0(vv&+`$lrLRypUTc8yMwK&VoYdKIgEahDqbVJICUfG6Gb) zHMqoB9M{T&%Kky6_7g5%vy6RoAMtA|wMCj+>w*63H*DhcK|#wG819e3wWP6smp>e> zX70dze>wkq-&Ef!Uvb|^Uw3q`t^E@NQ&EJ!x6X>0=?#>Toq7`c@#ORCdbQr4lG{1dycb!ZmuVzXBfrF*q%s{Otqf1E${zv*j;TK4*N( zd;{@&S&LG5O5n0~MIUO0P>=>N1g~}NvUU`Wor>GjOX2Q>8 z7MyW6{OuGdfKw~Ql%FuLpThkqCA~-OH40UtLl(isZ@_M>10Sd&^W+&Et9$tCH0I1b z4cfQ|AHnxTHBG@6mw_3lf{Bn9+_oMuk%mK<4m;xk%#|)iJ9fX)Mo{1X|2{J?cmpk1 z2DsHbgL{Ktf(L_rP?4O~j%Wk?^NjKmH4D)>A`HW2c-?9*c*^KKknoGG=~%LzL|&FJPmfm4jlgKgA(PyA#)He07u|t-h;iAM!F^zgp)T=$|pCM))0N| zNAEU_8POX@o4&*_)zD%7jV}BnOuUlt7C*w#TMa|}J0m-v_~aCSPC<+R4__IH7JIV% zR=mzTmc=h=6O7N|jDJVwRs`>;h8Dk=ebsdhjWVk{Shpq4WPTD zLAFMLo0S3?)roc1TZ_S|T_E@+LFG>23u@8xW};~^Vj}VPVG#C=?C_Jpr;D1s%*AF# z(91=JLJm9>UfKv)L}Tb^R@nRo8X93K@LF=ivS`A}nrC*0%T@`zWhqF{1aQU1VD-CE z!lV#Bg1j!Kliqi872KD?a4QyC`%qW!#0Rgh@X*?edbEb{!4!ycJMi_!R$F4RsbKG$ zL4$Y0PADr>`=3XW9qic+hh_m-a&1=8dl(wU(dYgR0=!)u3Nl$1{COv@jluw8yaUz_ z5bOrRNHTy3#>Oku01scC>}d&`;TyagC#seGaATZeFIb>sK-wZWBR_)keT2o+O_;;F zxJNcJ0hhtAph}T&)E=>qoVV^;)PpQjxQ^1O7>;mlxT4;mcWL4BqyUMJ1GByjp7)0k z%h}mMm;^ts0C!7mgt#J^VQ~N$;yGyXdRE~Zex`=73oLsZKa-zKwj+F&EX4kgVM7)G z<$ecx-5lP_Cc%Wkae$G3!KwaLIE$vc8hpcfbOG@T7eK@4)i$71su>A{px~ctbvuZ4qtnR zTmep5W^)I5NeZs=yD%H}(sXO3SP9&96THR;Tt{m%={WJJS<32QOoS7%(wZUm0h4?K z^P?^KQCHYfN364SRq4(1zhJ%=lg!%mnFuvU;%EAVXIl{VL@bU~De2Wx#9TxFgiPe1 z7X(Fo#T7TC6OKbDZ)Otjn}hhBHuj2qV5Tq3hV&5m4sJdGlr_CL92~l-b&Z)ajPaRF zUz`9A2rq~wa*)woM=N~Kd?7tB8VY{cRh_IsVtrT?$@y~;D@^=umJu`Za}CXI(h_)~ z(Xi6yTR&k;sKO0%BAriG3Js`4l*XmE3Vy@K&5u?Ku*{1vy4D+L4vik-e4fAEss+RB zopl;^)JqVgqQn}aeivTJKH(gQT0_YSrV($MC+LQ;K~8C9H5W4*8WXRbu=-k8l{`WU zsU+QoE*PQsu}&v%_M5ZB7M3EGg~6b)cRR$sMuhkYM#*>cCz#nFsT@!7h*?OyDioH+ zv-)#W1>J2u7T*Os^FFEN55iFKlTl0RAQZI9;)#3!ed!USw6e>>^U)l}EY>+MQ_**<~&V|@d{M^9@p$$(fwV0cDWHTAv8@l7v;RLQMhFC|$j;M~7h;+>`--uQ9^fGvW zynuNQ25jtU6|9I(((K zp2}V5e^dg+{uI5Ykly?oHCVtLBmFXZOYhLj%o7I*i|`*<0{gWA9j8L^@jhc77w(v8 z<-f&ZlGBI~^kBRYLszVGdKKw%u#@n_s4eW^^xGf~XV(}NTxWIH2MW4*NGyzl`(^VA z-CM4LV7(Bp>E}!nXY$K9lg*c(neUm`Tj8;1pyS^yctx_AQFv%fwCL1{Q%O1;w{vm6 z)|2{~DPaL^6!O9b{fCUQjY$U#x&^hMK9Rc9a@r#Ykm185`=&(66^M z%L+Ti%w}<`juIn`l~zc9Ss#SH`e{5lj9_JbvV2JxqO?&p>w~l{_`yhN6f?fCE9AHS zviCRtvh6XqTdl=@L^xf{V^r=+bLTd8sio#|v4gdTnK0DcA~b;?+}4`P*`LVC`4S%c zbGQ%T)>Ls0n%I&=`svMAMwrlFcIgY?RNNDn8bKqaJdXb6*`&*6EOEwJ;f!&^YHWbW z3$OHhatA4oP!$eGTeBXxdrqN`v`Vmqx7K-L1rra zccdOZ!#qv(a0=dFU8N>;vAcx7*I~1k)}Q)?A$(G9%NK29tnX;g#^9Eih&%W|IAbX| zKc)#wtQlqnv2?I9YW~5nvd%H%st6;^ma@biAC%JL%N;J<(z~!X<%B^L#p*5v`!yW* z%qH+C|A8T!(|TZxHPT9Fg;Ca@W*@<4Ze{1WFTXL0k*{wvn-JyvLm%M0uw@HUIol)^ zXMgJ=oYgbnJKumcT2abt*0PlmhD$%K8^#|}5z{MNFw;>PT_QG=zleF17gi=A6;979 z1x;^9AHt*NR#>GYq^5YMpSK!F$-YpZh^5k7JS3J9UyN)xi?y`+QfF=<*oecspld#1JrOn-+eL%#D^F*r65>ts zU!sy3cow9S@{7R2xDw5y6Jl>zI;&8rIHd!EfsS?{QP2s_2Tgd-3AzD>&uRSk>fsE# zSJ;6Xv6?m83?r8MU>%|()Gt^>Z*bZlWi2<$TVc$(`M9mzBfja!`Yg#g-cuaM%+3XG z?lBC>#nvRa=6!^z;y|vZ6)`}Y|NX#&=5#@%e`-}6xr*aX9}X+6r-gSiE);3Sg4QZ? zjq$@6EaVd}OQP7(^5R|DfEoM^UQdju!QgjueosKLG7MhfdorQa=wTWX4gMwq^uW4| zrxNgj^JW*_(rPdlmcdkPMum3;Jgw102e;^oXc9%YBl4qvpLG=u*C}hBkN|rsH(m{G zIpuQ`e?1f0v5#Khod;84DIydTQ<3*g64%1eeTh$F56i~6ZdmtVv2AA8t;z|v5e9Ax z=JPP_Wr`T!E+612I~Bg91O8PwS?X4lI3@BV zMosf8``uPtmR6x%%m&MK4sq9ly&cmSSy2>sFwT<+)j}iFhYpSf%sF_o=fyu^JZ{FBP?Z(b&jmY? z+aC>{(n*Slniq0)yHU_gr@r1!{K#{eD1WDN_nADgpu7X^@-nd}IbK(IX&` z#);aKYd=Rt__~xEf6yxQ18t;ESC>(#7*0RXylP1tS{mBUsD0JHl{!j0xrQ7m?nOr#GMH1seTUNohTnf4kmNnxm@D zznq61q4v&dR<*iX&VJ7E#`V%w#nr=JLAfLRYeeg3v>m}_!GiQ5m=o+7ysU+g#lJMd zh$QETGw}Usi^EzO^^+}NZ|N-V>_Q#4q*_pUg)^^5+AP;ps;Irx^4#?-WimD7n@ST^ zQ>WPq;#N8t-`P3JBGkGCsZ30kx8iItlrB!?sdJsD7V%ZgCS4ReNsD=>T4bI@$;- zj8@-^7tJ=oiGdFOO}+`f?CAH4`>Oc{`=bMGgH4SOoH{Acx{s0oBwtVVp`5sk+3f#*au@)TbkL0w8RK}*t0jhHm%3S;!nyHg*WpM&Jf&MwSGKQ+m z2I{)^a7`)-hO>}*R1DF{5m>F;smVmi*Ch>xeHK{ZpYWqe_F(-D+j%^-&St1j6=@Fm z${ZNkbM=~d$YKb(=pT8$5qF%MeQ!e8? zJkqR;PkJshmb@mr9HQvTXmzY@g1v=fDqap%U4J^aJ381e*t*!xtK-y@O1N@Nnn_h} z4O!7-^pv}(o+bS6oqn21%pdYGyh;t_G5w?V<3*cZ4a4*KFZ`W1qamy(Z4pvg^^KRf zm$dTd@fGn_^|bK3OZt*zPb{18EWT~rt+-%Z!?;uQXx|;XG9e*$MEu71IdK~jM49GbhNZ*u;;Rsa(r;aI;uEp*dmq9)ApS={Wh_lprV7i~-^O?Z{dShd) z*+iTxRaSZ{x0FBWBizGQ&%TNJ_y#b7$G8;duy(~n0ng=a z=1KPZ?cxgRFEO%-MrouxLB1*PR%)m%)!AxpwXYIOt!9gK1O|O>!EHtxpMvGJDF1rz z4;*!CClyN?og^f#N~#neAMcHejQtqDHojlN-=6$QBi$)HSK{9$h9}q)u6vGqcN=A- zPRct+G)@vX)RS^78qc=&ZO)QTk8_88DZ1C2!A8M$fxSLeo8(`ujn*y(25Slqe@y~u z1DW*WM64&-H?lfkIB)y(%Uh~InMRL`Q16jK0vJ{Kcu4Zoe1QkbVaFY^Vs&= zCfh>QbZURllVXko&b%QW*EUy}voLe3y)ppTo;>PjRbrk+D3zsw9>! zeRl6e&lpet#NkQf6ILZiiDeUuB}KYZyLC@z&pdaC_if@G_vs|JXLR6t@RxW<`lMb4 zzsRSI!q4Z4y20MZ@zZ(UX4rIv+u1@*)FsGy*)(%J>8_{;jH_|gX2;96Hg zh!vlrL|=;1^O_u|HnlBu9&t8xWp|}^Y_WGyJ1Z6*2=PiS+W`AwM=AR&dbr=jiFi56 z_T#p`_RfwtTa4OKxr#&6PU*V*4kyuF_#2i`YU4mwocd8KaTl&Re^5zptE7<2!U$eM z#|yV~MJ@mi`BCa9-QkQbh?3$Mb-K=A4Na(8<~N5J*Kn*au6+%34va#7U(s93yDe#V zqAO`M{w39t9wc>3dY+UwX--lPJWa}bX1Zn1Z+ByNnD>A$Ch)_YEYy=0%e}xRic^nn zsEz>x8)ZLZuVG8WoW2D@*vcAcX61A~V(bSCn4~=k%<_x=f`MLIwB80M?{uI7SA-kl zWT~lgPfcTA?#S)@;LPLd?=0&y?UU@i@v<)FN*huwq_^vi!(~6K9z#9rR&&@I**d5T zIB(apPi90L)|)f8F&OL{9Im zzN|QsYFjM6jjd23HK&_gdEo_|#nyUuuz=-Mw5>pnKsEnSZvgfFXio{xT+cU8YIioz zyrdfLFwcRc1D*lylAcAE-n(Ae-ypSqh(i6PVgL6k%OaDRmql&sfY$~zNSCds~9o*2>olYVeq@QUF*R({izMt?rSALQ4+!5rZH|?0(S$AL1!;% zAGov8S`c5*zj3oVjH^gcJ02W~bJI=zs6JUgp*!fS@yeWIwWpiKDq`{Vu?m*WVtkzZV0>Z!Kj1>L0vhNozWxN zi9z<`BJx_kCf}DI$?th7NlfcxG*>rY3jNNoETK$Q@khI!;p2ItFU-1w{8eJ*_6=-`@}%?mOzs zS!N3`!zdguuTkYsY0TFv>G}0Ebbt!izXew_iXynm9_=`OHS4rdT0S(@s{@+@n*t94 z5nw1&v=iD}EqP8(4<-gn>W@KQZSbYv(_f;ywHn>mP;!AuBK0EH`V@2&FR07sms`lw z@YDZ>gF_gv2aLmfl&=4PFE2&|Q4^fxK1%z^U`Jg+iZY-j_mfvSaJCGUzEd*|h$&E9 zj$!mi@V=$+5zHmkW2KEonO_@EQ5iG|oP_A+JyG0Q(E16?oIgYn+~hv%`Yw3ME)-Dv zm_fy1FZfa6*vNY_qgAX(O)6PYYLXq7p~9XL45}cS$8@Oqo{{7IjW((nD0u{Ez!wmd zR-n+jahG001MpESh*CNVP1-Ykt-eBEf;ZG;{+y*x(K}IlD5Pi8|J3{Pna}uhR40R& zWb8Ix(8Fi~3Yb4|jJ(K|)B|-t!MHUiH+f56j_#mB`UissRHWlpU+%I|G3|Lj^6Eu8D}O^wMlkH_{sZ-C?-p= zBUYiJQHm!}0erbR8SHGZ?+aX$7cFoK@am!{V<&?8oM0uqZDqo%Q#jJzq? z^F-94&8Z(}Li^}qOkRQU=pZ?_P*gZItZFobbD!}*n) zkAvt6^RcqNp$n{o!{9Zl40%{rd&&3fF?Jvm|F5ugc*EZ3)c;LI>LzF7cD|=HacKvt z9^Fwk*W|=Zu$mDIWF}kB0M46-Czz7@&n)`m%m>FR$8XosH|FMdbY?xS<1DzvZ+ik7 z_?;DU6K2mksxdeD|9*T!UZS~d%d8yD%ovhWb@l&OR6nm6Q8Si8kd=IaFCDrV%2EGyFyiRB%rCisM1-%9|DG zP&dxJY~{fH;RedPJj8j0VN%90K7}~Fj#Kg3fxhx4Rq%$4Rd&vgc&n%I%&bZ6ct6h4 z)48k9=5ZKWYndSqtD_kO3YAv4DBeIBTYx9Jirl=hxr`X1s@RrFd_SWup0SOrcq+5m7b{wG z(A(7(d$AT%n?uN5x6n7IwU`#afQ73$?_}jPH1} z2qRLQJ*_=#mCmfl7pR_(qx3JueB5iU5QD~Gah|jbUGNjeW*c{Kj~Unwmy;y3pD^E4 znT1zCB>n}L+-%;WFU&}s3zxu-y9%nb*SKLsn+p5KZK{QDg%d_m>x8iYgt`do!qU{i z7MNYk=C~0|riWD|uH);?%|;FovsZ1T0#Y01>?I<2HaWu%}#(>wiv zzBO3Qh}F-SO^is<1+95OoehnOvIM_1UJ`G)XMW`^0=dm~DB*JJ+B>z;LC}g6hosDYJ1R{#UoE=q|VOi)}SO%u= zE_U%Z<{nP{OyFhF#tl8SUN87ayQ-bn^1^UR(E8xK8RKv6PxM{)WyH66lm8aos!IDa z`f7Nmdj@%0d;9voYaPu&!XcbR-+<=4!|~xJb^n5j&6dhh#L>jD#g+<`9Rb~8@T&2QWr?4UQ*o3SUQqOM~Tvl5^Gq@%+-NhXSVO78O)4uE#sC^eq0 zA>lZ5Hp97PBEGwul!-8zc7o`P=8n8pMdFE7xC5oe&87@KA^6?12cPEtFNtMwpczLl zP?CCKJdtONC`qqijDA6J{Ee97D%H78oUZrT`=`NDzQwsK@*YcBC5K>!7b0pZN>nwH zhko#Som z+3!y2IpA5~-=qyUbBLjGTcrcl!D-@j5Qv@VTXH(4yAForc8zq#DHG*9LJ>{}%bX3% z|D)&wGwN@g*Xjkk2D$}j1l@XDd|tBRsD6Psu%{S`Bi2iKj%?zazC=Bvep6e}M`eI& zSGF^I9^rTvMdfuE9S-^`A#zP|n-#|?y_(KguUU6Xq+%^ao~&d#eZXk)GpCY-1Ic~Hgf*OSXgufu8w##b!Ww0cG^^u} zaSMD|#&^9W9_qK@_T3JCrQe-K4S7ZIuV82G0u0nC{`3BZ{_p-R@R9rATb$b0)f45( zM9-%S?t`9w{w~@OV;Da3`NTNR&0Xx;!@)!{d2*6#&YT@^e#U%;_Gf=}!% z#;TukSt&u^I7|6MDI}|sz*#w(eRebas1Z^bsXb@T-^7Sv?8(c7tzr{qe_iPqu|z9) zH=Bep%w!2>Vk!ADam{X=I?A#8Jrkq;??PhXB3D3sMK0f*m@no3alj-xs;1zK+d@Rt z*6jShgVrfx@Dt=PUL(=?gZ=9RjIVJ{?;_wQ>T9;qVkq;r@t z9dYchN>79axD}U^W%&fP3!fPMe}~~N#HsbkY0`?{!Br}7YF{IFImGKYOdxpXM4@Rw zl-e+l>Jq_pVHua@wu^hl}(^C`d+?vfKJ(+@Np@LhiXZ(no6MO zcF=_Ns7}h^rtZeGB$UoCrRaLO1QaP?Gy+9h4ukB9se$Y71JR#?J5Cfjil&@n6UbF| zkPWNMlO`Y*C8W&o-m*!xa3zYS`ZPqEi#t&x=1WKN;^9&!#;+GyaxXsnH=iFyc3qiV z`3`*cWiT1n;h43b*mW51?}tNiQ@BljaFKoL8YjypBDC3HZ>xxBFMy(*MjMcdy2E>_ zML}@3BV^PYshAIBL`xIH++^q9$a8Efv<|4T9E%yrc-x$!l_#AH?&PnU(pN z&+?GNJSM-n$~fI2!&`2yCzkJM<}m&ACy9ZVc!(&g7Y;~yj0kw2c5tID8+;hOXf*^~L%|toa7C)!=;5 z_}-Ju)b+&lb6E-FIm-r+SGVC+gKVZQ9xE+T&J41qa#Bp;tN#+|_Xo}Yn{#6npP2=A zJDl%M_I}Ak2KB?Vp--sFPuEA|Py`lsexkVCoRocdpW&>tMMVCy;7pAs2kTF5a4!G) z*P2F_Fo~>n6&0k*WY!VnJ=J)&`N*!bvj1ixo2koq4I*zDPwqQ}x?6}rZ}sEe zI+E24A=??qeUE0G){qGu<2~Q-B)*WX2YD`T?(+<<&EV|p-hAgUMzI<1)(Y31n*6OB z9|IZP>0HloDh|ocJg52FF`n8@uHYsdjmKoV4(=c|*;pmUs}dM_veVEY*leRYzXnoW znakZC#qMzZr}@1im8?)I5z~pxr&txGm-H*TZ{J`$WV<8R5U-e3^~YeW>*2i*#<8P; zbY7e$%rqMtXY|;gee8t!65eUDM|@vqM=NIH9@MHw z*xfthmy?;*(i3jTbn=P3R2%}-PG?H%s5!J^m8@q^$t~U?^4KW`$rE=l6RP0f{TCed zrSLX(lh3?_IhT_fQ4TVH1)V`QSROlgb`O}xeL%Bovgg<4bZN+Vr-c!n?8*^+gUNPL~_p!I}0 z!#Yla8lf7gnK4kx@jUIjj{NEt5U`^K4jS;0q`qkc0poxU@%hppuFk$BQg;oFoEQd8HXH>%VaU%n^8D6yk_o_Hx1 zHgj&Wyv$-&p|3TQXy_i}@e1W(T6jY_*;T*bq5K(kY$~xDD^-T=aD+W?IZjh~m@A*@ z4830LfZFRSyT>rr%2np~eR!_T8N+qB;4I_4Jk0poR8mH=+qLJMOhv6HBaDTmaES-9 z|1T3OGH+AyIXaHg1E>k^H`_T0W2wE(u`Zd9h%Hu`a4{L#0l^9_!XW?dZ3zd!^MBG0)Ukbo7 zPf1=nn-$oB{rfvV_nv+AD7B>d>}BN{!M|8fv8?o6>?xgL2%qGMv>|Ie&AC;KIN%Xn zk0eeoS=>(M^%@mKPO|qGWT1jrpOx|l+2luI5?qUA!STY@U^IG>`cfWgDN4`fc@E2AvD(9Gf)sW*GeGvSjtO8jPIr|Vr7i*C>Gtah+hBs;3V6V9NNJ;%xRMtBDw z^s3nuXSlWOIZMctqF|_wh7q+9joWJ?hq6?-qFG5@lD*tuGv~!+VLG$!r_q+#ot-DP zjh%NkeLnkJ@2Om!r8*)BsrkQy80Is&yfD!&+!Y5x*Z!tnH zf-nC>da<;{|9l*3(tqfGn*&`#1$@HG2v;NnoPMETQa@cxwPXRY{0;oQ7sB-&g!k1T zIQAXbxxD5Z&bHIs!!@cUQ>hpaqpF*RlPIV4me_uV^@>hh>kKbASWc@EF1n^!h#Ke- zvp1gjYYZP1Z=%N15EuZ9$mU*<89zsd5h@KLpW7s+L|t$Y ze#irPo<;5L6V zyWXImYfGk_gS>bqRpnz~o*Bh>p)`!A5bDaS#I5XtUdC<>n$y~x_&=G~yVxC{vfnl& z|4zwS`H2(p3(-?gMk^ic|FmSzU5Fs38!OCn)C5a_D5PT+3^LZ~tMylUaU&3%t9J(L zIvrf7mkTZlt_|Ycq}3u@tcw@!TTnrl)>bQvO6ZpMTPq(Nr0)caD1;&^2UlBGUQWei zuF{aIj2ql;zw`^lDhy7ej0cTK;g+JJh zuM;yJB!*Z_UG@}z{!9J+FrK(uiQhL6RZZliD$d;6hO%NGnA~el(7o*O^~l4I<9RoQ z{opaD(QFWe3H1B9Wi&K<8<+7+?SuB`E{?$qKzrXBary*fl#z(D{2}@o^`eXC4BgVA z^q0Zm!J#n2b_eeV^3t&^I`}5AFL+$b2MSvpKiD!vkUyxHIHcp`ebv;O$~pBB>bg*M zlH!FoVoGaJX3s@Izn$H7CQRtvTR+sFp4=2-g@S2NY4mHVRR)U2kz@jLG1~VNe(ztekWx|dmatedJ;sZt2hq!=1tDF-OQf3#8p!{4<8Xvw&Fy2XJusvO5|)m z!)Q!L{V|fAV2hQHdGvvL@)^#FW30SYM3YCDQ-2dxh-5M)Ky6~cvEG^0s08})p}a!= zF^{b6FJpoJQIFQA>(%f&=)`KQ$p}>d@yn@a(NmywwBwY20zVeNCI)9Rr!s4+ab3;{ zMqN{@L+8NxTA|=O&8FAU59n^Axw#Bq#STPL9i)x&cDmKAqhn;GdX+A}*{HSTl?S0o ztHUY00<=;i!hs#c#~rRWHF0D=qV5b-0yf}t*B&JBBRwq|qXxes=T=O5TGmqX!iMw6 zUvYpRAm@<}ksr2|B8j^;lM@wUmZe}{`pJp-1O8-YY6uV5slIR)p5}~AR{Ni%N7oDT zy(jDguZROKG4F2i|M`q|W2!Xi*}>8?$L4Szy)s9@DK3bwp=ceaF1ni7M=_Ux^9ATT zry3LVt9p4xYpnhtSQW43%lO%z$IG^HaBDDvo+&AUOTnHCXnJ6JU{au6;HN(=t{h7O zd$g?hnRhd`n)#^db&}r8|0q|~TDBlOx0d$qbc)J{&ik=aRw*d2Mv0n5oImx3y1IaZ-GR%L7*Z<6 z>I$(DjISQ{e2&!4jn2-_!Oka+pY{aXc50v3siSzrdpz48#I2L8Ta3>GG=nLvlk9!D zI5E?LRoSQ&G{jLLk1|UwY8!4Z?}%_taei_Zb7gkjpp))o$5i_vTc~XvXUS05rUO9n zV{x926Q@(RSSRKr+APe!3z8EwAs@ep`qnE>q<)yp*qO@CngK?{F0`{JV5#olZt z19a3pSK;`dCKJg5%6y#sq&{l-Nk%_oiC#FkJ~#;XiIBj}KtumI|0@4be-Ri^E&MqG zXT7WZeY_XYfXs1gp^@*yI1|edQhC?dH4X-RE6}+I&Lb3|#`Jc*^*W(Zyn0 zVoXBDcq2X#-`U+YK9hS)+|0zm?uUW5S~epaULm@xOGx^##;zJ+_Vg}GSe~dRY)c&DZsiMMkDFv_)-S1=+FN~Ruj1$y@+P!m#PEp0 z5s$(%hK~!);oM<+A&n6x;jg#_mkk>_#IZtV#&00%_jh!IYNDov|F_ou-to-Q(b?24 z+Z!nLq=2{&G$%^>OPb6Mlz<{ zu7i$7A+C^aVZ~i*oQ0K&Vw5H5r+lB>gOfb&0q!lH?EdZE-M*TER(g!73w!Chnb(ys z>|XfQ$Yv2cB7TO4gqI1M5c0!evqyu8U9eUgMfHERQTh|yaZ5-!aKV_Q{%|-$vN?-~ zCxl61+r#TdoOQkq8Kt&SUWi@g%j6mDt*)>LUYn_nZR7_xI3MPT2b7)CEY(ok+m_lQ z?CESz=?GWP9%&OC)znY)07{TXQo+tE-LVqM8E+XE&9gy!Ft7g6pC%CQU*%? zhrHGh#=~lf5%Q?h)WRWc!nDw&$OU2Eusoqr&QLlA&^_kL1-3>be@44XZ;2I*t zi`42HN`~}Exd&F$99(g(EzH#hK6|`l4{I&L_DUW{$C^WmrruXuOZTZ`Kd{L1w6-=CShbzwNr zU9c|Ly4sG@sVbc^fZnv7g=*$l=}-8Lf+NQ9r}Mo%!Ewqt)iEaYsN;eCv`cZkP&zwy z30agz!gVEswOpAcmp9*u^NmRPtJzW=A^cKn$anGA3Bz$jwEZJT+wAri_-3WFKbDWm zIn;+@IoTzRLC5qKSFokQ%G!ZIlDCOJi?ww@U+r9%EE`b=?xvez*x#lw)D z&gZU$%6@sFyo|bvL`fS z7cvSwn~n{yC;V-$u`;U zlKuwcdLoC>6REXr3%x9RfYt1@=aMxdgd55s+eEnY1Lec&bnUT2Wm{53Eez89OsH-4HF{`Y>0s3iC!u%V48A`+g%ZP)a>Tdzy)X9Y z@14IJCq9cG_$w-5XY8k+Pk+b7HvRE3wodHiUt8iLeV+p-@Hg0GO)xqNX{A@{vk)^R zD(sQ7gfoM~;izq!WlK`$@jPUqjZq^wI*{2n!<*=9OyrQ-Kh2lmpG6Lx(YPsep%eLT zTY2V6Ggsk|vtc75q9fBpo(+E;HaToz_~FP0;mg8Dh0Sy9wO6#&w@0ua`)u39mnzjE zx>cN`LMGv=vk|3vCDf1`#qDTgJW2`s67{vx*zrXP+FIFt_HuYN<+Y84BRNi~AwI%W ziW_^-nI1^vcl7ED=yuqqVhq>>8m0IUOslXMl-8;e^o!BYqPV7f-*|^t9 zv*Rm#-|f!#xk22KpRV8g6ITCP7^}qCfBgPc`dk0_W!~S0$9fwGwWi}6lUi!3ri*+S z{y99_q1s~XlIjIDHbPQ6??_p!g~2YttJ<)@BfXfmJ5bL1E$LtP9?whPkU&5$LnPLS zT3cQ*6MB>UA>Tve!*7Nc3q2Y-KJ;q1711kdNb1!o%SUDp>FS*3co;e%q-n(bkT&)v z%6N0UR9oDml#=_2<)lu*mBtml461Ipmn!OMaXXcm!?NhOW_za&aFn&#mFJFZwz=vo z=OSBYb)juPeY|_gv*f;F9NgJOaB&=Bd2@wcT==9fF}@i!wGw`l-pAXLg1(eqmnXq{ zETL-rjJT5VZDT*jr%qb)V_DLqAH#pmi;0On_@hj0kJ$OKhqP+Jn%YM*j~=bTs;a#YA5-lt(m=sy{ogTYiY>f@MRIzQ$9$QH)Yny_u-err$&5?7@oo& zku&_Z^O>3;%j!aUf#y-3;rDbq=m|t?&#jL7eDiOkK8}Tg(Na8Yz5rW&B}S{G<@~lM z%5KLEb+av(>zFdozTM_k3fp$W*KZ_Wp#y1ptA^MVW#Sr6xV~DFKj<&yjo~T-^z1#A zbS^$FUXH(%crm6xT>IE*G4Ag*;=ld8{QG^3@AI1oiM~Esm?x$8Y4C~C z+*Lc`ao7`kS$k&nfGb->^OP@A)(EffI4^$^A4sygMrkVJ%jhpn53o+|9R6DVoW8!^ zgKnq0mixFj;9sHFF%zvMp)Wk0?MhX9u8=>&J4U3AJQVRdLXNtXVsPrSsh6axl4@n7 z64@&}TSP)gr;t9byy`e*1}+09t-1P2b475BwnZ!Hucn<1EXEDyhVf4L6)b|kUM< z5z!IvBkDy=iL4S*)-g!^3%|M_;w>t!ZOs|McD{cCgZ;C#1L*T^1o!)M_*VrU>RsuV zkU~BKbLN-uRNSZTaolmOcfN2=v3b=9dj`i$*Qn4G&gbeK`4s4JHhhql&^@(3?31d& z`t+|HZ=5rh1@ypP?;rk~WL68cNuJcc25xWCd3U?SjJ~yrqug=vVq8@G*_d;26Moins!}UUrJoRMs^y&PBEE+X4EYLc@DH)ABPz70v!o-9vfpSK z%!}6BqE_5oFGd&99i9l^bM3i4(vX6`{0;nRgSQO3RaX`jKdx-G=^I$Z7U_KGbi1y( zBu7feO2;>+Ep%V#*U(+g>W;B!pF$l`j#iFww(|0GK_Q3z7p{vgh-MMvQ9wo;qk_+u z<~^rrI?U|!R&OZ{P6ru^BShh{dSh{YfwkdcY#lzqGB9{&8XmPvvsh(c^U;pDl48oY`>^O zU)AT~VfGybWO3tdV5fH3Kh@jaKg!d{_r*I>YZMsd3H7gV_x7)$3fb1`MjbnsRva&c zF5*-5iTFmIj~?{7xK15!8>i;7Wt56Z)x}Zp;dd)Lq$6NB%XJ-A#}xW%R5CLgUju3M zcfojLDg2KpvxKqJXoz=y6%;5P;Sk&)XYD69R@SQ@@d&%1{3e6Hfl@w6IjbBZ(|il# zpgonLb;4${$^rOQ6{N2EML(slq#jrden=C#5T`|PX9EEn1FF-As`g&qe=6QnHjup~ z#Ei4#$5Lwfk>sbxm6N*Z9P;fZFtyjwji!&_qT+wR@Njh_=*uV?70l}NP3>$vM!Oad zD`K133s3&nsEub)3oZTs5Cc5wa!@;Q3Xj@|6wMMImZgA>u@#hS*U! zf&Ta<>iVW2#2!$;^7N3Js829$`Vy3a+dv6Cfqh0kX8TP&B~^9ZpGyx?QNIF17{6`soGyXaTmidjgF+k69CC;^@I}55S*Am??gUpl4~uj#x|BL#aPQHi zjG{hY63lJ_C~9RqZC-%iRHN24@c*h;A&{Pbi5;SN$2I6L*26iNOf4#f*b8Q88dTCJ ziT;XEb6X5cpaCq6vDCP4P;;COXCNmzSV8LEBe|At!Wa;h=TjaR|G#xNabqj5JFZrlmxH2NC5j8tf)-Wq+tzt_UP|H7$z z4p-I@!YmX9`QVHn7nXuLwTFH4N+`+%B>9{6Cp#cgEzBg8R0&sIEz6U>!tsM}+~ z@^^!^MZj5_4`#g>?0N@^jLEQ+CbApcL5Yx_xzr2}%^Wa-M^*@UR0lBQp|BPvGt;i} zx&yW`m(O$p@2vwPAd>a84|J{vPqrb?Gy}Zs&u9@oGfE){CGz3g_6N+9Xc+VRxaK$X zg9?X@mY*tdGb*R!8T;h-8wNVrl_wm{Z>+@i`uWbs@MRvt^4$oAP#gY5H0WS58>Tz& zR|`bCAG(1JjO`Ba*27eT;^2s<1}Vx4+ae#xZWFjTeRYsXiUw^KsI(L_Y>*&w5U8TZ)5=P4u>f1*^TzM)SLoh3<_m*kUKaW{#56s%T=;;e`>x zcQk}qHih|e2?kXpRoY4@VtT+`xB+)21g1^|<5`T~HH+19U`Pc)xt#dtXBuc%oiu&+FYt8*XrQvf6{l$Woqy<`Ep<9JPFN7=;q?O@M42@aAZgo0h=qN-n!Ro9JObU3du;8kPAzwsF919sN} zHcNf}>;fWCn{%l=8DBM=KpOHh#bIxxW=DDrl6QvZGm+i0B%}9~nLUK(n4KN?K5W7< zW_$c)Q?avuH(sNfx(x&UINZ17k2B1OQ_PH;cu_p1uk3dt&T!Cgwlr#ore;4nY42oa zDCpK2vVx|-7&y)|NMnfau`M2G=ByUnva59+iPzQGAjnPqS;W`bwT1#3LW`EU#-#yP6=x9PO_29)hR zv*0JQ0Pv9iyP2U5SR;jbDs<}LNp^=(vIy<+4X#22{|$jpUWgN`t~d~W@O)yz#c)^F zu!ru3t#}fB{zeeeQ6LF_!CS8k#+r-MN5RqJGEZqd^Q1k`u^rDgKQl}LCBDTzG?P8H zGI552Z_+ayqz|&suHiKorNSIOCc*XT3WD6!tjTFwi~f2=%s*I%hVjOD4_o%N@tods zZy4=cjCmr+a4tr>HwfxBde-{P)UbVTIK@QJq%-Ue+j(t7i#L)du!=o+CG&1KukAz@i}{nDCd|@t zJeAHop+(Hk;f&yL?tC79p94BPi@&X8yWn{Pxlx~;BzPyH^I9Yz&_fM5i1OimyI5TqL{#{xyyYW=Czp-IsE^sT*AoA z;G9j~C)%-^YqGNcAbKdo3`hlkz>B&^fu9-1zNm8+rAAAYj(_E34%TM|_Tm2FTC<-& zVyrwotu%~QZk|_rR(K1ryqZ+hOM&u*5j(trCA*c~vmHO362=uCj_l$2aZ!B%A8kG} zsW-l^8BI5_>wC1G2aP>+D_nunaXnq1Pq2Cqv39Sbc6`K)a>Jw3z}~Yl>pGbWVgKCc zB#nakGXP%WQ!wdlJn8BGpEVJnv@Kzp?1h_RiQ#lGt40jp4_)jyxZ{Iy!t6!=(K_f6 z%HZK(q6LXZ1$v!vT?@x&3Uh7{<5`apEy*sFfn6vicVwq(;)OqOm)UoT-*bs^9Su*6 z-p`zY!};xnc;#iJsPjOR!{Qv%T*x1Z0+iYy1hwBFPp)2T? zy_WIY0DtGOamTn1>n9~Y-G*5A5U6}wc*gUXvp#%GTJsFgvp1K66F!0c;vN;V3~;sz zljT&R$3hR-&(UaWdooI0@g8W;M<>`zt9X2EX>aA+`0F?Z=Co1eRn;mqA0@j0B| z<&5*rZ&=ZO*s;lQEuviBLqrBwiTaEXMzI6g{UDZc<^^ zv2Dbw4dBn+W6aFA2cpBW(I#vGz55mBMO90+vdmQ2182=Y%*o~eb38n#c4k#mz)d;R zc!cig0Gjp{#x`Rid82E_Cc|UK;1>Kj^_^p^fmXKl4(zHs_It3)n$nx{FtRrg%dG;3 z{s2dZq3pxbrB*nhPm_N_KXnA1ei{!$_A7qnEu|qmo<2CZ4^e(n<|&(%Jt#jeD1U&d zo#OdUG^~t9==*O=N6_*Oq_6!&g!l(l7^V{|%;9Ef^IcPdKUUPdX`D0q;*&6tiU#A2PmLbN zJN(o&df`Jo6L*Cx#(jJhCQ);JGi>~UXhxdBUK|1U5YJjxrtv7vd!bt2gu~F=@N{;_ zd6Z)CPAb51>OlL9P<~ZrDT`>MWvF8wDSyImxvB{26-840YDKjq4*j{+VpJz6tJ;dL z?8mqIbL9)AKH1JSazD8?M@`0Sc>*Ie9dvvLNM|>qfU>lYBtC(a@`&gxh!0#g`@Rdr z?xk5dLc}*u;NjgS^4&_z+n(sPG_DD2Sm%ntR9Ou_MJ0>V2Cm>WYa;5gOXfuSPkFPr z`4~U-MKCYA!oMhPNXBuhfL}rRGDjb&f3Hv0r|LuXBDw%pY(!p1tjJf9U->x{SrOS8 z*-ZtAk@!rMGg=!n%)91##_cwH3S+SXJnrelGL^}PERzd^cKi)*yF4yB+tlUiDRnl; z_ZD@TdRD!QO+kWIMe7aj{=2qPyP@U4PATFk=PBhW=&9Z7+K(kjoo+7abn)F&YakF`Q5 z+{(OJA3F1_0=(mgthgtLO6n5-tz_3;lx%-PXE$rRfEP53F!Kkq9{uN)al%-QV?+}} zL*2KX3IK=n*<5`cs;0;3p~!>CI=C61pu-v*85n7a_U}}9V_1sl;c)m`cz5_xcwP8$ z*dOT_sSGzk)cfkM^`DIgW(zV+RY2Yf?PPt1jd0TxzUao=~ z5Bo5H*70jb-f>>vMgLN8{P(D$+JatHL#?C^QzxopU@MMQzf=3G-FeoM-lov|M2@lp zC$S%uVJMX(YHUoFzhTa+&iu9%*OYJJx_*W$MbxYOEHTRldf!x9aTROqNi^c$@(kZ0 zb^)y68di@xgOPX&6=}NF&3a(|hN3SPr}rXMji^f{ht|f|@DGD{ zlIVJhzF$8~1)_QSGCdDIERA75CDP-X=~eadXeAp)Bt0$iDDqjPPUKZ268ThL6jAku zk<0Ks|Drz8Cn&fYk%6$OJ5!%L;Yg{0+y&RC(~6_qQ%lpA>f+3jNk2a zbsXqws2kx%<_7uOg6By#th_POAXeLaIGUd(s@u)_*@L;OJ-$-E@_ZPv??_^qkBLVL zi}zS5=CL-tLdQ_vy@e-bw%y6zi8ow|IR^$^ZQ~px?k6};1@+~OuOV<)PDOS^)R-96vPzjx{eU%KmLK7%^-6Vdp({i!f-kjN+Hma){+skrfh{0=_yL8&wAa#x&Wo>6~e zJKWkh=?L1tYB&WACFI8L%ny3}m25FYTn`@&p9yyl3!yHd8KLu`?BM+1dqFd&IsRQog0aE2!OOu9!ZjnM_0Nsm<`J`% z{QvX&1GzT2<{OwHQtid6ZWU;bfR8 zWx&jD3P+_Rna$Uz-emagdtrq-_@Bg6mEki~B}?cVb`HA^~nS&!$6Y3Z>#uF%mezj8XL;p&Sd=se` zDNikrG2!vyjL<)!`ynbghAM_?g@VBz@a-yw)(2xlbwX!DeZ$$|vaxi z{nUB1)`yI@&$UsEyZ)>Jg;}?Yo6f6_xiA?>8sv4>NC1C-xde zt&FZjT^-1?`#=v)6TAM82x6BID~^CKa~b6R9bWf~t8jsN!Rs~2H^vJm;gSv@PkEO( z=yP(a&CqUN#u+pTB)f^_wU**kcOL%lYvUyh;6cU^82P_a5#}|1N|*HUsNiGuS&?`3 z4Us~Tagm{sAJHn`3b%~RiF_0pA9lm>kpx!D5|NhrHhqX;7$I}3^^3jAu7?7~b*~F` z;0uqHs?%$xGbeQ??p%*U-58~w(vx^IUH(eRr#zAeqP(4ii`^J{)qN12;mqwj)I+q^ z7urN^HluKmwoI#`RnYQlg|rFq`CF?mlxy^<4oX4VY%x250MXY~c1+ErskB@+J!u~i zYm~>imCR-a9M2?iDfMP1iRT3mZ0s@gqGWcGUq~Wz{S@i|^g<2R(7nvcScv$#vr|v} z*(rlH@x5N2NOZk% zPfs-VN4lD-@o!|PepJ7#f1}?b!Z@Hu^#3A1=%PNEN+7B^9M{r{?3}&?)9mW}1;SCm zoxuLSyYvsE;s>g8^pnT3ulWQ#rz3s~`INCl5VIJoiS)eT%wE64Os%0*hCSVz_DtZ} zM)^1q=?%F#asM27gt7rA#9tY|o9USanWeW;iz^2ujmd2Ok(3CIAo^%!u=ikZu z|%u^NL2ce4fyIYt4ip3@XJ(dXth^JBZK zx!-*4?4}=S&IKJ`cCfAk^w))ETPNFmpIzqH=9f->yS%jvru$O6n0pp4^z|^^y1_Bs zNaf*y)OzngMs7Y>+gxETl{((W>+Cl3TP0?<9VThJPJWL54UlxI#qCVVN`;&IsOZ%L!X*-9ta|Mlg!P<{^<_e%xbm-0(=Cf~an z-VC>>x;BES-*M_no17fCvTWITP=(aQy}u2W4<89P@KJ9hHL)b=OZ$K1h|k#5oz~(Z z>U%#jCy0IBM7s_9qSM~$9qxmrMqg>SaL+@e7+P8(<#zDL{UAq9@>G4L0ibLDYiq3c|K@2#{^~P|z zzi77kkiJ4Y@ksbvxs2b+dZjQK{?*cXyBZFR6P(gg3w%6kiL&*=%~Y#6@$LZXxebtO ziwCVMZUZV5oD}Yfi_AZyo=QY-=XA3ws(IDw*~f%^>QgITnH`!h*q%d?BEksgvhlfJ zk2El^C=ab6;zp+y=Vudl9u@N$al!-XQ{yE%(I}!b**Y&&70;O?(5Ggaot&SA`Z%4> zfrtCA*xOQ^bK*1m0yR*R$j!OVd@5#)r0zfo@q)8bylGw#2T1Ll6ydh?)Y9aEIPRx{ zm8=mG#H$>CiZ~Edvph;hSEOc_lOR&;isU@s|bonP~ww>Sn zgweJ}kA*#q_On7K+@x*$HSIXx{KGjSw{jaxNmPw2E9Y@LxXth&?kaQVTghd&-_D zgIh~+qF-L-`pLNUbSp@XeN0>=9E$2ToNv)IK6f6I-N>>|%F8*%uhLRn1diGx?QY^) z*mA4HNwA$42p3Hmwaa!;&}DWXGLmQU<{kw?dO>_*wG)b1olxqJ#wk1tFMwyvmb--4 z&PDSE`=3s>4~E7>Ayz66X8y`)>zd9gbDD#ujh)75YqR(_SmrY)7M@FEccFDz7-T4swZW$Qe`tYurXYqpnySZPR=)h6-N8$T}uSCCKlcLBGAv zuI@YrWv@q0V>A4ZlI|XIKS8jKRn$G{>{$4JP%#C)WtvbG-zU>4DEvqExiS$ak$=a5?oa&c&QWvk6*;@Y!Z|!d{CEuf z4;<(P8tgAYh@y_2t>}3h(GT{foqMpZl>ujnR;B68n9arY_h+}9K-_g1zx<|BY4Wi} z(T%JHDci)OFR#I5B|dG12CSxZ0kpp48{gPwxSD5^KUqq~v7e;jO}&Tg?9bvD=^To3 z@~r~;Za2*>vyLDK&t>bpV4{~me!N6va`mwxLw zw&-NpKG6G5LEXnWRmgicC+}SgRB1C+F4mIK{10Y@z>%WpZ49VieZJ36M!gkC#u;+7 zhsk9}(TG=MWkTe!s*$n(1SMcS(1@SG0a}v7|CH=kDVzkkH zR?+G>V-{rpy@$yEI?iKPh;$#~Ul-+s)C1Y5jApeuNOcw3JCQxoX>w=Gs>B&(xv$5; zlh5Heu?OsM9IT>BxE1AhU!$6TLFQee4uZuPcn*i)6nf?VK-E%Ep#0&S#3OSPKO@01 zPQg=+Cu8)Gzs!J7@_;sa0J9`Qoztx#*lT(90GyEBcvjpaV|kF)TS+S}hktS$#X~&a z)9-SZ3W6q7A_FuW-2WW-VGG8>C{U3{dB{pbo=qxiE!^WlGQ%KSoje%Xg*wxpV$`+LQnMt{u7l!LK>p= z`ha^~6oyS{GTSl;o6PlP;^%mU(G*1_k8|DUXuCgP^lWra&^m|lemny66D_?Roaii{eu0l<3jgO~w3g@Ch4|(V*}1jB|6)4m!V;J_^FXO~z@Rz>RujSx zG8Y-py5M9}=xv*sx$cu^&V<*Hi?LY+jJh4Y<8$!0FNsupGn=;oXN{tirO7$GOV&9L z<5m^v!0d*C#QO&Oqv#RRX{AR*J?B{8qHZ6Dnez_vd#>V#SVX)2%q%mG5gS!$enqVH z6m!2poa zE@1Qx!Hg1U3zat*Cc#=+R?#aXrnY3vMFqYgXjGVn)AF|tbWC;>C0Bdyqs zp4x!F`iPck&fkB+D4YmlJAotafvdBJ9(IgAcaYisG`Z^iJljt1n@5{ZrTzPH^ba{k zRZxmDv`1O~jk15D43U~(Q+4^N&v$jAPj=^w`l9XZ4j-h$8#Fn3?p67lB5yb@0hlY< za2W2=FD~#n{)Pdu9OlM+`rU6_^-uJvuX&8XSz{Rg_ouDC;4zSSpdWmuf%LmRJo^-8 zMtd>`op3+x`sTe6^yNu>ejfLB8O{~a{Lp;aFS- zNBa||#U0M*F+7W>@GByq-5PzT0FRUHsMJuGo(mn>`tm-0kwqioikkRO*2Wd5JU>P8->e9Gr77=pCerB!v*lBM`miQ{13%$6 zVz{N?UK?TS9$-~D%bIeB$mS(0P1IM|abj3SO7Xo_S-0xD&2XiNdQkP_&P-$;S_1cH z6FvVjv0gg5WQc>E6s)({4*YA21NcY!qvD1K3wK zt$mUm-cWkt+qB&+?(`D+#^sgJtXeKk2j-y07cCM4~y4VF4_YkH+ zEPhLM$%G6dzx|7}kR0wlTvmc`r!ysiot1-kWd@Gl_h3+8#ew89+1oWZkSs(mJr-}v z{@~P2!A0N2Z9|ewcptZj^llId?SdydnKk?)+RRU`E;{B%)2r_@v#camsOnxJ9_>k+ z#KA#bgXhC@>%R3n*!pi)L+eZ248~Z)t=VMDH&}bDYp`Aq;g`RL=igf!tscBHz?u#l zdR6N&p0rW)^cMw2g|5YeGd(QUrACPfG70jwv|H{{xoNOU_GuViyH$X{ya;lKXDTCmW+Z ze~zNDskE5h+U`vZ8BbK`1A~88d_fdmlrbBDtfK2s4qqZ-EV2F8P-qMTP9;yQ(vMP=-XT5%o>xoZAh!S;y~$NaTv99`%`J1u?>1l>0gMVr!=P7B$f| zkd}7(1HH15XsiW8o@ZoJLp~K`VFFj%-6#tSAydDuf2z09i-7>XuV0Tm!`HJSxL;vv zhR=$WVL5r=7jhl)!#VOp(28GiN$t!xHKndsMOv;U-<=5K z<2G6Gk@WY8;-Ab4TWOn7>}&QDAyjg|Wp2@7qfCI461FrLKLNWDo{_269Y$7btAh2f z*?@{!N#J|uamQE|`7N?F{8PAExO}*DSPy*_c0%6plu%mm`(TBf4)>`TdXfZcx~Bs4|Z~^|7)})wMO+ zWo@nYhq_u>CRdb7i05#O91RcX2A+@?V1t!nUHcBj$P2t!R#_{}{zeu3b@UXf$rebi^Fn$svNCu45z?nG&xn&xThPYzs&sT!LrwnNNWe`%lJdq%qicDqfft<;qBOIdhzb#cf!n=g!w z`fH9`E4)73Ba+9cX}L~b$yEOH6!Is;d>D5kKAbQ%@p$6R#0v@UB)o{b6065N4O9(u z_K)z5_e!3}>i0?mTwq7w`Pvd~QYY(#`LEf-`U*ekQbIM!gAQpGeXpGNGjD%yezg9k zb_2dbyiyB)lYn#*X7PSFc_mS}EXPl@i)96Ko{U!aEfIcnYg@?+-N)?lz>$k=`C46>RV`@?dm zXwLfVb=fzwJ;857XCl|l+U`)fwx?X6W!!*-Gl}(*{!6TyI5^=|eEIla(drz>w!PKLwrS1T&-@#bkCyd2~Z?nAzmzmCZ%ZSXhoy;@fAS3aaX>XL* z^>C?h-q7L75qnsR$7AU1LHSv&Q& zgHN)CWE@W0mtHcnR8D+kmwD9fD^Jtz_!bAm_+9Z|C*Dcil=wwrw}b=nOXIWScEl&g z4@}S!lH(r74)nj}%~D2-xT@MEam^`;_wI1BrTrF_SH_BY zO3GC+sZG+t#O(>$@p9b%VwO_-N>!SnAUjKivqkoK=dRmTOcVbHN8oS#(oe`As^6>6 zJj1+)efNCd`j`9L_!Iqu{iA#Ze09Aewb|-^u;5tSY2GJ)ctE_%sxVd>5B@uf`n&Hr ze}IVj?f+S`@Onwn>*}o{GsC?@`9p(q(z6d|I+^n`re8V@Fm+}4q*Y4&G_Z?%`zUEx{5P>Pyq|kdfa0%Y zR2PZ)Cy*Rae7jYuikNG2PGwGfH6ZoG%ZjO6UrkG2l&yzz%z45#IYztXdCRL%^KOmz zyw42OkDrlLFYmqleF_|YOU&Oq_Z{3b?rEkx7e>}G{0_I_RagUlMUqlL-T{_Bz+J>{ z=x?fo?3Z?`KY70lq{eoM>k|Jo=J&vOZ$up}A0`8RN9r!;D6!g6t+_T#y(M>)-V-GE zZ+MlpoG+bha_I-CcO)>*7rQ^&VWWMdTc~Ev&a8b||7O`)53+~moXMV){ax0!jEuCB zX-i(UN-g@zOj(&yH1+<=rpfnSOn)vUZ%kQ`{%dHX^Q*QbR!M4~Tgg=@@w-4jUriW8 zI~fme`+kq97Z@DatL~Qao8N~XXOvF+H}%(7yI);NpP88$ObO30f1=8HKBckt8`aB} zdHeaw`?m+q#_mekoohhe5A(Ope>HFJ+z;cs1txe}DJ^h&{996#^XgSCOZ!oqs;-oe zi0eUedf7Fpyi-qTEpAr^X(Dy=&IFDJ(*4hU37#bFbE3U;QIZca2cJ{xz&`mzSug!0 z6mq6nqV`Q5rJT1A<%Z4w1PhFQbDcf(%a36X<3;dLLHE~wb(YTQ@ z9efQvch&8lkGx%3rPuiX_0E)!y2T=ua)N2EUe8ZGk~SiJPxha|BKlL~p_T3w7j}d4 z2edVwz5d2ASK_xN7`YDRKAiW1d`t38$!q66nWQ9)hp>dnA_giVPj6Ml}14ea%lj&e6V zhObEM-nf@>^8Yd>4iz^bF>Yb#`R^sr0>k;@p z_I&)Yxaxt~+FAfq~{5r#21O%8?z5TL){QZL3`@^8B( z`SysrL6|D7l#|qkn(Zm*z2vpMC%s6msD!r~M$H!W4xEy1@>kMG@hCanQtm*!@aEbd z*nRD3b|P7vrsP$Y9jI@(Kl@Vl&)I9U3uleatdZXFb^TXLO7hF8 z$xo8&CRa#4oV-5e>$EMI-Sru6tU8#UT_z?e(8Qan4pLue>wJI3{F#uGxFvCW>;!Ky zX{=c={3#J$-C!v6FsO0Y%cu5Ad6lB1{{4D* zX1Sa`#y-cCW~rIlUg{gZhOb#jed$^2UmN>JydHlpZj}F5jf&UyW4Qm*&3aZNd%v~I zEMdG1j}6@ljtqSkX<;<5d*Q*jTgs7brMRYeH~9_+_Qtl4-;^*W(U(*t=}Kaa#MJmh zaeu_T^gG^nJ*N6fISRkJsU+b}_K9OzZ(9MJ4F}s>oWH3xbVck#mMbKcl;1`L@VRQL z_lQva>HsQ?Tt_oLOL+yut(s!sFxyY4=q5Ye$c(NK%D{^42v^|~C!hV1bq+;C1M?mn z#hUsrk>cSIq2}-nKh8On-6Q*A*44~)nI|$9XN=33pHVe)epcn+f1%>$6`OoJ8G)Jd z1F0bO<2EY~G<=K#F)?Lh9{J1CQ@)VD6kmb%9YsA>PFU&0+oR2Y_2zm*7*u7bqY-D1 zvL&*I74RTlF3!jMW~=7)D!yF)SN;Mq$6`K=T}Ui4CT3(H#ed58J(1>E?`h9Z+A3wC zbXmyj#zp{mrydF!{uf(7-dnA z$$el(Ix3lNAPj5BR*nb5?#-OB%Wh>)g+blOJZRKmXY-N1Gcq#rOL%PfVd!yaPG~`> zQ>aqtW)LsCoB=uBoQ64SFc4lCd12(XKSiZi6a;f5`MV8nx{x7FRFRR<0`hw-K zNrU7)YEX6Xg!mYI;1!r;Qxt;7aSm=mrl&Hv+G5<8UkVCt$zRE9$ylBujFL%;Iq*xh8D{CAQ;PSvfba1 zrPatmO(7%PoUxhBE;$7))oNy|E{v%6VZoKA(!@wyUN4}fa_j^$RE<%#H%IIEIT_+< ziK8B%kklnrxN|Xe~Q;AIQFfxZ%z|)S^5aT}cY`{aKMNoDbsE&#fJ8~5>zWIlc%D><1QNDEY)Z_zr>$g1y$ z4Y!yqa#Y3H0@S4vds_!=BZ{*f<{I~t(ci)&TBl$Ls`anQk9P*uY6_at$?XfO(g7`1 z0{Pk}WM_AgkzIv`ag8&dEbv#%wQZdaraM|`Z9IAJ$$aZoCiePd@0Wnutp*Q{=Juz63hcxQ_0b!pEwiwW>%3Nnui{2bOrZCDwD6#aqsD?Zdty2`2 zlyEvGK?dUGd2>7!#5kL?d4jg=K3I43+B5j=FY=PNdE7yR_75n_-<;1e&SekRvX?Wu z!o2%0PHcDI{QSe;N6#tzMzNd%7W$f_z5?wq!9)cxBZF_%=$#r!H4>&bKktEhH9_N7 zgDZFs1z;Taqz-+!HLsQfTar23Y*f%L3c)0_lj0jyU_Ra{4WB1BSChzReEeP*6sjnw zSp&vEsW*st4A&g3BG`zwsLxT$^IK6cs`9i^eU2a1w6_3%8%+B}wef{HVr8=YO*lqF zo<;jYy*FnUMdM0)fZxsFtbGXgt@6qaa`OI73Ytj3Y4nDUF zEbedm#x0I?lgD4&i&ONVXa&RbJUh>8(fb%}`)J>Kz)`O7iRk+`cqQ6?ChtGueqMa@ z$=m#GgxCFCT>xe7pM2{>YJ9|S77pi;Nsq|pnZ_#$_d?-H!r*|9c_$tOF3RVS=~vl& z;{%R)lcT@l%2VJBg*fhiJWJsWuJg_n&gdn_dd;i17)h6T{VCu74?izB=52b{DXux0 z&)Kv^3fKICD|k*T2;93^zRSUwV*KahZWiY+-{DRb1^X_}?^Wp! zar|8~dQh}~MUPP9&3!My-xT1}akNA%M@!(26{0s(UTxoP2OW`+}cJT0f z@XhlKMv8~t;(@&t=9N_LMK=GFV&&61tz)9)&V`!RqP_3XLKg3)bB~{KPj2$6$(@bX zu-yi4BieJSiMxcwP9u8iRnYl0L>QfDsYA|mW`Hkg_l?ZwRBd3!AHe6Tf@JRIYF4nC z_~6y-72Z{POEd65`U<4%l+y$5`v_^C@GU&1@lo!a@BsYaZJ~|(h?QhJbrA=`Bbp*? z0YjY0U8B+=d8JnT-U3V;Mj~_fX0YH#AiYaC=kcH>Js6|)xS|cTRBiYpBfyQ@(yphN zj|03DRW7#WoOaRc@4&n3!8{XBPk6?h{2ud69UNL}@_sh0mWOq!KlpMXM*lPX&%1#@ zS7gm-z$os>zap{5Z0_sl%okAw^E8x^J;C!U@qY{c{h0R$ygA|sj`}(H|9H^W{2bw1 z*5_s%YZ_}>Bl=YX)}WGnrYZfRJ##~I&hZyMD=@261se|Vna=d3Xhnpo95>nnhq4De z3Vv0CvoFA1D9Ah>?GMovY$dI;)BTbjJe_&H9Y@uObG~E5R3;7^K&6gob)fFBe$K*H zRJgVYZ)S&Kuw1&Ju51A#)Fe7sMBSI|_+K<7&$AbH#-Gk3*0DTrG5==O>d1X5?S2$x z-SEx<`pP00S~KY{O<~+8GRnL#XWrtNLzoMua6gy1sZM@j0bIeRXb29na(u{|#tg%~ z2*Xa$-NVjc&hog^kr7Z(xNZLi!(%#8z%9owv;x7e?HspLxi8n?v)uqsajZmYQokqu z!ZDfQ9G6w8sWgVs+RW`mb(edz;s)0LKb@}ycwN-8X&~CHya!+m4iSE%+5pu_g!|5y zyjGEM8(n!n5RN&EsVCZj5zvt14yF3X6=$W}TlmZAOpB$_x7Sc_W-s&TG3G*(iYet- z-Iq8QnFm#7h`Y{Jcbc;R*4;XecHFVu#r7=WJLdxP*WXTc;j+CF{NuYfBcq4lCGPsy z+2;<%!GtPKj>-JEjQav}ii$vYi3eH|%{_OQb9M#A$IP^2@ESh~A`~mWau4v1NS`=Q z6x9%>#ZD@(|3y7qs^p0A_$Mu5w!pWFib1>Gdib0LVdB=I_kG13tj4S}fxdE#yS0qC z@0!yIZA7%X#t6pj+c*<7bfaz#Ve08za3(MhpBA>jWcykeLOcFO+8QcRv-+ zN&U&&?-9m}*QqM`1G=TF@+*{cBcyfUvE#w82g5rY$vW^2XIRAjgQ&eSyOcR(AP(}F z$UbEc+{rW4Okun-yANa+-Icw@B=Xq%q#CH05^>htDBYocx`5Mf5bjA?bd*D3fWAW= zj}gpX3*ft)U_PJiOkn5NnH@s{tl;S()B{dL%q9MQ15D-4XucX_@O4Cy}l6{h*$M3y}H@lQOz}G z9jd1KMZ@iGfXCjoUo%%8Wyd-WUTzom`>o}6)K;CZNus5B7>fmRrf1@$UCK9shzP8{q{IU$G6lM&Vjr- zo|)$__OxZ(0DHo!JiEZk^g_&bAG1Eih-Ktl(rM*g`A^x!iGQxFORCyT-h_^RIc)3x zLKaM67}vrBr?tD=IqELL+p&?g#6D%NvR9A|pP=u?=e>R8bMqAJhOQwqI5OBbTmf(M z1ymv(lru9}$aK-WwY0|R>&eZ+i6@_);K?WTP-2zU+V6hXyVqAnJE-kYBgz4_xAH&r zC9_n2badUEad_<1g;_GxtYD_L=ioG_mtq^{qD6EV~#?$Pw(hUs1K_fYeN> zjeq}0tqWPnAJtOe>RC#(x=3#2p)^tanM%61or-R>?$pP2ob`d(-so&RHoBXK&9_Vx zPiSJ(&Ae1V$;JMz2Q%b2DYs&hKlN(oaG}rV4S0X^RPxH6{Js~S|9P5tn&bCS1b>i$ zFqb}nd)UMI-D+(|=yAi%s?KI(4ONV)8(Z9NW`Z!${*VE~cG2t~a@GaM zhpOcq%*@Q{l5ryY{hSBES-~ke&BL>dsrHxn_Ql|N_gHGDBzR5lEUhUlgU4P^;D0fN z0$$%4&j;FMsv2y>8U3+z6^+sjGS|nfwq|~FoRNeE|D#BwNYh9#qTve>W9ig@U1{yI zHwzEN-{o`icghP?_8qk)o_xM{eKGhQ*YK|KEC7>TicpOf(+SAIcNd)i29|T)C))cr_@N8q;>Va)}D}Ell^OaL;Xwr9^X3OciKL0 zz_Z^oTq~~)Q_7*i`%K6s*I~xMj|N6phE>S@#7-l3HVn?&ANE=*+4XbhQ+wlPn_KLYU8r`pI*;0^r6O%h!C6^8I;*sU!9#1x}NhWBQtx!>nqef zn3HuVvr&38bA7mqQ7jY)%|REwhmZU4S^igz}B>>E#SPYR5j5!x2q#fz%=jEJj+uFUsC;LyB8@$#GfueF}uBrsD) zIiNp?&4mGyBo|O7qQHNmUDxV+Wo;guf#Lov-wP_CWP6(X{#I}KT4*c037&gOhT2Zr zt*jJV%KL@`Qy!Wwb)h`G4hx_LI z9Q-@D9?$B_*;{f>W__KVlQ}oR~)eCQbTXAJPEImM_!qVE^~c3-ur=Nfx@xnV#5AP{`Y;OeBXL!!anJ#7L#kC zj;O?X(!>7KnPGiypQ7Gc*g9Z}ZYoEDK`q{fXEs~us?PIP^n`rFy_NlW1FwCzV;cKE zikabS6KLnZBGWax*DeJubeGsIIQG_RsXx_b>Bo@c!s~%QwioT=CJD zF2NYw53*-)FUO!q_fj|dnc>xEL}o;y4nncaXhrP`PBHf#GWsp#NwNY%Y=tl4ujsGu zU*t1E4`O3~jCtZ8@9X4g=UD}du7)}kUCwf1m`=_r>jQheUezjVq!Sx{V2yS58B58; zekz@k3rVZBQgXUyxaVDO-hk?@7IW0A`*ZyDz0Crnv?{*aS{NPsEMmQtE}k}UH1j%( zt-Iz#JdCF3C5&16I{k2@mR`kZsQ+lph&VvUiL4y*yeeseGWEW>@Fn4Dh2-65pjR=EBIo zIY~BErSz88%HRv@fKx+js}7b^@NTJF_&Vm#@lgqVn%z9;uHRh zKiTt4y)A~F7wlvsc3HQYImGM}`O)lW%`?l}J&Y~P5aXrkqE{LtFOe5%)jWsb;cW0b zzQ6q6`*VW@1%2=NM|petG;bVh-yum;l7(-iU)TX`cWPQUtrAqt8)V!wyWr~`Z!FP| zn8S<)Ry7=?s@Wf!gIF!b+O5rp))Wvt$Lh(t_&mJK_*&l<`CXrpb3D{9dqmbg=Dd92 ztn}tNFW}ux%;}xFG9x8rpmEI>A|}Qq;PsRVJrn&=bTs9 zDtj8qkHwFLsqA^Pob&Ei#H|C(b;b-_>bDx#jpfFEM#m8AnNiuQYIZlP+eub!C)1c_ z9@D=z=QxF}G^>eOII>Q^7Hn@k3|5V#hd#;auaC-_ADo(zKhw_Xl#vlBoY6RYb=scv zbLO?kpH?Lzvl!v7^{4ugJW}cBU95WJx)aOfjg`E|{ky%@{AFX#di(ek*=0ZSms?VB z&92s1v$|2-7W8+G_i*V5>3XCLF4x1@^^9_gi)BMuv zV)n+H%wv3LG|*G^J^aiv>Qg7KymgPNTYGVJyFuGkH!lzY@3-z4KN~xZ&XH2a+u;w5 zgTc=t`GdQ&w}z~&$)V(oCs~y;FQ+%lDU#;Pss75B(IEZHP(kDS@E+rx$SSLa6Ax$Y zp)f?-t1gP^7APLL>d*4d^R@QJ2Hxdt-}8K>O_w*wlf-d&tR6Ohv2t6h%sf^Jm;^tC z`rxh8gSwlot!CsF>rip23*IDy)efGC6twQ zGkZw(;EZ~ik26-KjSqGV=CR&2Mwy4~W2WJLFXU5RNmI22o>4LN13P1i_)Q%6d-+d# z`}zBLey3-4P`k;`gmKO`>#+5QmEZow`po*;h&SxW1wBWv#Cc7$|8i!azJH0%;flPK zdTdL)#RHoH{@9waePVaV1_xWx~=7qo`#O?!YGJ^`z1>dNfa&yrO* zVLmqxS@-SsD36cgx}lIqa2dm8v=!PGPeZEt-1nxlQ%UmtuC7+|shgBrvIf(w16fQL z&h=f(GPjUE{>3Ux9l{(u-##!#=m(8r`dTAim#AyDSYID`8XBY5h}<&!>Hjm=MHW%> zV{2q%co_~rtwNdMe}jpkw}Wf4YvokQET8=#eQVaG^llm7qNc;HH6TEey1%ZXnG@+5p5 zvoG#(Oozavn4=!u`zP%6mgI9Dkyl^n)?){pWxcTLStF@3+rX}fj^=&5OPUdfEL0U% zBwM?$n%>Vm4SjcE!!Osm;2%~NCBX~vjPyBK%Ky35P#(P?s`}Afir>gwqla0Awf3mK z#Q27>g2!3VXSNNUHiF@ckwwN(V%AUe_jDu-u(UHm55vcZmJ3GOgl>foWcLiU$*z`D zFlc3L&Z(bqG5c)J+?=J6&B1?h0;^!na}L?-MM1bI|ERpH-PT;y)Y6rTo~l}N&j7Ia zVrmujH{}Y7^35n3d!XpabV^V)&SUq(nbwb^=+E{jXNvouP!QLsOeq(YYYS-8)x}y_ zPft$?Zy#?>Uqe_>QNN|09tX!bm5OqEydT6mr5lL?)|2702({%PR)v4?GNl##}vh=id?K5&v< z8-7qPr?csIb$lLt<{+b|<2N4}$4!m;5}H}bxXmc;?p$SD=W!MqUClnmr*?nqW8S@<(z9zaaXVXB^f?q3Rh&`?Vq0^&h#c!oomh ztvpZl$YGqi{=?C0kJ3iXVh{d@JWVZxhu|Kmi(DJ{aumLv7bIEqNPmj?V3Qscf0XZ{ zUiyL9=3Du56eSC()KXN5m2b&wwUKB{{*>dWNb`~0N|Utl^80Y{VimvIge+x587}4l zW9h=HZJfEnH{w)Rq7`pSDb{rHu$xaD>s%KeOKY9dqF;F>78CkPDPkY8l0LZ_9ExXR zFS!w}QA5d&opSC7-JLU3UB74F7ki+wF328ogEdV^wRa17%qjRp^kVj|YHzc)+O@>S z=48=pZX$9yZ~Y_GbC0o~nqVz7y16}BOH!Nzksa39NUV4h^^K&L&|kX2a8F{<9;u3eY9ualeqBIqQs*c1@{@JKMe~?v>EwJDp%bdex$W<(OgxIftx& zab=$z77I8vq_5?VaOXHeHNS_#6LF)^Rn8}kl_g;?9vTDX{?aV*W96Z|*J&g_mFuY2 zaQ@li?J5m(Hw&M7p1EmuK$)WK6dF5*(%IfDo)s3U<*dJyUUE_EEoY59LY58VGHdc;(onhX4gLU-k=Vmnuo_u8ce!Fi!%!9_X{DX)fwfAog(LA$(K zM_=uBlIpt`jl!NvsP)?$L+r!SWT%I+DEvxkMWk@gJ}rLY-5R=s5-{Y%2&M2^z9oGd z9^>q#&(}~|hOcT%+3(o4M{e$XAx*Z1J8kST&QxKtHCMEpYI48uhi-ejwYQ4C(@65( z40jQ|#x$j;cthVI{sgb(q%$88{w-Al1(A#;jHQVAw_}ozY9+txd?wU(Psq7!yj`jE zdePO*VNxP>YztXu<%Q~d_FWL0s$R=jCQeYEDIV*#^o@8}eq&dQ=w-S@!gtIjam zdDzFn4e6o<#6OH3;%V=HViW6vkm~IeUMbv=f0s3KoFT0hcAB!1 zS6C>nHjBcyr>3f1&?@cblGe!C<~dw$o+ww0bMA5Jx^lqjZ+CGg*nOly;to4j8=-%! z{3-0PZW}k13$Uwxv^pE-)aT|f4aYQlskqWzXH|y_SwJf9UNdi6?|UBEYsA5kALP73 zakl`@M=MyXBk~UGq`p_|Zhs~xxTEA+N@L>(oRI4Z`@|8-W)OrauZ5fDMmIyw<*pOkSpA)^P~WRC<(^C3 zK)((<+r_s)iOLH9vRWv@_f{_XxR~X9i4W9ZWwvG72gQ-<0=tP5)wY6lyQQTg&V3NE zi}FljxNFV~CBdvHZxBbi4aJl86uF4|t5`((S!gR4HjfJx+)DDl)-geme{yyyV{!Vq z?rt(S$h(|h@>(BHV`{er7}}J zCUi3Iglj7kVA|!im$?5ZinZAd8@a_{$_isHpR20&V1GQv+)QQ7aaK*oac3&4obpB+ zobM`0iy|UU7|ofNpBj_ozwDOML$d?*#S2=KZSu5YMf+P}pVWu=Jb@WC)vhD;1x@{z zSro5pX&X#~j?4(n#0AcON)@pKZbL7H5q2!xFhwfpOmh25zk6;>@FrFlPo^@KTqcYCq?nR{hgW3 z;qr07k#iQ29F!allE`T@eh|*Tt3x1c8LMi}#r^5^V0H5Y&^|wfC@NEBuj6@Q&Q4^$K zcs{w+t7$e}S9Q7KQcE!=K2r~9ad@k<)UF_>kB9%$LR!bjCFbkPq!i5>ZUT}iE$ma@ zBb#L#^6IRH+x=TLN;gycG5nY01S>2NeGc=fb~WOaKh^ZGk^LPRinn+|8m#4MyoYXj}LUF`mo@$0up0 z@NQfnSI`r@y^G;&2{1eRTFTe`h(G)uSh$k-e5y%Vx*d3{`bExz)c6w&!d)VOSL-Ga z`Th)6zhGn|+Dwgj@h@i)p|wd{r|U;{A`|x1zhR}A%Jn~kq2(_y4rOe_;gy^Pt|JyEz}9e!o`Gl8EdPa{JegVV&qM?d zCD!^Z9;Xg?R$XAdXTSrnQu>Z~3Zwib^T!1s7(3|}!3aEDItst^JJP=(&FX>#{TyB* znH&{AkVB=n13ivc7e)mE1k!0?S@CM8osv+#Az3Rl${R_b3dNmS|D#0NvpZH z1nDh4>}-%P8^Nvg1*uY=r!NK0=_q_K2I4S3 z#y{fLTEZ_8h8Os75&wj$MMb{FlQIDt$z}Y^hhYt>taK;A+RUWwuE09{RM(VM>;&rk zCT&<948~HhN7MMG0f;Y=HJ}d|il#_SPf#}B;dNh2-&~Cx^+gYiqYf^BBF&^GmJv<- z3Fv_JtSL5u3mwb3)i|p&?IlB4AO3*?pz^MKRmiK4! zg4_q=5CgIx21cOz_|Zi^n|bg}jX=k}O?!QI8pD$IHpt9+pts_|u?gI22&7XsS3FH!YypY1 z6Pu9WI@!4!F1iFfMP`${nHn3`%dzeGZ^KyV0tR zc>ed1nmHiUwqebl#|SVC>8!+4LsujW8vW#h@3nJc~OxbF&TM_ z=!U>fb)IN}jc~5FD#DQW0g=)eoJ$6*R7b!?&ExJSQyT&W`#usfisxI&Q|?Amf8@ST zpeJt7`x21su4u3>)I~ecH1ESAHUMVw4=J}UeKi*p!!t(3ED+x1$>i0PsOyU4Px*uM z*Ma?h1g5(+d=X{H;d2z9`&fF=4w%Wp;D+9i9-#C8AR~c63c=)82fIN%kls;n-lfw| z4^!`B=sU%bg_q!~PJ`rH0P3t9b#{#r_HU5rkKq|PLA1MIlH5XlO(rMGr|LH_`_;x0 zbTGV|>>rK8EkGQY!_PxMlht7<`3uH|vV-r2(t;&mn82J9wyJsEDUjd=T3kip>-@}Nv4CgV0-EAJB-TZ3E^4Y>DW(n@4XWKU>GKl_kv z2(mtnTHM2U5z>tW$v>Sa+O3is`@lGuOb%k*TqO4cV^fOgNF#BBL0G$gBM(g*;wvsv zPgBr2AJa~qxOzwS2T}HRwAM>7U#bvK#ke~7W zcJA&yDUGq=0wt|S+qBlL0o!4YECIvwkLp!dQENZLz5gu{5V0^StPB4Dr^yH8P?-Q% z$vJYSTqlRhaxDJs!b8|ra>P^ebGpgA`3kRFes~&IqgQGR7#4dW(N-O?Jz&*(fQ{O# z>x@pk4j*zYEa%m*^~hv}ER>wsVk2;9B)}_?iDh;jBWF*>&wAYL8%S0|@`g1>=DIN| z^`ifcLiYN>tJ{wj7)JcW1R^9RFc$V@q$nmG#=4fr%&IhO5jo6j2B0qrB3;qptC97H zZj;oA@lQtbfkjDv8zZ^b8iow%mwasy>2 zzqJE>l}rrL26r`Yk zzr!~81l-cEa!F;j?2;GY9X>CA54%q@?1fILI?|mFuj6H;XDPWsKVWY(pMJb~tKe$C zgKQmwLt_9luKL874VB+fzLB#*{gvgHG0c*hV4)lhQ}a^#M_1_uy>})(wj+4jH_+g5 zj5N0xFZQD0MFz^@k#sV~%^(J76r<%>w9SV^5&4O7o~sU3o2&Dg-K0LbZZzLWhEbVYTQVS{QC0niSq2 z`hdFmH{2S=(TQAFX3o%X+LnDe`p1DGwqyf&&T4X4Tvk{8u_$-3?*pt`GNlRefpvCVk0gJ5n!VzN7R8Foidi zzm+ps!^SGrU}l+0U(KP%U6Zol&^$v;)n|sf9X&Fd@oXe>m5ylUyI{`RV6WH+t7;57 zcp9456^>Fj!C`bFoK9VQ2SP0))EI7`MWK)JS{)>_$BR(a&=T@|{2R;)Rt^3Yj0rBq z%hC?NRk`2;*loTAukjX{*V_eKg=&NXp?}CjG=kasc;-G2S*vx`?PS(9L;eRYq?g!O z-qPp5eH4#XFbLAt{2T@!q`6* znBOp%-22jgxvNr<>`E2ju>S;J|Eh*{FxpSnKi2=F-va-A9oqU^p2a8o@IEfAL;k#ALx7G=d3~u^#)^(@oVy&R5q5tPMvP3392p0c$W+f zS>$(Gs+eJ08H66smS&;1OR}b(1`oOhntTU*btA$1j$n3L328pTn4{AM5O237e3#tF z1MqSsg})|tdSLi3M&?Fv#uR`?XdD_GY##c8diW;T7hb2;{yg$c)C~*_^z(ZI`+Zvi z8+~K_Rs9Wo5B#!kRbYz$3+m)fsF?PZIt}~CsmL(ARS%Ke6=ZKV8X6h*8{Q#H^&5t1 z-0{1xpUlwjLuLjl>9C3ofd6zQ+PEJaJ3Ha5Yam^P4dH!hJEL<;@>>0>tkAdA_b@zx zMWnZ}n&}eAvZ3axczK2zS5qtR84aeL#wSLdaik$0Ki_5I^~SSy*$xBSE9SW5c4eNu zfEo7tT63%{hlnD4j4mCn9aa0QvbqiStBavEp&7h(g>-O&l?}T?&Az<9FW z?10(-zd!-`(2x1uzCpfGKD}?fFNyp>J;>wI#8<<6&3lpv=Gwl2{z+u?5`JRAmNY-I zhFSGc<`>iT?-{-^t~NC?-80Px=lC_PR|+3jV{)$844>)GD|w7XyQLP=A=ddK-`{3Q)x6BS}S&(tN2dle<@bV2*3PG#v1~IzGWHtXl zO#5}B{Ck_Hv0XHeh6SrSQSNn2C1J#P1;%HD93ZnnW$7PQogKhv_lx8*hSX&=t;Tx! zUu^~vhKu3HjKG{%g{-!R)n;lJ^)$S3G2ywPaqys>1kqF`_+2oE%u%=DwE8~K5uUHA z{%gLs{CmCUeVx7KeA~PqcxU^{c)i}5UbENX>*|e#|7uxi8{7hJVk#2Q++B2Uqs8hO z>KQ*U-3BRi!&pF;r%UA5I;B5Oj;Jz*YWl6%ESfXp9>>bE5Qe@Iuzi0+41P~m8Q#bp zT|cQSn&TTKPX8l#t#Zc4#xtgurv2vM%(A7Hr6Mx1z|z-}U}_FWK1MRbuc?o(CuD zCswHeyv|lEO6jb`CaF1~LGvR%or&4IVAgA=@2=lU41cs~scE}ul(~laiMhbs&oa<* z&iobVgZXeAd`!-eu?7p-woWMXWmUI`C;ber2azl2B$meokvdpQW~=AbXzZ*tu@#ZU z9sANSJjNcaDjb0SGRG(v?jAb8Sh+L!B`X(CpnC9kU!m`X=WXvcZzZ(hC!$9D<-IoF zVo<$zJPWqIvM(r?W!|*)B2NZ z=mB}E-c@cAiS;w_EwyNwhuFPZz>7W%R8bjF=iRl>$@gB7oE8zK6B%e`85S7A#`Px6 zyw38j^=<1s>))33mJcj@%o%3V;F;=~ni_+K{xH6+kaMwhRby4$0?Us^pESU9)g|HnWvd98k;a4yk&|ub}@`o>d6nWg8YdcnD`gg zU$gOGZzMPRE204!)Aoa~amu=1VYu56>5HXvu2fpyjveU%JZfDcV_0{@htq-r>$8LY z9|)`oUiI4pANyK+4|-%chWyVJC z>rN@hjLXfPtX^|V^RI?q^hcC6`WyOL#+PtnBr1KRb6OcNUI~oKzktj9i)e)op;Dl_ z7l%rO$HM#Yjjo3L6a12645bX6j7g>p(^|`9>z}qGwjs7I)?Jn=X1{5wsk-@*sjR7^ zv5`I`my=4d{>mg$sS{S6X^~d2zUPDg??)s|6XtIsa;pKW{8PHs_$n&GJJ=Y8fu-0I zw!-aDGSW zK~wOk|9yBR_XG-p<3K)c2IDQkAU0fk6}bY|=5u4bX&#(Ge^~lhRqGmiZ+o36lOx(; zx0kebvFu|+J!c$m9Bk-mkdzwoGMFl+L>6l~ST^57etpDclm|=Rm>8(H$&MGsPtlPz z@kR1atswsB89j1^Os*p0oz|&QVa@+V;IenMFU8~W_~2-`;_T{N=Njsgyd8OVo$EjM zfWjQto9H#oSl?$6d9-8mW;a<+3zaKSSQ7zo(Bz7i+dYwwCsx@SI&W)F2TXx(na_cgj6-GM!h}%Uh+jWTg05I~qMk`sI4}5jpuO)~U)b}L_c*Aek>R1C zqp-Ku4SO`0dzE>{U(F#)y0x=)zU{K@L&xx_+fk#VQyqUguG?DLMp~0BUEm|TYuInF z85ZbQDTCw-#Bz3xY|woLd#F!-pJ*gKbIGKDsRl!xT|)G4!AneV{+(q8iGCSi|x zf)B(84lO@4CAc84)0gF0;cDvK=zLO;bO=7!W<-Ns)L${!EnO{(>?0jh?OpAcEsD9Msf4i{+U+s6 zw9;Br)kBn7M>4c+@GkMb<+J;@`fmn44U7u54V?@>jEs=J)|-rlrkmz5mcOjSY>OS+ zqXN-mW8z{aM1L1`-SM}5zdhIX6DaI=z&e*SwK3H-+|XadqR@$qdjU`xsl)^Qf(56Q zq{+R|vWr=1j#GNdm*ARv3wy>erI&J*ycB!L9B_)b`Acw~oW$nV0N%J_o)_*k=Q-yp zXLV;)SEB2%b4B6Q{GJ7BuEXicf5UmOuuQ@J!n*nU3f?Rj?_B22^fm?4)m7J5s!ukZ zj`A_X`{r2d9(!v?ZTnbz1#3^s+n`u48dJee-cZ`eB|%`62+a+ygbj6QV0mCc@P&V8 zpeHk^%3@BXS|ckZn=(Xyfwhev)OjNU&ZZWH(yGJjIsuP_YeK>kg^jA@) zs6Dnz)-C2tV}{`Z5kj-kBoF0=+{wqpHr8N{F^KiwMQqX<_M`wj+(~Hu)=2#jygoOi z6IjAiuo>5)Z(zAqUxZo*uLZXIfA{_1ed5XUWP0y;GCU7lCf70NMCWOj>`HL;ajq?F zT3D&@bU}8(*n&$1pBLy0_c_nHxB7YqIs~VNwg$->1U@AXucy@{W}?L{5p#9xB2y#g z`~{3fQ%q}&tw8kts(+#^Br>amZiO~a^@d*(k2nN&;=jmN8qY}61p8^O`YX2RlDbN8 zyV(re4Pi!i)pE&N$L6=iI9fQojv-MSqN1aJacp;twjZ%&TRU1>ntiZZ{im#k=`tC= zU{7L$CX>}8U28+0fhkyVGr)nR<4G+ePap%>C-S$-GI=akjL(&GFl27VB3OuxwJ&U) zhqP+g+rDGQcNB!ynZOr;YyOgfYrbfo-K)B1xhuF0?qjaY&U=NS0$2XB{3ZEs6?895 zcLv?EcP5YT(;Z+*d7@R*y(^uN?I5)D#@~(gP36q7 zmU`A-!0Ya^4F~o7k8P4IV*S)Iz=sk!cm_C$RNR_794mDR+KEYX_d zqZkiY!<$$Q>cV(36PxEm2^j%h@RD_O3+X!kg89tCW=C$Y9(W0!=tcNhXkX}$V4YxE z;4^<`|52j45_}UqiJn`oC(eG(3xzR-^$WM;pU69!8_YEnly()kUwUUSGYa}_fhpmR zkvFA3vHM&$R5h8+do6E*)()6kTb^5vS(jKZS%#Yi7)~fB@manlA7%zqg}?*3?N;mB}Qh~5}g zJZhptv0Lq@VaV%a8)}&Y6PewxT)Balaa3d*v2<^!L2Mw0g2Ra`_#jvx`&TR2XS(70 z+Cj#juF_TFD$nRAk_YppajI-yq}Y8 za=CM=>v3WA!oLfTxN5LQV2MxlF7p=i#`_Gu zZvIUF4gWG`+6{s|!nL%aaQ?2-XBeg!FPl4C%iGf&W1{v)za5hhdph>R*ut18(H=)D zhs9o=arKC~(9lEKfoC8ZlpUGo!l6)MU?#aEHwBsp>jrXzQ(&c=r1e(Y!@aJH{2^_Y z&nS0^-OSK;VYczo*p1ADpBSGRjv0F!R_Qw{3$T+9iTtJwR2N}68}HxX>*d|xuIQfc zx>`7=uvNj$yc>BR<#x@F&wHMGD`!^Dyqrh5C2|hursQ4nxq}n2UY8C<)XiEzUT^M2 zJAPzqWqaT8jlGxSvHe%aT*rO;D_c4HJ@W*^2q}e`xnNUTujarK_BJezbKT9|vpgSp z{_y_hZA47XoIst>gzzj-WStDHOnuD(bHtKhD-&HgE-qnq{NuQXamV5&$90K&FYap0 z%Ba0IgRQP5#aKaKUoMRN2Ntny_)Mr%(BvQOTgrI7k$k_GLG%ZMWz-eg1RQoN1l15O#&%P+LM;TNo)^3@*U?|gk-2Mb2I z9(wKm;{N8o>fREbiQcY(722stv~){e#~eIpc!#>_V;yPl6#YkBuh=`$9UVjAw|x>F ziLDtkGwKdZ(bJ8IhLg$$*(9%v98;U1@t=bXe(ZlQFe0b~-woagY{OPoI=n&qO?M7I z!D=kN-z!rL7mcMXY33Yr&^+2S&UD(?l5DAQ#&V{sdYgVCzPmbbyi9}lpl#$!?9XTY zmxv_j6^h03)XqQBd&<+rd&OrXW2xP9$}`A&&2_-v$`kgw+~o_7c%off+#k65tIUg815qs9syxsD?$zgn9n4lkTxU6qtNYl3= z58x(HnAbHGuCI@@58-+pgGKHocE?HKO4x26VVAn%ZlQG7Z4VJksg{skx?5`H;K$wp-Z%WG!&||%L=!Xo zGH^6hNw-WMsDDaSd2{8gavk=^udr5bw_LCswJb30H{3FOZyaxaYAJ6`F+P*G>UK-R z$VMM)I3gF*#fN8LgFK|=sjG=~{uic#EHuL#;da_Uthmo$TR5t#Ko-3;>8_z2^WYtZ z=j2v}t&dE+$8ik3Z`jSKzCjtF{H&j>?3TOfp2HK@7No!h?4JSkZPgae3M#PVE+sDN z4eSbcLi>Vi{3nB{;qkgv>ffQ-L6dK5V4^oEyi>gw-0g7}9(4MgPFI{O(Oo(JoGa*= z5m@Xk>)Pc$>pL7=7Ad2D1|Q&(@FHoCKH1vV-qF^~veH=9e9qF%R>ywCzR)tk6o#!e zjdeghrKub(&yL(u2N4amE%0mbAG~#ch7JTK``_{v3;rCg3ZklXWF}0O-)s3&Y2s@u zu}bzDS{Xhevv+s%V#_gNqZgUVU@J=24>r6Fntq@RB1QWUZncYWfK<`$ggXZh2G@rs z6Dx2Y`}P}nNL8Z4yMi8C3o^ToG6;n7_xe|c3C2n0zUB^adOv5C8DnypcA1Wu^oB|L z#quI$IXNyC!?n>VqKCnJJhO+D^pLqQE_V#w4K)pQ51$O(3TKDjV(hrfOl~G^W>W{K zi$a&w6Y4y?Av|-;}3anF=90clKlGesFid94H_BBNPm$s$M!7BdZp}Me)*Vm%(5LZwI)8Q2TSMtB|T&!Qk@ZsFXb~;4gNPn1E z{{rQlVT!(%QcJ!U(McUapiF{ibfQ*MT^SCCMiFbj79UC@dLhynEFO!0BFD7lF#5#84Kx}HW;;0%%UQ0pO15Je-;cL(CVuPz zN=LkGEAYRy;n(8wV9@ak;BQ%`J4-gG*02Z6(2C>HtARbN8Gf?<>U-L3qBJ+EJ~A!F zM)L6wl;Li+!-ADU)c!*0w$?%ZA<|ubH_}fcdnz`iFY(k)hwtGA`5pU3o?$7Sif^Gm z`q>}8qYVu=h;#_6T5Na+u?z#Xk3#3P#-U!?`H)?05R%~RO$F&;Qg`CjA5Bb8x_SVs z^H^f=cf)w$1{Kg7wxRvHo{- zu(Q}rzXP4Mi&fu9uv^Ws9e)kKSACsBYJ`7Vf@ezAb(flw;h_>)mgh=#e3Jjj-IQH; zseMvgc{*!q4bOgK*0t^NzK+DRy&KOO$Wd6v)?iDV2x6xSzLdk@w=$^7rnJO2@>S40;meYO8eA5mB@+9pufrzijz{Js~ykkSzoc?`2bXV1g`B33~4LH@M_t}!D>$P)WMWwbAE{qj zPMlV6YDxlU@&`7=t@Ow&-4^gRr$Ce>zzLKB)@voTI9!SbQ}U2L-45*KY7kMirN^vH z?H~wxfw3&l-hS97&w|=1hPT@V76Q8+9`9d3vn&Cp=7j~THY{5g!A5q_eUFWNCrGrJ zpr_{RHiHjb7UpBs;1)!&D2^W!{>>A z*n`LaX!uH`Ec3!$YPR|!yh()-3!8XM%&z=C!8323-IY|r7HIuF0)I`BW6@XFR8w#fhj`3~G- zRxnOuz$<%cHt%p5ZXtLVD~;@RPrTS($|Oeiqwo@}!Nzy4)L_Pjip}1CYo! zxO!zI&yP>X2=?zD98C|w$j!y8vz)6<=6wj>;3`B~tj%y$8rJ+857IJRWX45$P>N$1! zBUrQ!@=JYXT}fh9X5$YiRQ{Ew8lGqxS;qG14uQA%T<)to9Jv=(q!p13%2(i|C#v~c zmO5M45tLjBqIBP-=M<0;>7vvIHoA+j^ERW6Cg`fb(3J|0%pW|5@RGj;k}*x2k9ODw zTVz}6>k6~+rZ5-|2f2A24#7U4IzyE8E4;uCo@x!8Z#!Tqn*^rvYtBvtd$1JyV-^U@ zZ|OI!K#m>(1-lyrXdh5$jcJQfXoBi+1r{QinIIq^!Fjd?-87B;pCaY4*64I&;G0}Y z*}2SMj^T^g#T}N@9fDVH7zmiLAf_LJcm5A+*7x}I55Z)5hw?Un#O?sbax18>Luk2W zkvEGBl4qxC^yb^#4n2efM^^u;_R@Ke6+$^R4K?OYBT zb1Y-cr(AhF$ng2--5%6VFM2~OdWOhy=)yme7wG}=b{O?E2u95Ae0BxR(FF`=3(&e1 z=nWMqy&6atfe9{8`3>nGZSZ6ZR?ybSemh2>W_;I(UmDT|ZNc!q%h^3SYZ!fEC^i2P z$M@m7oj`}a!?88N8CT^$Vm`Q+24H9!^H~Yq632Hjyd-*X02y~ukAmsZ2X@s%tp;c{ zq2cWydy|-*l%_2O$74d#(SrFkn)WXProIL^<8m+?CUAt6wlB`J2{f$-Z*CSgrvi@1 zWIvTHpZ!$Ic@E=K8m;OAjpfA?>7Lr;2y)=vy;uumXCA` z=@3<1+`j%0<3JEk$S(=3R^o z*NQehdIp9-TJ8_(_%dyNm)8^i%i`M*+QYz@7(r@`%-J>C#zza~QR_k$Zc&G1sNyIu zwgw$IejU0b%n>2lPq5Braqdg*{}mE1B=#|_^Mu+@=d3i2$>UeS5$iz$oV34SP7bhn zxCe2Bid@J@aBc3lGOWo>&_7KmtvCBMsmIFHUMYHLDf(Ilks(gT3@7@)NleZY;%zi) z%7zTs!F4N?lgPIg&XvF@_<1TH_WV3PQ<)v7GqyeF2_JL)N7P0xcO>#<7V>l+j*6lT z12tKk(o0gam5N#`h_qPf_YPXSA|r2kw#Kw~Z6rW^)}-BQ@|}>B3LH_a=!h!lqq3Y) zjyo^Gege{BL{4PYAjALkX)N}`YDvxSP#F@9=E zj5022LF8G@=KpN=1T%CN{W_h!JT`HRhwpqvl9o@)3)bT-zRfB6&Pjb1a$W)T>F1ez zMWu&$GJ@%NUX_F`!)G+_f{8kYJ;6UM@}R~Qy&Tknm9qsmv4zdXui`8lTU60` z@w~?t?ZxqJ4Clud{T9n-G0I6o*C+5wEGoA+EnS-WE{}9nrmY3XbT!^9GsXx;;%bZ` zRrs|mQjpB~B{^SkOBY8X1RJu@AtfnGw3a9_f!`8&#d9^0IW>mw1$VZCP29P-OTiJI z#H$2bDasb5mqr4Ui%KfS8G=RG%JT_XGVruAEgzxZ8rc&yYT%PEsCUsOK0ZaO<c!b z3BN8q2YC|Zr?I`h$_qYIib@pHmc?=62*EdMVcUqy)yh7 z^S_o%K-xqv7IG}c`B>h?H{#wz2|_N#*dREi#hshDp2&i%=PtxZC+CXKBtte_=w7HOke0scS zXCVu@94R!1IN!^@hnI_&=+oj|w3R4HQ*S)Ci*T0J+@_C7t&R|O~YHuM; z;ule_tEe>~F5z5Jnvm53_MDX9`Coa0JzG56>pvmqA_KCh5fQN@dXLz%QgcF1L~V&Z zV^LorsSUr1-X{8ZY*F7#qz;9?aqyeOo@jZYHAH}#gPQTU@myh1^d0oXf1R=EL%yANaS2=QJL1FyAtP%_7@sJ z$cFeI}Vl;SF^!d7-gtid+K(vU^{A3MdzmS)Zq8v8SvO-&k zwsO;kuVq3UpG`}PF-Bx$PvyT1#^&eDPhL{K_)h46Ldwgaj7+Zjir>UK;R1#P5knD91wCO>Ql73&>!M_(IRoQrqJAiFJ6oXf|b4)?Tba4|C8A(%Xek@wnEXn z(AY`*CbA(U^InM~D>I7JV#E+HkttprDR}FPBNK^5av(-FA zTQT1Ico%IaauTFbhL~*#?II-a^_U{sH@`@tgrtdUydITa&q_qw3!UVmJU8V+d&f(> ziz|isR%K7fpU~|>I>ahNj1aH2r(oEBJwif=f!v5^6k0;u`RmbB{H{Zuga#Gg>5wHc zmk~$39$|$ny`G^6?JqV_$BDE?BCS}usK1N#LNc$Ce2O`Mm=TH$^2Iqqj0q)*j(hzs zA{9#)Nd^cQeiida(WYW-7c+y`I>*T8YdtSUHzDH!ktqK3yu|$ptsyj_&>v#1^m@z_ z5-r9SF(!$t2pJR-DgGCCSAtJ5(@8A)6l)~WCxlKE*A;qGXeTiu3Mmm|h^QganxdUW zb`dd;d{v|?o-)sW%4$z!(s;<+`z3w&732NuQC*A-LSN+Zt#7nFhDpK!43SQ6OM2{3wBl@-In`JqsJldfWN0#HgB6>mSii&(I^hd>_ z_AAMiU(ZR!dkp^xEhMB%oR!QOB5#UlT`}f}d1x8V6SGI5bwsJJXRoy=U1*;gtY?G{ z5})EFt|?}?r8rY;;=L5*h<$NoAxqIknn|p3#e7`MPQ28rhCB#a5UnQKSL90pozM0k z_JIey-(xi<{(r!#^nTIaBjn_1(f{I`8>~TZ6usYMrFe_&CP#^Q@>{Iz#5XswW{CIO zd@oA6&vuvJ?-czi^00_+?^5DJetTLZgF?#0Y(vaCUe9>N_>)%E>83%w>ZpBNv-{6@?v#7shrUP42OkwmOnMZbS-9})UOjC^9$`M+^P*iFPy z#fxMyL|G4~yFD8m)bW z75fD&4u4{!Jc~W@I6jkoSRJ>)O|gae`c+sgmw>@uf)#Q$7V7cjsvnIFasUw;@8VtQ z1RA^xTX*)~$0OE_S5It`{m5V-jv7gf(^y_ph=KbGJLMwmnX8N9&UbT_W8~~P#dUw> z9xfwW{~E*l^2oZU*S8O zM@-R5tneF&&e@A){xlZ9=BBvC$GP>iRw& zjTEnYKt9Hr+r6l^+T&U62oAV4zK+&J)3)Q)9n0=}MK#z5o~Xfi ze1=ocagvi}u73x5Bs6nONdJpzQh(HTWTM!ygyD z_r`0~jsHHBdSJQl1dqnsc!U~ARq+NyNfCSjY51Yf5>vc^JaeC618#*5(_>G3fJALW zZpOgI&;k#I6@TSjxI^|4Av%j3as9PU;M-~t&mW~(G!u(VHy*Dvu!K+4dteHmv!9Ni zr%?5&HZ7Sbk9uUAdkSDeWeCBl2K?ae#-f3|BxyA}u>3hkdB;q14?> zT3|kwbisA73rykBB76HkXtEUgim=oOD^h*z%)M#R32RrC5mg`d2;xrg9J*ekOe83yJKmOboYR?06e}@eh$sU8t`++9vR! zEvcVuB41Xi6V+a7L$$1GPzyn&{ujO!77SBoL8jgh7ZQ_TR!f2wX{&yseojQ?TJ;Dh zfG4U8?{sOc0r-*+sPnYSV%o*2jWB7BAptykU=!w>a>2 z9f`sJ9?pn#DFIZ#N7UUmqNAQrZ!y$c8{%(g5Lvke?)RU0{i$44?tqm}Q8M6$yiY{# zdF2qVrOIHkVpLFE)ZQ9ONtRQjos?WxN};cRhQ;^={+SLDL*zUWW8H`xyGFcmf3=F5 zN-b>-j|=w*SBK#sHS`~G*Qa63|2cFvbP1lxhhR2KhC7A(gr{-bAH?IvQja}|Y2Ts# zg(Uik4r_-;b_(V0hrcC*Xtye$m_MM0&PG3OLZ<#jmx>H;F-Ndo{xR_x#yz! zhJuUSpA*oLbMeA$M)&@OR0`XC627<(@TqKuqvj?(wh?~W z{n9I9x`%)l_*XU{Nn>D?%?6Y7Hu;8E5sUgz|B@)yG<_cY=^?#GPspwzo&9t;8GqI9 zB!B1xeJ|p8qV#FXZ^}2s_9iK3P%vR{s^ZM znU<`6ppHTg3*bUAX-$dnY)Sw9Mq5ed#KYQA@Cc`n^2^$N(8C4Vf7(MWm;UV3N`MlG ziX?#-GLh>y2TqOJU4W!h?*|B)ktD#E=ph1VBKgS@U zbI?KC829coE()(!D@tp1h<6aJb!4cZM5?TwDEd{I%^ zpHZh1D1SSi&C^KiUdr0R{7|fMgf8uh$F>#z^^fr=o*+_KBChonsE!)23G`yjnGSzZ zU*zo<>VA*(J4mWT`8s;_75i)XZ!`QJH^C{?Vk}PJ+_H=}mD#3)_^wVZ{>J$IHqqtd zIdU<)D9`Zg_GF~%O8=;*tc~hYc7dgoZM$;D(=h9X`zzn`;5uV}YXkGh(EUhRXIa;` zg^3^uBD$mcnUXJ1D-h>v>N)3|uGYird{~14Sa)@4op+~} zt=Yk?8qw_ii8%e&UqKrmTqDnugN836w_u=il8YfflpNYE5g!W1|2{dhX86WLT7~!O zj!2gcJ>;L%tDrc33h1RSc)>HZYRX_K+31YlZ%B&!NncWK=Q+)>@ZJl)Z1h6YD0 z>ig0Tt+cL7c)s>?XuS4JtxjzJRBajYD|5B&YAVcO4Rve5SlPo@BjNBN-6Pc{e?lIX zlDaKnN!K~FLdptXC-QeLKKwP>_Q(ye^9R(AB0bd`(!XfXx|&s*1M+UC)>(?z+RJ0L zBRW@jz0_766B))RWg&n3ae0?IGxAlqrVbtnm~cPWHP9^57m>ZvMD;Ui9H@qpksP@+ zx_G&kuIr~}OYzJo4{0CZfA6mS4x)Ot9Lv1vru4oxOzN!LBi&HP>IP~JUTfbmR@+%*Y#oSm~Vaed$NvHLbee9=Q@KLBz8U zbl*B%V|gd7b1iaGd8(dK$c3T*Rp&Gsbl>Q^!`}5IbXw;Nm11sMLT`jet*5RqTq*K- zC`s{#em7Rr>KWD0YWZ}inRY;(Og~>Jm6y-r!)_Ouru6_Rl0zTsChsN&Gli_mwHUjK z>)wa|<%(`R8KU}uH5i2la}ZG}CFG{ceB#}|1b=cj+)BN!{vNSPV{|ZV=rU!Wx>BmC z#e*7IsQXg-h8fgq@Ic3;rV*cVNA*innR9)63deEJF^mT|$okC2`vQSMJKq`j%!U9`If{QDX z8k`<9g`0%>Yhm@0vJl*RSob(Qj%fKh-1~6l5y-Jb$*nsW-VjO%UQqJ`Ris?)q|OBv z?j7P-oa8&n3n$6-b^Q%~*w??*$c+YL2l~?-yvLshxbh5xm*jIl>cIuDl z$@3^BQa35uVPZ?>lNtVFLsz)~lM2~b!#&_nsI3ju?#nK%kbF6{w2`6q;f;Z1sxSCT zJ`lMh)z?~RzZ0c=A}EndtRL&0_IjJTLpcz>Eh}nQGJ#oOj4iJ_0iW3xm=lk|8y?Gi z=ZkP3wVLW+27QZ+f$!<+s+|}!Kt;evIYQnbUDUNJVsIF*onh7Zj`XuOT2~GJiq+vi zgJmLl;rsAXZjl~Rt4q~&YI1mYXo`9+d{_->^TQ`1C*WfELOmUMq{&htEXCij-c!(` zTj5*TXegsUrQa!6(r+T7y1m{i*U?K#8F19G#{R?+?=f~|ev>CF#@hNurjw>qrg^3q za`iMf+%(KEG&5Br^5(f|x2ZYtOdCzT^py;6DZ7*lax0~Wex0&jiId0ada3o)wh%3^ z;m(e#zXaEXJpMxO1%Fjvqd*+PlNlD&!jYg>K$DWY&!+dr^hx0!95t;=ldt>>)|t@W)_ ztxl`a`myz;b*}XzyV-W!TH9L2vd^*)?6RLot=+cOwz}3gEfdX;O+Ua0;x_~^f_$&P zCiehq6%(!>3XwIdivLevXI~9pZI972-BpQ^^(W_*!ga2>u0e$ZUFDsdod=zRTuq$i zJkN->9_3g4*8(4a>U}%7Di9w$9_|>t6Z$y3lblB%gjV@C2QmU5`4R$Q_g3$Bo|Z1Z zGpk@lVS?*7mzxmZ3BIIY7ozt!`!!!#7?OUBd?z1JJ~mb`b~Ls%tud7~ceFMnlA^j} zmimSS z9Q8?5FUL8%;+Sh2;YhH#ZJ$~8noAg`8jSi4lfy}60>gyQZ zf}x|oZJd2|)c6>0?7oC0iEkGlSiDTh7bRwtcvRwW@s#4Pl73He6kC)Sh`$+^6O$U% z$Wg|+*ZfRhUoq=$lK*6?{~hmI*H-5@1)U2@<>%(D&%2QOIJbQ6p`1rKzvQ0JnUfcl z-#)Kyp*Js?D72nL^Hy_CEzEE=ab0n0?mLCu-Dik%ZSSt*Z0dUKYT*3Tnd=;!H`-Y# z_s{%N`Qf}k!OFsNg{574UBlcny~BN(fu1lfeos!-F3JI=y3uRQv}W7hjjvFsDlP$JuXY{+BTk22-UZ74Ip^vzQDlS?K~FZp|kVZ}dB zYM9U}ZgSL8TdujMVY|K}v-TtEDf;*n_nm_2xi7Qc%ra(;$coPyk=;FeN0u{PO}C~W zP4AJpEUQz_1ul=N2enbK3s{ZxK-h4(7B%E{%9m1$SzY^hSEt|Z?n(W-cI z;+@!vQI&01Oic|Vl-<(f$lq%BpwnB^WzXM|^Jiw$taq|TXP?PFl)WaaY-T}P^R#+t zp|o?Ejk307AINQ&-7b4yrj&UmV@A#=*>m&f@lLtprdKoIJ_#9maN+ytHBm3TzvlH;Lb;Tz13 z`|2xL(kwS{ey>Ql5dXZy_sJ#8B~|#llBG)dDp{4sR~l9ENQJ=_jOAyS3zv~gbu0N< z;+t`%=x4U;mWigZhEEju`PC_bYn~C#-uW|fe#qXJ^|t5MvT`ztXADelnZ7pt zR_58P{n_<%A7=N>x|P*D!;sk|t6@&x-0t}?d9CwnWgW==IK4u~;LIi&x3g`TO|pN; z*p=~bYLBduX+v`QlPZnx72qWi~TIFXVMqN=a=qKwo}D+l}A_oylRK4+pDywe6NzD(x(*$RTxtC zQR!(Vt`wUa>yF-N+iS@)*3;jX?y3`lG2TV4{rLxTOJv7pwaf&Xm@z56R>s}5ztZld z#iwsdUzs&3vrl&W>=&8+a@M5{%noFHlD#0ipx{)-RaeinX8FZZ$7FR&kD@b{%Wjia zKSQ6<|JCfQb}85MI%QOIe^Ic+H_JN^?z~K`v(j4r#GYWD9X~LtMPfn1SMfh4HAvW( zbR_YE5`z*impC7nlDITFH|lTu6tiYsV^B>ipfjAX(yK$N&3D;<|1XDtX60?#ImN!(cQEgz&Q#ESVTvqK+wQf~6RDD@>ZWVjw zV1?^tGD|KgwkX~eb2I9ZE!A?!P*1t7#s#-|HRtt$7rAY624(%8Sv~7Q#`o#prs>kY zO(~Z)C#`i>e&*%uSy}tD7Ns7_45nPnuw_N%P0oFkpPJb=XItt!skJj|q}6oo?wc{=(EQyth zJx_dE>`ckrq+P{biBsdVqRK|UW1V1sVI4+x%ryNMl2Ka}_`_GzKgqYyf8Jl$>-7EO z>gvAetnQVAmv!$eW^*gsQ~Ttoc`=*fyO#)+X;rCJ)pON7)q7V@sghkqsrq}>A=Ra7 z-&d(ysed_pazU}`32S5BF;5+h>;ui{5MAX^qA%#);v81+Mt+@~>)B&6H)XU=^QAma zt&}o8eMD;Mi~(80GiIgFN$->~=+))SmZ{})CZy!KO1`{OxG}X={^)1Z^E}V`<{W(C z&Uy2BT1L&6ol^I|*qriD>X3|jS?}hbE2!=I(pTAgB#;{VDdLw^OM-PtT<@sTi7jJp zBxJ=dEE-czM_)vnbWRIgmEe$`dgFIIh0 zqeYGXs%@yUufo!@Ba{Crek*ZtVu#q>F(+&r%|0cTJSyM$)4T)S)12?-?aCRG^&%r7 z!U`d%@p)--ac!bq_NBV>{5S~l>+Y)TNx9W z2-Q~GYjqQ{UzXx)Y0$K|7N*rUF9ehoslr5c(*b~E7Yv=UZr7GZ&dxh^3f^@ z)yG!tR=r!b;gx2Uf4}U|QpbvyO$sKIj;|1t?zn7PXNZ^UY8fz&{p^|K>Rzxv$CL4A zMuiM5y=~gdw83eYUmbkaH>KgLH&Yj7yq6k$X-xOO{P$Ihl)mXtQ#)oXNWYwAe_A@X z)61i&@4T9uvFM@q#Z}C^o1R^G`d`+qM-OuC>8h)JR>Qz~celVEU##>%U7@@*$J$H9 zJcwE!UoxRz+^6wli*-)OOlnplzxbC4tK!SVI~^}=H^^xhWouzhHod~n^jxhMoa0|e z-r;=zA-~(x#&gE~ttY`V)8|m{>B7oP<8<>R^JGU3_eyhB- z^4uy9s;;lrrrNM79V+!Id#B9A-ng_ zms2(Cbk_W=z1bzQ|H{ZrACdNmwr`bjJ^Mh`I~o16uF|fF%#coJEzABOZ%bCIe0%zg z{IAlovrDIz$XomJUT&?Fr0mu&LpgI(-pMhiEpUFFKiWIf-8gu|yC(FeZd}+;&cbrh zIo45JzhnG)$DxECu^Gj(5{DPR5L-Fn%jmn&lcRpOb+rFy+i5*boN8Zc-~cQGw*o)F zbd(af>0aZV$N=_hik;JDdJIyH)o5f*v`WytDJ?6uwz_A-`MR z(|n!#LQdKI8m@a;SMrqnN14{V)!C=>wakxPowM_Ui*m1o+vHDEhxjh4*k7ct$-{li zW;gV5EVgusZE72o@FY4WJ}c%@>>G9^Zn=G5Os2KB!)5JnIbwX<@V#^ahNk3*FMLaV zC)6YO2==_P{%r5{z%lR7p$Xbm-6E~2{IN305VCHxFN;Y|FeMq2GK!Tbc`rGx?6op4 z%i794D|@BP;ZhHhZONmO^NXJ;HZk5EHzH=b-EJ#mwwM+n1&h`3p|1X_o}Puv^VSro z`H%Da=l+ykGV6!T2kGgV-{pLn+oE7x&hG42SxR2r+|~Jy@_Tyig%e#l|Bs`y0FUb0 zqVPS}N0XTd3GTtIxVyW%yA+BRcMA^1-HN+IOY!3F4#g58<1^RJeJk(l*Y_yN$h9MT zuk|laCr>qxGtkD{)xR|}F&L9^FQZ~+4^NZKWX~J#9=_!Ns2!kX z#X^=x_MDu~R^%%?ibP#^cXSo74s}$wU5jcSExHOsowvqWO2|blO{{O^X{hmM2bD6f znGb!RK}d3as#K3$4*eTG9B!p08?SMH95pVP`AB!Mmvz1Uuycd!nCnzbr`SB%@+F?h zQ7^}V9368E%kf|KjKpmTQbN^)((wmkez-bBb#TnL-j>^lS9yV(XnxXJ^+e=p;Iy~0 z=cp%}cZKI@M&0!ElsTz)Qk$gA{JAN!p(p&aoA-K3k@N-`J+Oz$9(?ET>h<{e^lZKs znO?usv&DZibFqJX`bv+>vn}%ncHZCor~HTX4gNSiI+CsxP`|LqKj-s^oG=tbx3$uK zd8T!gW3yef4{{{iVl9I0oi)3)u${_XW|ezgQ~^g`>tMOBu!JwpKR`WDrhn=A z)hbGCct$8D$Oen~74HPk8t*R83D0fsip=SmKmCW&3ivFZ@!ruM-rv&y%JV7^?fV!k zoBmW8lioh`$$Q$r(O=K|HuOtoqtLp{>*^p+J$f^^M*B0e7oE)^hMn9aXk5T2)?T=_g&b{1o$A(!heK!T&9GCI4});2`}1XXYidI~e75t)*UH zEpCiZHy}xBj`@%dMXm7^Img+hmGU0sCw8~bun%??h~De&9vg^#8k;M&T1?fLI_`z; zi2H)8O0?V2*5;9?NW)Q=7{u0OOPj+WM!cn+RG%n)6g9Flv@1Bp-^2IV>+sbLJn^Ua z?)f(OTBeorH}Md^-*-J+%s0tD$^Xdf_74rL#Tuk|BzNF&FxB@l&@b3CP%3;RGAdj~ zDWuI;rs|n+)jOefn#(?9p9)D}RMh2D#D3y@)VEk$L)%8VqUD|RQ64YHSv*LYPvkf9 zb=cCNZf#)aA~ht*oNCO_dx9uk57gZ>V**`i{sj)?S>*YIxVb_n+&RBm?^%AcY1Zb> z=FvM{ciffTUtQ0l8%5`hPKmDSYKt|lY@cuKhotRg!UW+Z?`4yj{gr?JWtDEwAXZIEF99m@*U)MpjD=VwE9izYi(hhYRPY} zX{jm=1jRNa3_*%xMbHMD@cmGMzs2+d>#QpKz}TiW0FR)fW;4d4xB3h{zK%4Jb_XlK z&nJs(k-R?AQbZbVX=q#L^hB|)-qClWraP;GY<1ij7ri*@gQJf1fSg-Wg;Byfp^s39 zUr1gUHK0g14!*$zt*~|i?7X$1#L%Kpr|`H?{^0z;onT|1+aLDZgSUe2K>5%^&+s7i zRu0b!1j46+A3`4AEAN)zQQsVOSx{642cCs<`VXs3jG!?OJ@;ndhkqu0NPsVD8zGOA z+DZZvcu!lqId<6v5Xbgf>Pthd%cNYkl}HWkAhZ?+ zco0<|qQkj`{$WNF7dw{!DSQw=iFT=rctRQgKGQ$;F^+Baxwe0;A1!(8rR~q{i)~@+ z0I9ARBedtIfF{?BzYN7lHDfgB%!~BJ$nxt3zS012kLHG6MfQU$G&sbAz*s5R*}pb4 z92~Ow!EWK-0|f)wgI9gY;q1Pr$|c{T-~r#rKtb<6!FGWU{-xphp^Wfzr7^gjm9@%7 zU1PMVv8RX|G^4`&5AM6vQtobDZ9Qk>z!P33kGG^*9@~bXa$7=LgU#Ou95IyBU_jMZxgKJ<;aP!}U=VCc|vV2;$$Q>-xky(4k z;s+OPnI%bjAe9wMh$|t&TnkzWGIO9Sy9XxBRWN4{YG>81+C(K6_6g;cgOS@|M)`*C zUN}+#73_;hlfM*f5?L8+11iq+2xw4&R^i%#-oe`5oZ)=FUH%b)Ca^JN1YZZMg|0`6 zDZfN2C~b_%NOc`a8Y1cN05=d+@lR5Ud`ON%QhHg;&%YqkzMbukHHYLOR1DlP^c~R;fJyLxvJzee5zA1CibRh^g~)3B+Od05$f;i z0A;DtQz@%KdeM4iW2BfiB3w^dt#k_a5A9YCf+gYzF9VyhYVdfZ zXymwhFMK$7J1{KrRJp5PfCiy5?Wxb#>!UMy3M2AmX*-_*KK(Y~sRc1Jwy>qVl#O4^ z=RkJvJnIi@Kk+4d0ZEJR*wc_H78j;NMmxi(OFN;>eiC^GH_h|tNXVv#R7FRmJ~xBoEy*y` zkMvX-q?4tlXdFeqtDE^zD+*H88@;IcSzlpBsVit-9w-67LO+oc*$!Kfl`;~&!W*FK9%m@{>mQ7o=HJlUO$QmZ zG}5!9=~y#2@)-)SVL0L$WR(@BZ^?ReVt-{)LH)RA93#J&y}^_In>I498I_p5@OOIX z2(uAVrocC#`M^@%VwR_a&9mr-e}PKm2Rd;I6cJ6$D0tvwL3pkWVrz5o8z-2xnXz!_ zk1o|pjzLZ4AhaZp%sI?>vm~hlx9bw9W1rxkvFHl_0aikJ zWZc3-43&tTi~?mN%kNu;6osNP#B_keyc2T=6yZE(es(YY58s;*!JI{knljdoGtoP* zOrMj%=4~do`5j8bbMPomh6^>rtV(*)-01n20#kJ>G?ZDY{RZfxmxPA0GT!$z6k@Y* zPPW47SqJX+(nwn$0owE>yhnccHFF?Ktrn>PvdK{N%BwIpk*T{7^pKUvBbjC1XZo0P z;i(^Qo&)*28`J=4@P`(HD)mq3Kp)^IZIB=H40DQ;=|(=A-;m8y9QxCFSXCB<2JVh-w#j$;7Mf6=N5QP>ebIfGubI;&@`2fB0VdU8rGvA|A4$1)% z;<8~h9mebdbL=GUkX_IzU1XZl`XoTJ6xh$8T=GEidIwkIG&<cY=k9LlEMV3XRQb-j!ECds@H-N)a|Rw!Z)<32cNwkJV44l0kexQhpx zHZq>d?0?iv-r#510DDe_cz$ID>-!v1dm275j@|9jT zmoaB(OC||AFNQ3o_fdmx%0!!EnUdxN_6aRb%W_q;s%*Y2at7|hTp+<26p2Jk1#=5C z#q3YhkdDlncbQMdS+a-*p-gDVyh19&3{u%B&7Cz$Lru_?IY5gL30koVq$*uOqChUW zg8Lwo8AUro4Q8bg9Qk%q7&%_JkLWX!7x!;Xn$Dip-r=qpV%}k@Q#*e`?_=I0A^koy zN6Xo(Mk;&S^wQrzU&$h;%mh`x8nX$n&P_`osY!$~M`1@nli7*gZ?q++jLYm=Ew^xm zN*Fi)F^cd{wC7Nrt!Lu!3_rxRY>KCEEy+P0%p;n><}u3hYcaBwg8%)4c?n9Ip=J&? z-5f{HvA5|9QUasZN<2>qB*8q&KG$=w3AlGV;AjqWznWi=f8l2C>5JG>NE$80*QV)Q zLH#8;q|N2l83myxoNa!A#wn-aF&lB+=^W%4XQM66&1569jqIc4xUxug&Chf)+Hq~@ zH|92N#m1Y7+*YG5l57UyJr|ne*^*3I_AYY+I-z3Z7{&nIJOyRJQ@YMpRenr}9(gd27DP$dZMx&9rSefRAhN`_Wf?iOkGgs7ZMyhca`o9!?6uV4s$kfIC zJcGVtJ~4C5tK_~}7Uyy%cYM+yhPf*@@+kH~k)NMbXO`o6{)ODt8jyNW{(A6K z766at75eBo%m&5|W~XtImPa=7SLli&%wSpx{`U$b8jJ=n8=za*riRSj#I<9{0pl2Z z$S};Scw%~C%q;;Ws75Cl!x*128QQ)Vxc7I_%}}8JZpM>M@F`w~64W%jP}?*{e~N|f z>m}6Dnfe2g1LI=`?pmH)fO0y9=}ulTb7&tHsx;;>S#Nfx*}y0GY;<8$%#k3Or9j6x z)$B}8A`kVHnLv|CCJ2-BK}6VQmd2+|FpIOV=?il#g^z;HWmjO-+hwky|9PLj=? zY(30xE-=c%q6^8gY<`1 zsxxNh1DHLgFt_Lx+{;bKXy!RML6@OA8)UA8SkoND`~@PxNRkid&~0c? zTAnl)Fo|Yq`Wmz{DBm&WIhcKzjX9`^7vp$t;F%zx=x>9o;4~Pg34~#0U`APry^cu- zGwYBIS`rGx$9Uofz;96**UEzVyd%cj`|t%^!gENNbr_HHU`ESD4nPrLf%djD=5h~I z&JE0GP(gpj%H%QFIVG5fG!ADwJLp+kp(mSyxhNW3qig0;P%bht;`fJktTyhr3vhG9 zf>=YKd#()aXadMOt-zBgY*xhhmx|vRWFCT2`YO)tVyIDvV+EfH)oDvSZLQ#PO2Rj5 z56{JSsKc_H7YEGlpjLdzqNro6%%HC@yXVJhBOf%XcQH>^F*ie}(hk2LG1rqn%yHzZ z>BPM}5la5UaCPK>cceIuV43+EXF3;Vq8V_`)xw#a<7i!RSiI?$>J&;hGKd;Xyxfh!~LDPOnQJKmjqT4LmuKMmq9}cKyRFa z6~J$p0i9-lBxaq%%pZm?cp1jI^7#2_+=IVkR-Z`DV|B3vDRu==hq{e3*%>Q}-I(>) zAgy*Nh*q1hM#vAsR7E@`*|D;Jf|0H**^V{EDR^82lEu8}gCn{HMejt+&_^+UDwu8I zY{5+Dz&R;}70pxJ!39B_gl>=yAqBBAiN&{i$K0V0i60|kYmE4BuyUM*^S*`{Yc`dlxjGZIo1>2i@LO!x<;HGJg(dZ|oppT}3uGZJ60aeys z#XlxU1v@))aQ+j7(Q+Dtn-6dK&2rPT_lSEO&yB{|uFd0qg^Q zv#?6)By~aR%>_Bda?|p`lGjpKwu{U8^Pr$xu)geQE=2CzR%Jk>6PULr!eb&ukYiOz zy{+%lZX0E=wyDd35h(mCbP}J5cZ4?b6Jf0MTKG+1p-NwYJN;k!2vk*8si;OmHdN;m z!Y`5X(oc1P+;@%o%mVBOZnscedMkgoG`AJ9ow4n=*0im)Y?8mAA0Ormu^$+Nt~Z+N z1J!oQ$jIIBt?*W)#Y~QLP^Kuylz+95Hij<1sDBi+<|sZVDjy()iz|e)Vh^E|a0X78 z0i-5WlKYL2HWy^v4eDM{+685?dR*P7FEtjJOHh4rNFMV!=ObuSNJpgUvLV@o861I! zq<~r8XwIZZF6kG7x&8aRThg!i9%nxB^zd!>J_)@Izf+oER4U8$6=n(7Y%gtUl+T$M z^)+g*t%WVGWs>xd^b^j7nW#m)h@1_+3-0#!@VEE3^0@=Wf(=4XprUK0XQMK?!DSbQ z$ZIU^Emv$itv7A;Y|pG;a7A)UErbN_JsD0bng^lXKBzWS{#9rsw=z4@QyrrEw83CR zZZLk(jTmu8a@V+g!mr43S;dVK%CO`3-{DC)%N-$2*^ac5siJ3e)2O1C0tmi7l2JNo z^Ypq#E?OD9!@2Aj?7^2JZ!*M|=DvY9dyRX6{c08=x;^rm$D>Arb%(x?URCqbWNp0h z6`r3oR5PAJLE4?Qa#7r`dctbRgk zY_w2IYMh3&a9Rz1wi>s+$gHO4 zpayuG_uv+KYo4YzK|9)KM40Xr{0+7~^N}EN8#^q@@JuP%n{7ao2;6=&fb^a0n9(cJ zH(0G$v3jhoCo-d8{^savq{ChY)iu9an6@Wwa|Y(ORA_>p!TGYy{LSz&&y2I647J4` z@v_ku+26gHH<}GA)F)WSuSd>w5_t1}8S9|it_{xP3hHBd3J1UGLc+O82a#l~=pQpb z^jk=0sRw7oTkN>nFe^w0WQsI{TdAt~2b+UXa*vbE1>}YPFSEqB14qp#+L^mzY=HyK zfdt=P*kd0vKa#CRFQzeliPi8SHiRcRjeE!KWOz22{K;M6@(36B>gGDGjX{{b$Z4rf z^XqVSL}t-s^{CbmG=7IsSG!?O)(V+>;pNy92|(TYRISZDS3WYe^y*4V0`|lB=lv7s@PKACDoT&in)a5Qct80*5LDVb-6Rv>oxwo2&9xI7AzNm)Jj1e9ENYp9*MF5mXC01j)@EVs4NH?0B&{7ZeoXIrBuA3mx=#z9oNziy_0XHmoSYW+D{%BeZaEbRoyyqsS}n$e${sDHxqS_)6o z2(zV;%bag?=lYT3mUL-^tV($;&!j=NCbs|h;o@+zl~&eY(tGMQEhjA-nXOjRN^8IA zwarpkh3UM`9^>YVtB~|DM?7l%z>D%i8EHUl8McMW<)==$2yK#B&d1MSTe0UEH-(Jke{7NbLe~YkH!Y31ZN5o=)kl8Y_(A)9n zubx4v^E0CSyZrOim4U7x2&(lTl{lk z#Vl|%I3`k2{mpRUTX#c#Lj_@mbrRptIU`DlnHsYqwqjIn=Th4dX&bi+RgBl_QyS6~ zRR0st%kqPn{@rcljUn(qPy3nCUCB+pU7H%1; z8~Wg%;uq8VcnAHYo*@};Q)Qp+XCrUv;6`6FPqm=zeetuF@7Ii%>31{Ze^&TDIPEQ6 z7FAwuk@zZVIg*jC3aorC_K@R1>q1?w&&RNja zB0AoABl=+U1=lxsFiwi~x*og7SgP<1IIs3z9mO2bmy<3S3)>jIm7Dq~;}2SeYbu1f zZMIUvPFr=pp|p|M`MvB9?hF5?)Rb$c|F53 zF8Z7Mb_Qd^54}&&q0g4qF*A3@{-4`@71JCk&%d7dmidK*cUsEJBi(ZRd~54yzvZ+z zF>0pMZo3sd-g?hEm`_z71jnVfNz42(G`*$2a^#s-M(ZCQqK*dA?sTX*q8Z#ol6&oWNeH z6M4sN6OTJ>6&A#3C&q{6U>l)@gyMjAY zNBV33Oivx=`IK@0OX<+^&!2weOUwD?KW}$WK|Mb*BA&TB#D0x>8tu2H04{LL9`DR= zA1j`brl_frrM?@fF3-82**#6dT_RuMB&eqkG>e;`m@V8V`MmI%>5z3Rzz|JJ`auNtTyzqa|<~%wJH8 zi(*Ao?RF$-KFM=^^Z zefEZQzq%x{BNPd^eAoPQe7$^ALb;U9^pf?nt5yz@e|e$UqPdge3;PO|%46m}n$XN$ zIx1bBLEHvWMuhJAD}|qiTWjNx53oqAW*=mKW4ju)HL7=1tC(!AdQqva9@e!M!h1=cVWHEcPDp zT*?&EN`6iH;rwyo%ZBumDg5V}sn5PHdfzr<$IrRmS=je}lmB$@bJ=6-$F_4na|Z0E zEbT09ES1G}ipA{C&rWQeSD7C z60W7zA)i z^IpicA=^=RL1#-zdtWez^6cle?+JrT~5kdRp-p;mF~Y|ZoBqJtM28lbx{|b zWgKk+M-BbI6QV~bCG>AOWUEPDX#+G13$wr{Jlf9wVR=taMVoC{GMq| zAN1pWhWaBRblTG@^OcW&-}Ym9=Hd^}(|4uk{2B{&*l#H-{O`Za{!t#9`R1V;@*$~c z?k;ifF;}waoLGkGyMwOT!L0IE!l6jl1-_Sq^DgOB^fX zOXfb2Z%wfpg}MA=@0A~I4Yh2{lrgqtgc^$vC%(&R5 z@hPs)&eD!Uay4r~TPzYEu5jm2(@j$U(gvzq!jtuD5egD=G2<)4>3PBnBPIQJ13MzB zvXrifoHMSdOMR;%)EoL)DpVzPy{E0;pHU~U{b%`%$tep{_IWB~ioTkWr1VM|<$qTE z#{8t-%fTnc#ppHC$ZSn>T#K5Yu*30*J0b5zZ)2Jzo)kfEw4d*YcW9>mEmvmM^i%$0 z%4DON*i>>l%G=Ldr#n7cHaJ4gBxh4sP3J6Wgt*z*5`GYz6Ennzd_Jfy$Dx8fo<78`aXC^uSC9k9z$nah zQG2Vu7*_(z)#qyN&_yK4kBXEsW`~aj8w3aV3VRx5l*^puDe!Z!XI}b`w4DA!-g%i# zGCKH|Lsh>#Ff%oeCo!W>YU9YH@EtBWI!zvwm=V3j8H)PDR>V34iH_$Mw`Q1efdewf z7-rn3hYglJ4j##SZVAQi+xSndg+z%OT0^yrIzI9~*i&s5{w3^G@_CO1c6ttF7R|Wg z@nux>WM($=tj}x%&B=DpPk#?jz05uE;~dSf`o}63?u{IDA9Kyh{x)X1E5ExgzeJqR zK2kQ3yXpz}B{O`h0S1C6|;s+NMW$i(GSi*N;5}S7?srI@a#xp_-=TL@?5#EPSx|FaFrbS z6xIXM9SHA5{Z|y{(lJ+BE;$aAoiL^d!Ktn3RS8DL?OqI~Zov+_)Zb zQbJE>3r7vhHO?Z8V;0jNW;SDpc3pQzw)(DQO#kt3ikyDdTh4G1QOptrSWy0>(oWO%WT)OK#8lqQ*$PL@`38M&O)(w^Wx z9N#xr=iJFT3MUkEud@u|E17onGQ4*!BY$a2;GC>s=BJ1BC3;pzJtXG09e220d{mm_ znd7>=UMj(#FmJ+ZeF$o%B=p>dpiBD+B#+i$rLbIe@{HXoY(|F4YWB97Ol9b(S}6xY zQzDN;Z9+>Ti^6Y14MQZ*Hh95{m8>sErrW#4^Chq*Gu^*7*f8yd?@W5)kK113Ta%2H z>FqQBFj|{VIlp_2tA1QmY^`X&E#10YJk6G4Vzs9F#Big~li)tD=-rkvFJq=>p}!#% z+E0YFwlUmW#mlFUM{|M)J4lQu7a^wwT7>vhixI}Y{yWi!x?9D+EXm?;$E&gy{Uavaw*rf`}#3t)I3GXG{pyx%`$Lj*2nP#KJknJk9bR$o)B2a`qze*Ik=!r3Iw)D$jyD ze9wHN{F?%QhWe|2>881XkCsMQ>pGf7$sVw^F5X)olr~4Z%opcsGF6p&;{9%zWXJxC2-4M+yBMe z#>e}92}})_0CsCYxJ6)cFq_}!DdHpEQ{F?8d6hUbY={XN%})qq7!6PoeX>pWNoQBEfuyI7VD7 zY~*^etI2VC1igTABns@5Eg%`JByl7=dTKY=cBreZma^L>*e5%;N9T&F=csCzEtiGE zoQwTx)}!x@ExL$o?TO&&{4i!9Z7afjW#jm9d_yE7*TZhggIt&xb|~3H-l4vo1Kq@1 zpy^PI)*JQh`UyRvPcUxN0_a(lME1u5Bs{I5e}U#P(Wr$C`~Ao%Dysdb=7D}_tP-tG zR4S>lT2VE-mIPPu3+*iYS6#JP+BR*O+D41ls;L!`23S4Ukh>vAr?wE=A_Z0Wy>=gj5@ry#>O@7$|+h=#jbEu|z{ciH8{h2FNj{9~MljMw{W0k9RHC!tLAQaiwFb18 zCd>qi{u$<&_UIanH@Y#c=~QzrGBsUvEOmo2)6Xntyr%ykea<{Fi@B4Lyvzc)tngS~f%u zs>Jqy%3+Rq7JaigXfyhnmq=w(WjfLu`2xW@Nu~ z#FDQV^bI`uo1#7omCOaQ1FC|qti#C59EaW`H@aDi$PM}do&K><{FKHqdb5~g=w3vd z?PwURun3YVqd;x)U{`q)bh!Fpg$3y%bPKP8qjQ^l z_f&-~gC5>&`h)q6%tzwRHad@l^|kC=x|ns)Ex6PEV0s{ZXa?JwWHD9*eMIvazo0n_b{X=Mx3PRvZThv265k&$6FFVl{Aie{3T z+ATJuEn`>Wn@L7}tZ+s``SRX4Lv-T_$Y)DHc&|(@8E4GvBrp1m7xmMm3p%b%$vEWU zJu`-Jm*{DBq47I+_i? zdT*i_C(x~%i|dew9ztg|>;A4wJRoqbprhGd7*!UKEzoV%M%U0!MRPGN2o+2|b_(r* z^TUx%%t(DBQvo`vM5Z@2aE1DTlhT$%Yk%|I(PK};xjx3`(n_%|I)}lXihgJU)T4X( z58Cg{db*R%gAVI>=tUdok4P2N_x;8Syh{kQt$gfk#!JUD8*y(=L$~@Ne^mPd0@!}$ z7uJnuYBjAuZbM&o<`&U|Cf9w}KP20~${^7<}8A6-E;hPp9)`(Pd6ItD$!uHeO*> zH;H=+MN<_pd8#u>P)_VO+OXZA2r)n~8%6Fz`SveXUKfpGq&h9kZ`2lvNt&P+MThYs z?=y0vJ=_aUVH4%aKMs#T>%4L#E$as6F~}*Jv+u5I*AvGnBsPQ_z7st1JRjsW~x85}OGf z6*WBcl)-ZqLDjg3JE$*u4~aN}C%Ch#vPj>dM^HWfn^ZItSPCNGM0y_gCQ^RT>%U63 zK?7C-6@~m5Sq?%qFdI+JOlAQz0bR^pbP_5*$ZTi&k*T)87JB!Y?ENS@3Pq z{jP`1f|?i`zCsP9qhlY1?zac}l}^xdjYGdY07}~oXr-2ePFWNz#0Ox6=f<5i2?~#R zsPmH0GZ_tq&0)MMOKbK5W77(#V@@K!s5G=(N1>i@BB$vZZ=!hJcDU0AeN`N#} zgIS3&zBzJZ(#>_yRLn)ExethzRgnc&1^F)>@d;V%*m2-sjzL#EFMg*ER1?KfE9nZR zRu%N?`)6q@kWAx%wyqpru?(LWgU@lIXP(5Y#c#%8^vH*_hl)^2v_&_&IrN4VkuLEF znz>kX&J}#-clPc6cbk@i}@@=eO49<#(Vg4430VvzGDt(BJv@rr^NriZ4>-=9{g?`juZMVC=&AE zJ163A67hL?aEt!-cSV$%nIN!#8UW zokLaJ9p&)N9hm*HZ>zj_?px7k6+kXEBp6LuIrL ze6%^38;5}>HyXOq5%{kOAoiUC`)m{(VI9z&?1bm*H)vTpL0wZFB+6>|w+Q|>r&;j- z{G5(Z&h&-`W(?k^F?5*87z4I|s5%9IyB+$|&CsTHh3=&un1XYlpc@aRYkjC_I)Z6g z>;LaP=KsXOY547JpoX>wvtv45vl8#J1Y9BR zpZegP+Ta}~;Z-{@8eD_6DodxU;%F}8^PWK^dJ(S}j`!M&cN>Fm&;mc*j=yh&qfQ3n zC>>*2PaMY-Xf=mGW3nGK*E9HpZTJ>g<6vD@7`?)*9JPC3BVK{B>M?k7qoAvujk&lo zxTHIfA=MwxPont%I-9N-7v3<>F~%K*;<6_bi{9iKXncD@Ww;6Yn=0fWX4d+6YC1sm zJ_FBZWhN7U_XzqJH@N~;;{awdR$n#Qb|4vECx3&@*9#2SHDngDXbm!v?Ldyuo#-5= z<4(T=_E{{>#UALzOF~I%l5-##IwVBOEf3u%@Q!4_+zzrm>_E%cd5Bhkf zruH~ejJ6?v(gKEJ-q)Ywd7WbHFq+|-y(fDZ3v_j7*}7zzDVttB55I*DH0$cfX&?=t zw3m=!+QAqJ{c1P5is=l^+zi^&=#H9BTPPNWLhU;O8SpmLb&#|P^4n*_WfZ5S%`0>O z?V+!uEATuVg;qX_>;%zcK^9HW?1g!>y!ndB$$dd)A!y)G(DN8$df-X#N5-+WpqPwc z)QiX1)E6q?hRk#5cvvDrGvFqV=?s0dQH$;b8?_qdV8+Zz|J7STFEkNn;TE*8@pO!~ zS1YI1);{Qkj81fqQI(E`j;E?NSZk)Y)qP-$x{YecW}Tzg(_*!?Y60aBbck#X7%XmuxIaV$ueHPb?tAruqG3??){8-^JpIt~42Jk+9FsPyhg`50K{v5A>8(M?c zxFSf!PT}HtijG5n?mg>f@1W{i8>_)`$UKdwzZh5b{rUjZ2-<ed(H+`I$jX&BTI zb+rt2lsZt^5Sbbo7M>hF7u*-D8yMw(<4w=Zm)Q$d+|oZ^rg}4OhEmm4Yz_G@+eG_4 z`IA^oH~`vJvTaNB3Rk_DQBn15lZ3LMh&i;L+7A2>;LlW2GJJ!ch1PTK-Yae12L6{~|hHcnW@e`fzB8<^f$ zAvnl#wgNv_Xe)jdKZq?wR#b$=;x#c@8ZA|mdWsqRC$0l)Wt)+DU{NJuX6%nADOt~@ zztjHKzG%97N3EdFP#P#jm5a(_Wrg}meWU)WmQ`$uBl1VMR`_u!d)O0P5h@gL1ZMeu z@iz3#&Ul(}JMCjym7iNvE2rf7aVxF4?@_22Gf$e0s{RFMUHd5cmAI3~@mgxw-#ceT z4YmI%7vy^CJ;ECUJN=3NK7rpu2f`0<7miXmb%%OR_Zan{W-QAW7owy$meRIzjzZ26 z&RLE}&}zKlJFqI|yGodyFx%3V#&m6yQYcb9ydR#-S=s@t@?R1!RDE^0f4GCFB~Os6 z+v+>ZMSqFbqWU;b*?P(rfn&cx&-fgxqd}l*V?O}D$t>_tzf%_~IA6FeAosMAzssxT zpOP+40`F{^&{sIm-{cQ-ZY}~B;}q=GyVFX>0OJzsEU9{Bq#nDCfAxdr(b?6$-P+5tNB&#hXBiA7#b%33-T=RNerlE`3}ufTh^$rz>*JL(>S%4Z9uHOg20~ez*h)TYt!!ImZ-Hw%IVv0# ziMr>YwyKs>{4%ik4r-!m*HWN(vt!=wVOGT0m}+dL1(0>WN0=x{){oZf);}zriC*7N9R^+YJy{?gvyiS2;@U!h&mj)3<&4gXb1yMZ27 zM}3tMLj`)(NYH1i2SLL7HB=`ghL4ALgqMWl!h=G;f`eAczutG)dkl5GHQvhJ@t#rM z)xMd*mdX-ykMO~oJ&KEti5}qKt+~Zj`~&VG|DU)>u4K()*(pY2f0{3H#9z>x-_zLp z+IJ(^6gkk93>|y87o-=rkADwO?gcp@-?UV^z#xzVXC=pe9856m)Ek&L6Wm?r7Y_*jMh;QPu4)q^|r@sIm{} zlk|SZ?-c4}P%qoT_hX^E!2)ogirkk=7spxtv9+;}u)nveaw}=7FqmJ!%>dtfI42`L zIwy$%>9n9x4Lr(p^_kjSYp2~nl6yc)!|eJyoUs!z!{*k5`Xy*%9iYxGpaU>}4$wRb zQJ+BX)IQQETsmA1K~ar^r9$)Jkenav9DM7K^0_^G(lgQvXNbNsp~}X1?wcjfnH+s2 z+V8k$U5U_wAJ~gdMul#FlnTPFv7Clal!R9HaR|D+&x#S{jFG&G}6~Xf$%JRM46xuGxsnDNebLOM~v?J zWi3G`bS|@*YcGD6=RnI9ZHu+em+y*mh3ni%b_zU~F6=})V};%g>-IQvopDO9qeYZo z)b+|(%$Nm@(eR#GppJh{|1+i`OT3nTSy%Ls@dZrzc62?xtM^9b)CebpX9Rl%mj*5d ztA^Czk;t>q>2S}`=s?8#N9K(5QE5u1=2p#;2tUCC#XLzpae zvvqf-I_pQDwNJFZb{(>ssD^3Hcbqre@y!a7JdniNW`B%bS zsk)_vwKHb)8`2XYKW8BZy@yp)J!6WI0i|&xoXHHcDV$c@k@x*sZ=?s{1U?Fn#ziIv zdjW~jjY%V>sM(O72KnwDp8WRu1HBIdeJ&eD*1fCljYLP5h9-pW1Tq2z0z071+UFbP z>*IZ!ndBMkE#WT_xkv{J%^l5Sn#cRHeTggVI$>KZX2Sn56)EY%$z-+{t1zF9R;n#n z$ag7YTE@uCPTrRRr;^jKkRV@8u5H`r(4D$-zte6PtbXxteh|9)jlt#X%SExPnSha` zP7V?O3(r7LPG7;`CS@?q$xRSPT6$aWSRGc$Qb-Pojl}jsQDF|>gMY$hu*rDcL$fTc ztsl|;Q7356bWlst8yG20krL&TQXBC9PI4@Fk1fN!<=*pqgsH+fVJJV1ODBHpUB`i< z_B-}>G1zlF02R0iYvGoFB|Cu|#2tmEGJ#CQE_o*EEj2NlpG0r>Gu&L~jPb@#T{C{6 z@8NtqLmGoaroa!dk|~bL>szCQ5e-fITm5^Ms}wByUYcK>3F6XCB|$9%_SB9utN;^GTsZ#X$v z3p3doer1DyMa5LhV@F){OxLdHe$E5dIbv3Qu)5w>J7|3-Rhl|Z~(&0mmqSp~-;r{V1AjJ0o-m+=#svwBB$N91ATPo=qf4=dj5 zY9D2Y5)SuJB=xm+i*_atxrfkOoU<2j=5@up|8Wm;op2J{E5XD3ru_`B49*Rxfe)c| zk(#OqwtN;x_67LZKiMAKajZ7SgPvdrH^hO$FfpBf$n7GvRmx}E`Dx|R2Pj-Og8pU(P}+!?-SKEYDiQL z%yC=A^{j0Eu67R}3kd!=za=m*G(kD6FEjrnyV+0N8LkEoRf!NMJ(TXs@z!{&Zh0#| z7BBMINs3WQ?HBGHtPyw|NDr=!tkCMxfzTqK6+Vf7NxkK7;N6ZDmkRIrOI&Sk0sE4K zq0!1qbLzv?!b(h}O5|ANqjFHYXS89Kas#D8w)~F0(QTraMyEJlTYi@Yqe8n3)RY@R z!`B#^(>ft5x=A2>cnXHY@5qwgLbriK zeBLNX3sDy1J)S7`DLm0K>P$_IZ|Wg+QTT48dW2J)+D-MIUds3Z67F;5L%3rom%qPn zwI`1+CUj2QM9x^AIVZ&ojjIxK&Nb5US(?iZ(qD&j`7PeLp5MF$gCoL&!LB(1O?v~> z87{yFtH6ajj4jKz602GU+FCj0IKJ9$%2W8<%yjK$cxh;qS*SM|#C4HW#UF&eRL5H6z=?*Tca`S#6(j2JGUlVn|+Q zO|`DIK9*;Qm$@QTJ(%YSU|a5zepc<_D6GSe$_=fpK3>~lbVPzl4kM@~8K1Pt`W)>qb*|bu z{5rfAeB1WnJfZ*8UyXjGsra|7g1$Q1+1L3NtKB5wKc=eDSQ#9y5u6cN8=M;~6{)V2 z(cT&>Xc^`+cJfm|{MgGS2&aUO(xEI>ywqB%A#~!NFmsK<+M&p&P-?Jn=tJmElBJWivrV=2vdyv_ zm%0jWlG7}we^Gm?W0j-IV34Cvqkk}o*@F)C0?>k3VTe#4qhW7816T1NKbEhHdu~6e z%9J)28m*v=sjQaMQ`MG6VbI5eWFsdE!}+N~e?FPphUe3dT3#`9ohxAmo^Dp5g-t_` zp|90B;4W>C%mtrfJXkzls9|#H8OADOsj*b|C`n4`P^SO1@3OajsCuNT-jJ;+{A#Zp z73~@lbHcUKxyE`P{f*z%U7;2J2cF{4_g4)p2{lrpkTH{v8l9iHf)u4q=zO&xmAIn( zJYk!3Px>nDmL>_WxM;FlKd$TxCIkli>icH<#|En?z4gI=V0X_j%|`VNWCAN|Re zVgD$!fv=;Zaw1^yX7^O|T=iB96c0a9=bPzZL2a?d*nf2#wmr7pw5-A&elmPGxxluKNMA-ZZ*hIT10Dq(w&O;k#6LObX;Xg__5o=fWj6@PvIe}Su+3w+Nr=6VPC z{1LnUFKI1Svc|YJx#ni+mBFXoV7($FvBiuJ@Id79o%Xi(N&Xdqa-m@um;a@!&Ai-r zzK5U+y!5YhNA^g+N?oN@LU$p;79)MkQ~D6quAB{%$enN~oTLuYrSow!)*C1FZj!37oDi)j`Tpl0FenW@6v7cNOfP$8OW`5-s2Y?4}u z`NSpY@kbGc`Amz@o1otG(ue40jT`j1xr@wUkMkqJ7;7r=(s$7=o)Dh$2OV!YWJwY34owT*#YT+uu>Vf#cG0d$m z0vAI=LVK0Ns4FbfHbB*KlI&nZ;zc>#;&9ZmH@2;`oe?Er56NlN(^`d_1SbW)1quZB zhZ7W^Hl6-K4sj0gjF?lhLY+EB+%BEw1tEp~6Kl68=5gbtk)%}zi?Fq_QC+RhQFrS; zV>?Vs0?&eyb|GVq*~HI%Ry^H%L2<9>>gVQ3L9(M={~KJ(g^ck@5r8TE+ns<(OW;?LUorcua0s+(W`f;SLhm> zRvB_lVf}l3QS76kDy?cB^S3@zR)rL_%unG%hFv^Ta;T}$(-_Wa)-o$ z=11ygu8CP|-o?HA(7?D zriZ!_x-06-@=!_dujMtk&pB5+gASjwsK0XPxatq%u!uHs3zOtx%S-r*|F3AXgh13W z^D^yAWq@~|v!v~2ex>|{`QvRLxn{Vl2l`5FRF^b(Q@!zJHXp$3ECwMLDGDXZ$?kPGn4Um*~NfOC$cWY%qUo>|`+No@%#HDOI3m z-iADx)86jkZUWy1h969=&Hb%&qT5HWjanSh)RJt;)|b#$lJ*A8f$`o! z?rQEf?$Pc_o(A4=!HY_O`0fz(E8>{xx_P>Ty0vd{JSaJT=E zcZqv|tF5!IeJGiPz3j5RyQ88#-u8*h>G0alIi&pe`R@G7*}chkUS&VvFBTr5`7ZKe z{E?!wQ^uESm-3)tW*du zrlZ;}ZFB7?Z5_=fbq$7KIyhN57D%WS&+`ttY6thrNA&wF6{DKR)Qb5v`d#$ws6!ENEwxRb z8n)=#sh^0Ug|&Z+2VTPw743%ih1_?S_GDbmW2P2R;ubTk)xyTiFTtdL0?So(M{oG;Hb6@ z8H+XKYhj%-F%aYX)Z4`Q)}C$~mtQ}>Yu^0a>v?Z;D%h44^l~h)9d(X#`dp8lYYOrU z=I74M`zyPAR_DA2_LAP=%2eIdNN?i8V(QX0QsYy;D_JA4dgMp?Ey0WSw%K*lJG`#& z>d>oB@AiB!+DdxIgh%N*nTJIt#N3Ll5w|q9LQIjUC)Tl+>!yDVO>|8(^;N-ehtOZX z1kYO{46B{%T=ji5l#S|WlR4U!bgsmK(ih4sNusmnV*t(u3(fcVCzg~aT8n6-msv0-uk>TAC6=-$?2H2w_uIu zg0HCdyx}sfi=`4Jl1G=yEgo0&-iKf%r!V}`D;s~ z^}{ShqYuS?8+SJ@I_`SR_t9q}`k1F0wRl~3rRJff{&deq=M8&Bd#XL&dD2ziuamzu zG>?o;XkOxK%7=1O%70PrOzP_7kBXd$jxaV-jrC7+zOjvF^lWe27mgj?E}y>}7V7={r z{?B=r3YO$$7L2t;`ECl0!e)-)7HzckInCuY10kmVPTc$rkd?4!_cVW3C%jFK&AL z&G?gXnz-T7?;~$oRvOpon`(-v{s>nLRr0^}_~<-%+g`!Z$yL{LCa70eHvSg9Hc3v- zPCHb7Sj8?CCzf@UdRO#)w9nW??&Z7Vuoif7GjnGb{Ou^^>!?)Kyf?&HQ=*(kt8IT_CLo={+t_KCiK%-V!? zN#lx7FVZ(j75CgSN@okH+_&-n44EE|C{{IxhZ)| z^K0d`vaPkf&2Q~`Q?SD^z?oQ3%>FafpOYIyJ%*7=jZa*|0<%I3bVk#hs0s-+5}GGn zO!y|Ic2q;-Jk?LZKCXcUZ*s zyG&}yNlB|BkYnKt&kn~A`SbI(=C`-2ynE>-KFWBNzVBz_FDJ}Mu*SWK{@2>jl%%() z9rSY9O8>io!MNb6;90cz_f%Loyx)6nx|g{|W6eK6uYF{jNVK?CL7)86c^?;C%BxZ^ zyx?Ho9DAI7xNVR<$^K2jB-;z5M3aKnxs!8CINCYedK(cRj4}QY`Ds)_!rg@H@j3B* zqIa2B=qpNTzD=%zyrVe_v$lV@l^vNs$3E2K4+Nw{?R~>z^DWC45o_t-`L)GnJfgp> z+n~{@E2t-vEwn9MjM}6oK|fvCiS)QHddB#!`EM)t)zwWMqx}h=lsJ}>p4PuynX=DP zx|ImUkF`2A$AZmVsrfB)QnQUY3-djWM*g+(as6TIy_hLX0(qFYAbuVl-s8+_LzKoU z%~whXANpG`CFvLcl0d1z|2X%_+aG1cj~VVm8?DE97#os ztSTB#P{sZlwM5r49N`~ndzD-KgZ=&54Bv;Za-P|I?g_!#sy}uAm{wY&qAo{{j_gBC z+8?Ib#ud6WtxoHwgW?By9{GL;nAx|)pXt5t+URWNxNRTr{Mo(P@0X_Or$_#nxT$1X zT7!yDD($P}E59aXUa`6{Y5HG7&z*;I@-r@_A5M2>?99oqR}3svZ#MOcK9KON$S=iW zi`kO4#jlOdv@AD#thp!s5()&~5iefr{~xp4UIpg_chZ~GM5oEw!KeP;{at*2dn$PA zxehr^u6hp1`9J4U`z2QuPbFtB??Cr2bQEpwZ0p=%>sqiVr+)73%yl0gefT%)dfrd= z13vO(wfczC(Vn9DMdjiTieFBe5;rXJrZzvk!8_QtCC8s>eE;Wrcjob|KMU?TzxTHe zM{6GGFPZOHA6xAaWg^E|Pf+7vH~e52q8m*Hd>Q#p_;#p7pgsLew>cZz+fvQ^ykMNY zxO10h6?0{d8Zx8&@!d4<2V>B%cUv)Lg3uPpwzkbu^ZHJ70 zxd)wlgH^Q4ExOoQNhOLuDgIUQGetHgd>Q@0GE>)H-W00mTj;U6Yr3Yn-?~To*8972 z%6OO#h$rPPs!n9XT_EOnI1k}k%-O!ztGXw(dITis9L zv%Z$D1_cXq*JW4Ao{`fgPqDRk_x5*GtSYazu5p+-he(Rs@`?48d6Ic4)#mT@({$ss z%hknH+a*!)TE<_`dzBiO-p*Uj39g!+dj8G9dh%<{$EI^pHR9$b-7U7IWUta!N=+&` zpjbg-LhLciZS5X~GkIt00#&X)N6IzkXWOrO_J%%J%{TP4$}ywkFUCKPFB1Q6%&(En z&13b?RHuolruzGOZ+ptqx2GwUZnuM{$Wa#)a^2zW%um}F+(a$Jb6vX=PZ^KHS;@KG=Fhv8J0n+}b3dzfZm+Dnc{}se1skb9GCAhBl6|dwR;3QT zFZ8Ar7FR^i*h#TN;{J&H5Z%$*z%)%&M``3W+9%~q&pPzJS*A0ie)j+J_B*`PwDr@> zH6&V`k!536$8?RI7u_qWAR^WBtFei`o@Sh?x%6YGVW6>pq-U;MyC_f223TciL0W{gvyKyRP$d=Lq{0 zTRD4k+eBNqs+qN9bQsl)5|(LLXL*zXKoQ9ss2Sl>p~jp-Jz zO?VggJmx^;0K3OQ@Dw`W(#{kY(mW0<#^@`>81>tbmeaW%Sj?9kYC zF(2p?KGT|RHX2*#25Zj9ha_3q5p3-5?)%Am#XZa&a8-71^sM%t^A}Y*N>?exLCKHs=idG&`G~-8!7r5?VeIGC zRnbjK^-3Pc+33xXHdrrk##hZ3reo7HZzbPmUz~54Z#9$pnt3yjT=yJDZ0idS=6A>s z71S?SSn#uBy5n>Abmvre*cnF^T@m{W*AwRekIokgERk>PqRk1`BQc}nmd0nrUx?lv zS=qcmdo{elf7$h`?OERMxw~^y3wGEnF0H?q5>Vx6zcQ{c7mfJMI^Npel4i~|7Bv*p zUDnjlyrQn=Ewgg((ZRK!(k=8wu#~@**X*t1S?f`Ib^fuz0;c7?(EegZH zyt6EynmX&-Y4*$Ba6=}k*CG$yqCTNnt~+bEWV9G>7{=%mwI8c@(jo9P=LVBQYv`&v zR^3zA)R<`+X000W4KoH9L}%n8OQwM4$yIrVK- z2`XZ)hwX}u47k<7S-}h@Ed59(M907m{}ZePo3}Aji|V^tx{f&ds@5S64`UxzGULt9A|%bzE=remfhO#4RXZo=;_fnV~@t}iuPL{ z8yo6c$nAoExcfOv**oQ5%AZxx(6-pQ)sqg@ZkC+t$+~m;6sEcvOeYLG=>JT7np)&f zbX&>flGUA6C8?b824DMccwJCoV^^kgzN@b%!oM~+Biu{-D^5r*JqMrsf_$LRx4mrbn zR139F>1ln@a>x3qwX7v*j5F-e+?7j)p9WtNm3-^Z3{(OJ)K&?PL(ElrN!yn*GcCGKC@3!|VZx@fw{ipLar@>383;fHm zz%|5G#hvZh=4;^dc}II&c;h`~Tn6V4&a$pbWQIKRd4fHq59*+Jv<>&$D&aP_21{aNQ&_dlLJ{`$df%0tPd3aDReY}(nn2fC}emU@%E95Wd+ z^)2;l^o{h>=@@pKoP}h0u+k}5fR)_NJKi(Rt9UQ@*9WVSx%-bsukUM&H78g$nOB*c zn${Qsx({?)8%j>^P$J8()D^W=^#=?^jD1bxO(zZg^f}tss)qEdYfP5qOuBM6C!^(8 z)g0|TLmhJkYaoL1(}<#0x2c=qOYKzR8)cZ$IVK6Lyacq6znxFonZ*dW+2(9$3FeNA2LRadn0s3XE2wq-J@ zpqInqsO&1@`qewo`_`N2b$C;~>)e~%OWhxPqCNTE*?~(+njBPj)mJhVwZvPSSdUqn znSL;grMpsD<+Oi=cb{vpGsiL4dC2vb`v}wfmecJxMO{|gO?N{Vp-<8;*S*r#A!_=& z`hj|bx-NaPKadkqpYz56^ddeSsuH~AFYf==Th9BncfL;*XdLv@o3*{Bvu=~&891j0=4)a$AE% znJqh8<?`G6>`8RDV6x67$3dsf zF#+oQ#r?|D-)kiI@tU`a?~SL2_czZwPnvhAx4{26I8rI0I;2k0moq*#I?Y?nr_AFF zbqqDNJyeUsKL%I%lc=1X?QY?j?YZf#?w>=ZQAH_NrenC~jAo{GqPCURtW|3_X;O$m zH_{}KNjO>Eh2Cn9U_0`hFHw;*H(Y~W ztG|a121^A~f^Enut{x~8Snj{+E9*bt`_A{7x2>nKXQO+*>s#j$XR_0R?A+-%<}B}Q z;=1F!;u`Ck>HgQ9Nmch`U-!UeCR;LrUGtZAtv;KH7In>wO>0dF#&P=QnzE`2Qq#~n z{|m3m`_!H4KJV_}+3Gd>*9Fdmwvrv&M%_tsQv025u|7+`K_Axd(vM&U>QZfG?RoV{ z>OJ?-qxbL7s^Ia!MgJLJJ^I7k_f27rZk3QWtl~6qr*^A;ow2kzXz^ONL>MFAMx2Vc zY&~qTn=%Zyb$v7s<+SkM!58GZR`BL|7J2i04Fdm$vZX=l#=1p@$;@PUWqn}Xz|_h+ zrZa{kx>cHNs*SHJ|Afkgt_CY|nsbkSl%=WA(#SL91$6riODpLZNv&l#O*s(S82Xyr zo9ZDYxGh*cbS1DMaL?bxAN2Y?lRcZ=pSv?$%bgjnNY^CSTGt|1g8QYbFTHFIyKcFA zyX$y=^A+=_kgKplrO_Dm&5b@&mbnpG1H&xQ#$WZHYhK8vu*;v}o#`&_+6FZ*aUOK7 z_Kf#!VOVZHoeNQert7UMs+-Jw&J4{ZO$UvGJlW4xm*sNQM*bSgz-nyf zYfJops;9YEkrCL4HYgdt2!Oaqsff)@-JhV+c9jKPq-5ivGNzpf%VwFxfW53c7#y%=e9QKl0!6zF>Nq?0w?v;7Rv(cE)?xdro^aemy<$I#EksRhMoEn`WE; zF=bm?m|hzb^rbcTRd!`qXoG*1udnwH&vZ|!x212Ce`a7fnXiK-r`(s*{H}DI?5vH@ zP0$V2X>~`mXNa~|){fIurDFMke2Uti-sIgc3cL>74;~6mX0G>Kby?m23=hmdT6ae+ zis}~Sj(lp3Hh0xcP|pvK2!uUjU32gbt~$SR_xFW^8Q~}DI{FmTQnSzU%+kQ3H_tL$ z(SEP4PABg&oVLFS{6$|xgEEf3+MlUjt14jI#A~LhuQNv`i>G&1I*=vs19|UjgJXl= z2PLw=+`%xKBqlg1(8M2y=YJW0X^JP^l}Q%jdVCtI_g`O@cbjjq_q1=NXNjkmcZg@A z_kX_bzOKO!RGvOl)zPNwpBT=V#+bFHm!>*~CCH4!a-Z;5!M*;o-pU@UC)It;J;wXk zYYHq2<}2GJi5c0iHSaiC9Ibn!yR0jzd#Amw9iTm;xyaO#x|&4l0iLSWn#1Y{brtmq zRR`5uW`PvQ<>jfW<#KCwwnUXdt!6!SU9!~g$(y8x@EBzxgn)I-4%kIM<^uUBCuft?x5zg6gkEy5UH)}f0r39 z#|e3F>SaDCx0Q*U{Y?+0D2>RS8A5$}U3yc#mtV@|nVJ#o#KA#KI!_JsW~aG-Ohk3&3T**uzIfXj^)lZj$w|zj=}bZ%qCp!P4d%c zfbj^m4gJi9h;JijL?uPdifUk;X3l_GO30Ig2LA(BlB>QWWItkW;7H=Guf0$FokQP+ z+skiMvD((UTG+!?jE79yO>fQrm_Jzluxum0xu&@}vx5%lhB00KH+=q5q0xb#{G0u~ z{po>>;0CDeu{>52(wOyr-B$e(T_xQB?O}DY+D?U~N13l^m1UtgWp~IPI;X5x3Y5mw zf%T%6?{{j^CQ~C8MKzy?3ckHE%2xheeo9{00DSxWpep1;dtCAOTn}B<96g;gY|res zZO80W?8lh+<+NFydu#)oU%J18dv*j`%Of?V4UdiK5eK5y#(WmNIJSFqWpX5U8@6j_ zNPh=Q`u4eNINkQgcAfpMBiq^A^TJz(u3ZuGM3q%LT0hoMnJm)5^i!N=xo=%#&9`2) zTxA|$wvm2Qy03_o4WnvyOK2>oWQ{_-Lw|&FsPitPo~Jpa{X@T8KS94oFXKidPa(|>JLKAhz@&{Sq=~}fsOxf>W;;Zf%8|V@@BcY)HnczdT%Msm zqHUtzYj|x4nns#yG7D~uWvZnHqW(-pdVFe~&|uzV;XbdXwt;Z&SF_LL;wtF&{1^GEydjup0k_I~z!+im+E z+d$VRZj~>~rv!(~x72^@ZW{KO_gXJmk3_7G=w#huZfzjCtKK1hMg{jWe^qZSuh!#p z|A}nP_098dW#&wrbVmL~y+*f7e~Kwy{VbW5Io5&JF%c?jgyp%Zw{eGlxn{T8CoiDl zkctMSb?C3qC1o1ZV`s|;)IX9ha)RkrR-K}Gpiya-t1POG(o3mkn3*9=8Ol(0DMiA4 zlpe|~Cd%vamgmr$H8r$9v?Nqgq5nZ>loAg8PVU>};9-ANRyf7;3eB?!KR4Cc#y-`t zf^+?POhYrfAG?lwulZu=+CNs-)441|f7!Uu>@X>ov8KJ|=7t~jrL_a-)>&WqKG-EN z+PBAd**ny?%s0qCH_$)WmY(Uoq+2RkeMTG9w$<-6TMK~wu=O}_-+1UzV^PFzD{0^SLcoL?DviYZ%vr(ltwSTByXf= z9llY2-wWR}zd86n?sibqPSwJ|+%4*|UK<)vyYXC8UR{qGj~>&*Y|b^{%CEfd70}?NViv9e9NZkBu%-JFt1|QYJHG-qRanHoYoNP^a36tnuNT zm>g0x;a1F%iX?aP@6h|8Hguhs(%Ilus`?HFMh6-NKKPyfY=3kAC48_O{-Cdtzo-8v za{1npvHzzx7D;aOH}wq)y!6jTYV?$9%e(0t?bhrdvm{4%OZ#3sP4k8f^fszGvQ{2N z?OrTZ$@{6D?!nsj%N5Zom*s788C4Wo`4%yN2s|0L=BW0fb`v>PKQOnONVn!IjY;zb z`)jWnM!yxa>V=#xUy=2yEcvc{k_mUCv5-R2H2PG0K)x=<2do!v6~3zkl^Ik_R8dxh zoS}I7Urq~VQUO#s_$shA!1PoqfSLu`1X>5$272*#|G=*S9ldI1`8%=CeZFb_uYEoI z>wW+F7x>!-t^^tczY6^r%2fPHBWl$L$pht%s=9Q%o=jH2cBYLDQ}raH|GAtecPI0w z6;sH*miNf-krEwLJ+M5Ss=j1G+wdX9gtZ}9#)CC=G@omtH9oq%uTd{lSEpX;rK&*n zShZZWLDgF|O(m%+W5HaMS7YthXO}z4C95k1s6W0$4f7l-mMez6RM-rm^UDeNGA6X1 zKJ8tCBZ5tWBgvg#Os?irY9iA3B$>SZV{|ua6RZ||O7-R1z7xkI_~6 zD|AX3NG0egbo+B94DUB((!naGoz#P#V&xI40%mjFRYj@C(V29GI#%6_-Z*J$BR|#E zwbV7p5gnv1t6rimNlsK9s+9++Ls%3x;*GcYS*d=ZZcC4y66yiiO&awv)nV0mc+sM| zD=(4jklXzeoelmXw=;p7@cqgKWh5Q18!E+-K(SaE!BCWPJ#;kmEz&lUI%1JCyqE05 z&CEBP5UdpZG)TT}Fgo}p`K-&BTURXP3r-Hz3B5vYRZ$)=EhG>crbH{llKU1U5L|UeB9#i(x_j)_^xWDr{p=_k8rLXdhQj9*~4?^j5Y5JY2pzWdcp?jg+ zkOz;=L>J0v;FG|0QOX$QBXSL=L8a%F%FMfIP6l%)WLq0Dqrafyemt37=b12+N)M&s zV)I^3r!v{MoZP@R4Tov7fS&Af=y)b`&*I^5*Rr}*mvRTf!v z%PGqwvsUzp=tnn+=JZEk>J8c7JE;8MOZS7D^bb(NYN;Y?Zj9V&!j+Y&;4&e{!kp1~ zc;#XJ6mU%@QzPz#pU{=yJ{^Z{@N)%NpC`ztEJYchi&o=Vf1P89C4Lq-ECuD|5^F39fgbhK#0HycYWfI}lBk?5iE8TX-Q5_WDKUY`P^ zigd&%&Z`78R+3jN^S@%GGW53A$uN^UXNcP=uOz@s5!-#IeoMX=R;~s(fTUPrHOuVH&=LsDr?ot8#1e|b4VvhkZ$^i{O9(EYu3PunVzV9#mUG!hk7p_yRR1a*d ztlr4)VXhPJn1MkI`%dCV;HDHfN##|JP9`bz4u}E{35B1F-_5|-z&)(sDVeK_0;71Y zhyy;Q_$gKhKMUOk)LbjzBYxBHn}BaTpCJr^%Ke+$?X-r)Bpzc~;2RcDj6|04pmUS@Kv(JP;`k zPF}$455aCClphNPi>tHgF;tmIQws1C6sY5q6407N-GURWCGlBRpkyz6k{fwkAI_@H z6)v7{r@KoU-%sINrb6gM1GAVy8jNQJm3S2gMuGy3VB%xA@ME|;4azN5h^aK>R3*6h zBi=0z1u!6qJ~oMbw>+yZQHbFie`f;im;66~q|>lV zC*R594gpqTER+oz_|*yR1X+v1yhJnKQ?W1c91cC`#v`~;bi(1#gKrn`ui#CJe7K+B zb3cCtnL=mbjuLxt74mVI_x|Hf!IL3)RY)Wwu!sfDMcAE?91(>OHu0{t@bAdNRYery zurl;g4o(!(D4Od<&!=RdC$N$Tj!MIu#ra#nQ%F<6JAz}P__ur^UK0vACWgD!D3rZb z`HtYYBHXtMuQX^t$lSET>jjq!8C;s{D;9oIihGylsX{MQ;=Qt5nZ`Aub5d2VD#z7n zg%nkSl_eMUVoKrq5{36mhf!{a0lUc~Q+5G#K&$75qaIVk_0%wAM#ZN}z z_eJ+68+a7bSx8A0s}buFYZH=J;7ZV=&^O{GC_=1OgI0GIzJ$gQtN%{}{l~LH4hw0N z%MS9vxsc{UDv3{p#uD@^=tle}q@WgB{!cr2S;2oaAYh>8&O&ww4Jc$DX>GjX!JCkZ ziT^K~li+uO>qM|E!2e6O`!)wc8rjP0$cfDIb0|sG~7w>ZJ-bnVdSro zmyYM9Gdo>Klbk|{B;>!4-U8!7*A<0oD*=;u{#JuOH=l`h2<;c($`?qObZ9scXjSD} zP9vZ^Ap;7)zo2IE4WX^YZi1{h8+wZ6`_lhQ_gBcGqClw#>tWCz_w>@yE0uc{FNBhy zD?v?SCndo~OEA?GxZ0rT5Ysq2NWHMZiqL0oKb@+cbN6!CV_o>a%nEaWM+6va!1o$( zl^%)uh*u`xP2inY;L<>Syn}dTV|sj!;GG6ww*=4=7S#tJQJ<5Mx%8H6!Fuk7FSD~2 zbfFu=l~Gu=m+5@i)s7_g@+Vc3CBp;QUmWyv z7VCZ~-q>!mp3pK^@eM>LI5Bn3&o1|1_x}U#YJzzoQB#1m1e68>|5ngxDbcBw&SC|` zImR+ib{)^H#C=`tAOJRt0E6{(Z=1{gKce&JU?`&qRC^p4J%v^t1K$oz7#%`?K?Tq7 zQ+{61k?b%%88!54?2UiaiYb7J@)e|onI4Ce=twyn8uMf6q=KttX28}9m*ajP(}nPc zGz?r{!P~hBe162+NkaDg%Fc&DXE)#;AM*bJKJEoPmLYh7yYQ|z;_WU{YC?mXm=S4` zUrHx{l9^6vzd?P=h~8LK~61dqn?Q%a!PJ@(!>1jnWm0X^R{aQxZj|j|$QpzCV&3ljTdj)+@ZS zf%JCk4;?O~gWp7`b32~oCs5ZFdJC_EGh6VT(ct(QH1aRhcoiCKEM^ocBcuiNQuC>dZ&It^dSiCRi;<3 zUl|TQdte1x=t@~l_+?~jJ%Dc7)3u=va~{t!vG6n$a0=djLsj)~xxM@yyiyPUw;u1- zpz_cmouT{Cx6sOY_VX_teckx`qJQFjs+()U&kGCfs}4k#)&bL|P~&jmSWh~{IrAOp z{wb^30XiJ)9jvXleR&0T7^zozeIUu0W5 z=&&ALSko0Fy^$BwA#(%!y#z%jvx}~*auaiD+tbO|3D>P6_I*#O3uhNYhV+>*+BlBgMxf z3nsyrE9jhg2|A7A+pCDkloGuN;W7<X{R6q+M}?6@d;5$i68zk6K*58=UT`!=R`kbP#C+>fU4xxX@E98dE-hH5qBf7ML zo~J(g8j6XTHGpwBW{GAK&)N;wcZ9RID@~LuQby3W1>R80^yNJlpmD<=SlHIbBnS5AfIofd-TUdLpC!BbT}EFx0%}dBX#+4ds;QXyX>Kd@NE|7yjM_%HJ-x90HMk@Zu zKhp8?Z|G(^D^HQL*ikfeTA6qI zAWfH|8QY*^neY^T!pYoNI#X_zev|vj6IkhKZI-3fxFsc?`O`iPTmeDW}NwrKaJ-%2??6FD!I5@%tq}e?Amj z68mL5_*jR``5&-)f^Pa8O)&t>wn1y{=DiZw62BoYPU3U51%j24069=>6=d<}MBQ2; zBPP+yM?pe&LF27NMs0`6HUhD_@Xi}{D(vVyq-j2qV-tNj*TenEeD567VLPC^1|n}u zKuap5bSG%5DD&BWK}vl_mqZ8LbP-+H2Ab4P10CfiAvcd9N2()5z5#PtzVKc{!5k0;?R(jwiCp z5@?oU@S1{E))pK*34a1Mw*sdr;Ve2RtdTyKTgb1$MO`4Z8M|siIF86|G~Cckn!qf2 zeOTtR+e~HeD{sJxIzY$QUy-~Mq4f3gTkxB}bA-n^k}1py=)mpJ{J+pZGqg!AlZ?yA z)#*1m8o4+H+8m2@xDZON1ci*`zVxc(o3G%|hc!}OOq0{nl;Qa_)H5@LQiXOSc>=QxKRb^xgLqr2zW9Z9*($%1ek~nYznP!1(RKo4F}Oyuh1qj z0fWty`(f8m;Oix{kZE6IP zYAoxacy2G5JXn%?rK!>Pk9zP*Q4Eyb={jIR*| zU!FzMOo!s^QhB*Gp3EL7`~=@qIJfA+-u?nEH`r};G_GLN9{ zU$H$~p!pT(sSZ+W2=gWCVex;(=Te~$5Jea9IXvM#lU!4vwe{tbmjEDB>lS{M{IdlzO zeV%iK#n|3G&=uRlX=uck@M3Ksbc&rvDsRGzRi~JR(N#Ih8GTP+&{lqiWSz~Y%{g}| z#-5Xqt#`p$CLUlj>7H_eZal@%0qNl48QK6Z1{_>r7oyYjN66so;Y2wcJ}eDkMnp+? zJWdvRvkuQTKm9_E_NAYf3>7TV5b!t>`$^p z=CWr=ZU())ho;Ac8%wY0U40p;wii4spwoE`=@WSseU!@3zrH`S#I;O2SHTk(z<6!6 zSu=2_hIg*>lv!A^uh6T_&=y;OUtKiXT%M=`PTxU2nPk|Bc)pk3x+&65tiu1nD|OM) z-J#dD@YVvRV;lRW^<2!Wg)CFVb!RmiTR{ z1G;CB@)L5dIULqodW$}}i!Zu_JB?=k!$5laS44XkmnLERo>#tyGi9vH*U&`*UW=Gp ze;O}j3?9)@_Ffg3_v0k#Dswe@G8bhOvoA_wBSfI}x`La2&`&hnn~ObP8#=7Rtn>cZ z8@WtvIECJMj`aKqnRBso|Pv*t5()SdNYwAa6xpq=w&< zHPQqBY!lP&J0OXTSk%{*wqytn$9DaSJzrKfATOVWcjI|1<<1wm(|6cjP9+6O9fQa7 zHFEbms4YI6uBh47zwDtCXZYVr$AQT+y63znldcMw)L*eD@~4TQ%VNEmYeC?9F2){Xsn4 zD^U1Ys<)rx=NG|Jo{pwH33q3MYvD^hB2Ls93~J%jMO@PX`?><0(*iAjLV5<@H$Xle z4IiYdb_;w?F@Ka9W#IQmmt|i zP60AT2W6DO0y&7b_y_&a1&Ue1r)#jrx8P$QhiYd@KgqAqC{E>tlA@d^fAYLMk_i^Q zfZ21X|1qmI^LrFp?N4YpiRi;QIHMFhOqc4)d$DU{r7Wd7Jl=;p??vV&VKE)0lkg=l zAS6Z+X&JWUS+Ka9u2fHvH09wi6>{(sNe_4bqg2AOxR33hh28oy&xuA##Y0__c;W=8 zcs&xdEp|y+Jia^FyJPVE_hRvGg+1$F0W3Xh~Mxa^(AZIJ+gbaVB2f0eq~d!1N*7w1w0cJsHGSU4`cEie_tv-E@mJ zOY%)PpgQ_g3pL&(3t^so5dD&h_1+k3`V|uC4bhnqNU(V9>to31wOo-*cg63eJxYD$ z5A5Uetm9j(ahV;)$R7A~Ad++r8aN8uuPxlVIGmx_!o#p`4)WyH==W~OvjASu`EY>m z#v$EHfceVccn3KRliKt` zhPH=`$lu9#&_PeJD?cdZ@M`7*OD~qvEu{Y%WdNPeZlmQ7GHa@+Y#~`;hcByydEA%mOnyI?jR!}4XJdUcRM1t zGtl;Zup(5@&l2S4_sFHmKz=v>3JZG_nxqWY#Tl-AkL1{blplg-Hu4=kc3Bf-l?%`6 z1yW}f_Qg=7TNK*AEPm-i@K*+ZxB^ZJYI8+ zT4D+wmiq}lD=j&cRzz_cV&@mbdpDzhnxJ`SN*mF>K5Q#dw~z$)foaKV>AgPDsK_M5fKbXO_^F2iRRr?B?ol#RMdlMS8Dv#O9cZ zt+rh5D5Ba}5ku%Ed>Yuc?XSS)cBF@xaZ!?}S#Pv? zb7;jAeo8E$8f$an8MMHn`I9JH3*t(1;Dj3Zv8h-G4Ux9<@k7MysNvYEABW?RGgXm^ zJ>)3hnvA~sl$GZrZ&R_7YN7{Q5;v#`Z>>c7jU~gro?J%QQ%o>Bqo|R*HlQvd-2eqx~zGp7vG z=pB3<`|>^V<{gyL4Lh%ZxU&|^={mlfUpWb;n0GEW2lBOrC8b!gBn(u{^^jX~g!c`e z+%{~w-<7j?W;$7Dx;&&$IabpP&-g8TvkUD$4%sU++f7BxX$Cqi6W{Ys`~eZUw}BOL zdKFKUZvq@$LfXLo7h}WSg6kI$y{k{A+X$rMZX|wx;wMD;u|E$(pA&)Nt?+1eB+hA` zz%_Y=XF-ebrAG2s66Z7jaE}`FH~51&R72pI)o8dYSR+q~1^kG8cLp0g0qJ}lXqM); z(?mjl0&|(rpaYxz9g^}J_V65|khtnN@Gue$c?es%46qu4&siPm))=38B0lX!aJ(5_=#Jj34Wy#j>*riI z2#%q{KN0l-X#XSlyG@{*xjZKxyZIUL?#cJw0Gm13(wDK0GO^?*GB2+^&+mX`q{H^# zghtzr3^f9ej_m3Os6*hQI}*)^4Kx%^{TmF}f!#SYP9O9@C3tW@v@!==$nJsA-`ZlO#!a=h!58! zj@=fSYD4$>r9H&9Z(*>?we>IJ-gQ0{Z=35m$kZ&D??M}@ifNaB7? zu~L>`VHcD8Kn3UEzM=RCqtFL51*h5^dhAXlA=11YzbLc9bbPmka8={u*b(?wI2Wy zdx@Xhl@r)ozYq(0h9$ob%XbpJq|Yk@*wbA6(38X?&LJ`GC|AjsUj;7iU}J8DyX#>o zC!&qZ5X&ANryt(cs+ zC|m|fWd+Yaz{^LF5_wR+UTO&^jNn9SD?IrL6gwV0+L@D&{=ni(;!=5>wO!%LSDdwU z;Asx9UKgFy97t8eHhm6!F7wQK*uDRgFCfF{lnS0!gV~i}oGb@$IDr$5GgwR=fuvUM zDL;Y>I=e_M>M6drt2 zDDNh434{APM9*4c58eR^!u~&w6!`^-6^X6)5MQ#m{INU+T~QDFt~!t$2^H6uX91V) z#G+OqUuHuG^Z90F_WKhW;sxC-r(q{dguWgl$;P6Ge?Ze;g=VG#=lfvh9lYNb&fAMz zkdf%s={?vCI+=B_cPAN!8>N%#R<- zo_{1_g4c%~TMK$^0S(Jkkd%fiYlr74ZceQpAlGd~_tjYGJ|I39**P4ZG(o-3(5o|$ z3cs+&Q^Xo4$T~V!yyVPl8dlXK1rjOth;8O|HeS9ZH;jJv&*3r`8_(p(f|aGB(Lpr2-zrPQ1OvMC#5%UvGiW;?e$OK(nUb;jVeuh;xWk zyhBcI0`J6Ggnz1R`XB}r_)#N<1Kq`9~SV%P! zkkwD2!-rz^#QHzO=Pm-|L-ck*tO?ObS0E%!IQ{@``_Rn z9|kWah=(-h>_Mg?Cze^wqvfu6w-V0@g5yZ;G#~2P2fVWR)*<}d7ySDXdf-=RyPW(L z5#516Vm6SjM2u)D*6LX7!3EezZxnI1Ru5mfEB3}0=;D$@Y@5RAG0?RWxlkL8+aFmr ziqFPkNwmkp8Bv&*@+CBH!;VwPr9XwF{uB!CfMxm)&!z@Z{9)LVtFgPr5apjrWVtjR zv<3fQHWtYmq|FvI;|*xzOL-DDS0!TYRm0yi1uhoMhLk7Rn&+X~zqkik_Q<2rMBx5` zgG6RccTS4*XoP>UCA@IRKx{kNA>r!!btTi~U#X!t+Sb#?F+{^XRoF}6|};*QIK)f(`+5~;fZ zYI=^X{2qDKo=*l~r+kg3+fEcRk@L{o=+7N+-~@Oz$Vqc5JoFrUdM{JK%*fH9*t@6Y z6e8O?_~i~Z)nxEqoG6;ef_TmIsv*7a5h1yUE+K}_>RJ(5dra);b2zvNF@w|KryQ_! zVZ|QeEbG>`RKsQl~g#ZOI;`7q~voPEOzj4=l88juU+#lMPQJ502`_dEFwV zv2ur58=o+7&qDRtdU*>{U?RGUY8Ld*MRcAHj(VrmMgBNB0eHmvbaGSqH?UU~ZBu{^ z+>jiLT~JsZB;Oly5~@L`687pm^j8l~SgV7zc|<)Q5GkENr7&MKz(Dd{uZF^ zZs12wC)zQYXlx0v@CJ%s0@OG0d=qDKH=sL_<5L`}IEq)yJO{?BVP|=$#f-zd%>phj;p(^WXbxA2 zUIQ`k$YA1rSH!6#7*&x0c!)^XI(~+NRUcfw9v^Bt*+(gOMin_ByGtEeB2b--zfm4N z)RS|YIb^7eLQa;#{;Ua29Y)u*Bgf}&Ma0YQ6a9FJO`3>CnTr0tPlO|p>N+xOTFo3S}WmVOu;;WDx=jK#c{{qIC;_2fk7x?+G5s{yw^3n3~ZnJcl2PXd!s zz~DZofaj3=qJC^56f^|yX)>DsB$A{)Kl9=W;2GHP?FhK^D)zz@qPb77RE8jhs-WSz5C=^r z1~>%yx)#0Din}kwD{27Mi0sfTU=oRqI0@--nx|etf1V{zY9SK#Ch?~NJtaM|62F$BD*PKy+dYFgz2vq(>r@Lf7|0;=Dv|e+)fG z1BLV8%?giB<_vl&((w+qbq%iUfJ8lwoUH^*OF&Da63CC#8_iVeg3r$A%tgrXpSd~>4ZDx^?1US=h4~awM8h({s;GHMhYLPqod?nGOR@FB zSRMP|w)@zhF;GYicv95$R7NJ(20HE7PcMEiO3ddPa^L__hl^+p6Yurnvu;=xb@`UC z_ELDx1=cl<_3wa2Q-Rqqcy0(bS1mNc8&-ZDeO{Wo_5}(Kq;+>9Qu7LZ@^xs3$H<8; z#A4RNA@i_uO7m?|{aF)_Vi{%ff{FP^r{!Q=$dnk?Sros`##L9*7G`)O8EtbF59Al%nT`M21Bg~e>Rbnk z``Gaf&h-y=Uw~^Z zJX{(Zr#w`q1}E2{h^KI8Ja_$wHH!(~4px%RwI=Y_6B_ys3^ZlkDk$RtuUE*ZB;;08 zp4K{05hu8x@OvyU6Lp6w_*@M(g}op$MB<_0B3xa9 z6;$UNm7r@;fhKCEjI1`At3~Bs0QsN zb7dYFd(FN@<&3Bc@#1NmK_>2llJeQJs5z~~)0~`R+ytA~@Frh!4^dfIn?uGK;4=p>a`-BW?>!L81O^Ha6_rhb z#v*xQ39gGSyrzIFM8#JgR|yR!s8#~Xq6SA)-U#~puTm=){0iClif@RTY&X|CkpdaK zFRIAiau-ou_5yws`H9cKM&8Q_+QnpsKf~Z3sGw(DtSf4p}+9@0+r zC2A9`g;$D7MZxo;Do?Cd4^{+c$ADW=-5U-5lfaRvWE3^L0m_<>R$1G zab?lMYyPY271xL>L@lpaxrN{Us|^+uFDg7my_2Yk6}5k&>QY!ZqHa@E-r0ExIR95? zSHNeYj!$Hg3C${MOvNYp{QOsSCu%PR<^&~*`-qxUQD^N1dg2?RT1-?JItp;oGtByGR%*Ru~~x*qb^s)Df$NXLW3- zOkb5>tG3Mln4v2BUu_As75~@YYWt`?t@ft}-d9^p)e8tfU#Qok+T$ue)BAsYsYXee zFej*-e08Q#uRWEar(TDJzxzM`cf0?um-FG1444zZI_Y=ssI%MuMhBP`!!i7i?XQlo z!2cbSx_|HQ?Wo*xb(_V%e^UhSs$)Stc6BVO^R~*)S2?FDk6Aq%l})OSA9Wn4^N6Zy zpfcQ5j;eZJR==rUscQeIJ*c)F;AZ~sTKgaST;<`by{fYLOTp_h&=)0s_e@o|PpbW) z@;ysJ@0I%9hw5upbwOp3#r*C`mC^IR-c{#zmGM~y-dF2f<$r&!vcy&PHUkd|MvOWi zsk4oGPcXqKQ$JB>6?HuRPs<<(Eub!?3TnLWhL%$Ku4(X#Z}5#Y7$G^Z zg?Df)uYPa;2ka*q9%_VG2*2<_o2b`^$}?7H_yDx20Dr3td37D4gOQ*r4(MV19tAyB z6QaM;@QgrTsk4aM)Aga(qTz22@@SC|P1J*T%)iG{8Ccb{g@`o{Vgx?&Gm0Q|4U z4AmgctO=`y#t?zDgvg>T#0eE3d!sVE#z5>_2R^9{D~jE)>YEBRyQ=W(Ht_2i5TQs&P&Dh)`l(jhX`oE@A20XW~s?AR$If?DnPVT7S_@vM7i~$ zg4rBaWi?^s1z@d}4b>G1vTSA;x9uTbUJS=G93u7#aNPG{Re1|=d2uibOTpEohihUu ztaSePy)O=`$rYew#=-bi`RYE{W;KYiEB!vdcvv3{g!ykSTq&w{i3W~9hKO2?_u}9Y z3s(*JjKkTg+6@EY)4yQrYJ0zfF`5Qz{dMr!Jh&>t;23K{Q)qQZ3 zI%w06fGsmZ2Ic^)D`&(08^CWj0#@>G$RDT?NOk!AWXL)|{s;Eh11d$3b3qQnTGoaP zhkb-0+3>0ctmj+9%+eFtuk=9vg{@Wquf+`5&s3=WnlyEgp3nx*0in7SSp=Elu}BOc zTqi@8V-92yYa*AGf00>`HJ^&?RCXa(KuxzYWIY2=pR9_+LN--J^#x$oZ;k-#2bm8? zVnIqleQPyrorhz70@;k$kWmLtG*S!3-k)$RYhZNlg1wf3?8?#KHHcww{7;|-wUEoA zVFtMd*TYlbjP!=p#Q~)a)I@p#nyCfg*D3?*AWl(plrE^v2Vt(+3$x2kxc+X#%FqvU zLK#3Pl!i}+L#<{84&E}r>q0CCE5$1?PGVpT9Dz~3 zOfwk){1xJW5!M^FbQCiA>CiJjVC(6ChkYPFM?T7RkWFwk#{xd&D`W>>K>k55@6eo< zqBM6v*J!_TSK2E#lS@gxaSJ~Yl+8bIGldMchp-4#VF~^S_XlU-XY*Nn zC-I(`D5Of`q)e#&en1?M!OjKDdoRSV;gDfy3}QeZBLzU8js=aes(4Mj9zKz1Ob#G_ z5j)5kWKC*0)s3>yS)`q^(>8h=eUSE$rzn}4K-XnBEl0PZ_fsFiU9_zBDpiTRMa;*h zV0MrSy8JWPs{WyDzrZ36;Tm$8!r%bM zmJWKoJviQ5I`Gz$P&}x3F8Idh6hA6{<38fm2k!FAg&xW|RM8ZJCUP42fcBG4>J%}S zxPTpp{Ipq_Ew_iUupdx4@8vb1^wtaSLG&YOdMIsVD(JfCOX-*B|J0Ay&Cq{iZ2G6n z2dzLksr%FdQldtYG1N$EGMRdcrqcey_SVwTZya>9jGSV3@LMLJ;Nhewwzlps? za?!om7qk(w2sClmLcT9U`78#d?ov}w>a7y$<1Gl5DqiQoeVtqreB0cwiaYozujC)? z3-~X3R$&f&7oLuv#Rg+mnD+obC0~cC zkyge)ZL%tI78`(hhy?N&9i@-dPct4c)G@3xG%?mUt~R_emeDJ^`OHFk8NHF7M6aL- z#zzb%bE&y_eJX;AAlJ~Xh<`zWbsyOobRKU~?WhY_b?sI(lB$m*=tJDDX^iG8YvdgH zt$yN+XryOAf`Vf1)>D?AH>>jR=^ntPQ$!BxRG zi@oeG=4_ISZwps%59Bej3;Re`#43}IFoSxJk*>(4B%Yeszn!R@8A7t7Wz!@E-psM1tC(TVDTjYkc51sQZDq(;Io{x-YRf8TS(HJ~ug z6;)K;)yDN-VJVlPpr)I5jP*u@w2segNk>EX#=u^u+keLQM{!$!Wp8X~SKuY!#mh-c z&>!+XWjpm2sY=YlA;*f;037@`&3TC3IIj}gVbDA8sVK4r{WoO>x7QcsckNobq&Av~ z)(+9~+7--jZ575x)&Mof%S0tm3;d=yg>)f~%k>*V<2~1ls(9)aHE|3sHW%a;NzR7_ z-QD|&ac4L0Psc@1)#BqtKl~#KySZ<>8Wj_P{l%}@uYvQLilF3w4!f_^AvTkxi8}Od zQ1-k7c$sRNM}SD%FOCEK!;0XjZi91*0!_X}SPD^soI(GhlXNAG4~)g8KIVm%>Xv1e zN0wB}2}_o#qzN_B#`bzn_dr)!`=6H7?j~z8)2I;TqF+!4=tkNy+KbxR%uA*X-BtgT zuAytCZ9?Z$gUR{0AB8)W+>-~ei0{5TsyLyzTA|6czOZcGoWh6#F7LamTYjpePf@kv zKl8I)v+`Fv4&>)J5_8PX@vcLzX1-d!XMwTe1okl!qYOl&$o0rfj3H8iKQRY6BCl7* zN^68u{J-GV@Jc9@8$i9eCd4aqv19lhay@l`3Th*?PZ?UDY#J4jYR$K;v`;qGv6r&{ zvi}o(!)^_`2~Hk0jcfF0O_!L%3=3|@kBDjXJne9%5|hV7GpDr8ba%Bq3^fh!bVqe@ z)Zdf~n~b*8l$E-2*Zpt3UEKMuIj#+b-3mYEugf#!zb}l!7FUPwS2{ZtXSdsV@_2BtnS_t8Bg(S69WBM^=#0V@M+`TuD1=x2qAPo>C{-C&AUIW#ky5Lol zCdF%Jqi4wv6iw@y4%#)k_Qr&;B@veBDbemoeIy-SDki#wtIWc&|7V#W_OGp% zy}miYP?>Iv)ukKLlePQ7L8vF;!*7!9VN^EOGrCRGC!#w(00qwhBwBeZ3=3%kXNvC? zXBF3VEq2azp3S|NT_(3#R%Z6E+#l&h*~5}grJYWxnxdqyN=r_|(wqDmT@aRC+WX44 z%3De(8!9dL!;aILaZyIRn+pxX7)bzRLf*NtG&$h z((a@`Xg|;|@f2b(@sqGp<>{{A8H*A((4K(znTmk5kl0E9hYSB!|0>@b&wOug@odLh z=kfw~fu>N*o0YdED={mS-Xc9B-Iv}!txM|C3^9FR_Uat3!|vSRS`nBK`Y6p+>S8U( zUifJerRy^V+Iiqiv6QL^iLRIA9HKnY5BH*5k@}i#NDgo-exjAJ(bzk@C5dZS>HacV z!>TN{UjXT$>xS|AScYVRkVX@*3ix6y z32lL=;2TjAHK8_$Q3e9uvWhZV)=3s{qT0qg*n#Yt(5pZc+;{eX2Wms-=c3&OdATKX z$m~5CM>91yyW)yJ_f7uTcj-cFwz>?gna@Jqg={E zHKC)y7kvlamfk~8rrgAFydqu$KZySThqo>8>LgK}UZX9kn`)S3nrx|N`y>2)#Hc8J zj5W4ntQhk%`dzdks$taVh&pzob%Z6#)ZIA3fasQK2|5=PcRhH0yae_HIR)5)ZSrQR zx_Fl-`10H|c0wpC=nlO1_wdj3)$~4ccPU=!tmgQ)a9_SQPs;9^bvJ{^7@UTrh5Z_w zs!fecZJw5rxj*+&!9hp$;%DAF!CInCIfpzzzk|obK;{Lw7qr!{(ofO4@6>8^^9_E}#;^k0-tbVwx2W1N$6}kr=Ekgy=@~6Y z9f}N(7-Q=i*3;a^^vRH>`>f5Nc+j;ThX0F^=mAZA#R{$^>&3r?&zzBShi-?aKsu*v z@L8a9V2{7E|FUnAx3b6XUIMAU1x23<&gDht(mBJkVluy^|N51mYD%4xYDlk>y(L%5 zKUehCnd+Gpa`L!rg7t?3t4Nk-`h$<$1W=c6t>0h_8tWUv^=+8h^fP)p-IrQK%)z%{ zvAB&GLFSU%sWaMvx_yR!%#Fk9TVL3V!yiRliws82kE|M5Ir3WgObA%VSe}`-8%e`e z<^?^C?1cA1Pa_fF_1RVGEb!bFE{bakI@HTT8-qK8hk~nv{=oF$@ZfoH&HEha5qJ&G z@>xE?8}KZ52VAF}a~#3K(uLXiy1bRShjV)6Jji~N-6eNx{?UR?4yWshJIOaQG?cF; z-GtTJFYF?bM&Hx6)wy(EbkFo@hSA2#hF|&{y7o*d=3i|BU7G4nrr_`J!$fOx7-^*j z(2VvkW~Y9uq1brKj99i>)`wLI^H|a?(=1*SVahj@10N}x*-RfNKNIgT1vo1C5R(j% za0wiQz?VLk+Y0U@LwN(gnRoKr_#ylVKETCs^T2~Fk?jhf9%e7FJJ>^^CBZX+DS^%Y zo4y9V#ooSNtvA^d?V0Y;cpaW)o>=c$Z%6;503E8%hB#UnB9>QT!5yvymV@;rjMO^% zDV;;Vpx@HqU$5=YY=N^`KwSVYqS@qGVm|Hy)#BM$cf2Ywfv}RV$uU$GZD1BL7ns(% zI=aR5Vp_R}D)TJpHkD2TfaD-U{-XM<9vCz)Y*pM|;A+$Dh zK18x3*{SSw_7%I3d(XAt$M6Y!9$#5#E~J1a`X+I=sFnJHCh|SG3Di_SK+dr$;K#nh zN_`pR9t@CAz#%uW0&4Jo!8)%yS`+JxEx{V&&++L5PdJDIq9o~ui25A$4BV_XQ1$3u z^i;YEXt8&oPEr@C6%+>!0q-#bHXFT!7J^2%9eakJLieLTK>PYC;B6)X=4(Hw1Pli) z_|4!j^-}ySUIm8(uV|M%VnF;Tz7|i3YsIl*jF>CD6Rrsxg{lzE^y6Fe_qmGT2Vvtg zxi~(NOX5~=|8XO@>D&Y^k(z*qiQG!2lKJE?Y92WF^oK`0^^-hJ)+L7%A8`YIALCFNT}UW6 zxYj6H+(f_A*PX2=m&c}Jkz$6gn&+GEqx4)m-gHNI4IRM7fp>NTv=SAm4dTb89^5+q zyqtu-Bj4lknz}+!2z*(&@8W)Wu>4e#g_nFO!6#mnJHk;m00iZ3K(Tn`@|ruy5i}h% zz>^UM2|~UwUb6}mlzSizH7xTgVcuA65V>fK1R5S_@ba8L*-hx&Rr9?ZXb@Wr-+)hI!=?XlpM= z-UAk3uT)7oB|VVl0iNp~C|tqH;&bp4#2+|LoWmAiG1wbaK<0yH^>}m_S_yngUP7A{ zKwhzlvO-!Y6$*lIS4@!%@(w^gU(gIit%#sO(38kZbSZKf)D}EQMZo^$XsqZ7XmEU!@5?)+u~Jtl1b$;J#S@|c`te&Ow^T~zBnC!HeR&D=YbNA@ zW`gq6cBB?s1^l{LtQ&q0D+e(}JoX5wfc}frMSG&FpijS}S!frm5scuq*b(e1HW}*# z4mH=&I_PJJ1f7tR(?DF^TvJDhm%B?N#Op$B;V~cNr}G#1C*VtVK>~$Aq&(gM+My=V zpLz}V`6gI-WH}&r8b~9>I^uJIlnCjEI9z-vy%l@H8p2n*MOO5Dxls90B7F~^1qcK2Yw$OF8L-f6VjBk zfKc3t^aTXoN>CbagZu~EQ#AqZ1N;6PaDs;c>-#Cf0O1 zs|-nyOBfEhR|INmEV13BBv zfbN}(bVZhc9@iRR3zP&rz8=PdOOt>;L2jahP#$#U|G@U4kI@)7o|cFf$px&|S;)3> zfJ0jcX!NGQ4Y?)vgCqC_-l1A)f>bEYkiLN;J_{r9xts*qnIdrdKmo(tQuzu_Eyn=8 zHv(yh*2O|tCESF+!dgSlXcS}+#{#xK3>ckRnjqv1kAcIo2AhK2!v4mdVQ;az*n03_ z-HFUX;-D{*L6J+RoP+UX0(96=shM04T;Ar1B3F*@#ZKlM26wZa1xkJ?&5?%#hRZMS zgADf?<$_dFnhZ|!ZKa1?XR(d&N$4ipK)ZjwR9nH}TKOa&03O)~P>Fz?I8qYkip|(5 z>@wbfNX5&MK|Gt7POKrU_(i-fz8tTLS+Ec~2~7kag-+lAH37K>Sg8RplePq=lB%u$ zhdc{frMHwP-W3@sQ+xyaza&vov2++7=V11k363_^ z96>t3);-|xFdw|2nu6BCk3y&o2e+naiLTm7!{Rk z>?(}p$`HdeGs%5=loY8Z>?7?QB|};#Ri$5s;)EwGEwl!gmJ_^=|K*=1>?zDr+5}p1 zt%P#GfZ2--U?n0;Zh?VP7G8wEXU9?Jj3S5l1<{X1O#+*@v)fFWn)e z6Ans2z<=69X7i)S{zx2tkzcHM&^Y3Rasxdf_6N_{drAkYfVct9Kn;a3ZCy{aguYAz25b<(=d)J{qYbmtsqi zaF18g@K?$^v4NnGM`%8gE%0&Rhrw2^uX`(c)R%AeHKm!4=+xZ1%eE>1iqbmiKy z2D!6$N}!e)2X1POm9vof8Y?vlS@4J8KT=BMmC@)@F-Z!Nca)X7i~I-t7CH{EPplSi zGB5Ewxe#3;&C%M#1Y#z#1fMK)RsKYMd?Uy`<@2OA73RTi;E~si9*;kQ8g)Gx$8O8j zh%wl7-i2NkS}+^sP09%(EYy_GhuTXbw@S>AIDZW`MIiXE!gHatcz`ct592e0jpCxf z7UX`=$o_*Z*4%@&*}YH;?73!`G>`e0+#`(_`;xoK26zSOKkY936!}V{V{Tv;xh<~H z>yW8*IQdBaK^((35r>pc#8jxdJdu~v+n8I@Sa~CYXbl|BJlH&J25 zt>E5O8SaeW0D|q&oD_@2Bya)FK(5K9HF<2f&_WX?X3EL(WVjRF<#XhDnm;5vXiU^Y zcS)c#Cgw?BfbH}L#5@xL0oqKnSlA$6l<$Ez*bHd~a!ffQo|X%w5r_!Nn#Iy(z=c?( zrt(XLMvf@?@@-JUNkJmzDo7rl2T1BFCrNoUG1{LH^8U^C})8a^^n)DVF zTUN^zq~79vu_34cmXod0a5+L?g}uD$*EU?((;uJ~GMJ8}OB!dGrkHewxy)NbxM^w_ZR%kCH>|eh zhN+Kbymh{5fEBTJx9l?q&G$`*j7`m5jCTx;bVn&YA)qrzg!qQNgBF}6WU*~T5h{-x z_&32Z{#U*`fq4J+;KIOP&M5bA*GR|d;+w86&aI9Mh06=S<}`4O%^Q>Rv@kk5H;2zI zmEAX|W5$5&u=LMaO>*!5wB*gp?dLA%9uthvEc9*HyoP(bVySGBE3}h{aqszase7v(+2x=<6shxKL=% z?U?>O?QY8DWbp^_W5f^r*Vs>kKbQD)=-a{fN8Vpfod0_E`}uF){Me$&$s=+fWc}%y zfv&|4#vTn@XG0phGBKrw0UT+b^-g4ys4mtI zW;bcW4*>`3Px78}NC*a!JXhU`&ZXY7P;XoB9_j2~7?xAFFs!go{*T;qIqNgJrU+kO zef9pB_3h<1+lT!3L%!bqkpJb)yU}m` zC%PuGcHH6c(8^f;9pJm7nhAlJVv|f3|TG<-Xc5If>qibVaX{c+8j+`A= zw#>-Nb8F136Hzm!YOAWvEB;gSM$Hw~zEntye-v5ScG7+@{DUpVGSAdqJD7L~en*mW zTq>zFl*+SXyc67YJzKo#!OH%d-r>dlib$u28u43p90W=V40qFx1uGS8)zb@rF8Yc4Z-^8lmEyhp#+7OqWn#FwIH?_+wsVrB`TF6D%7c>R9jJbQRSqn zovNO!IJ3;jvc8hiJHo1|MpSN;L#2z~cm_7!*|JhZ2) zSL;h~&v%K=dae)7mX33U-wNmDRmwe(k&`wfwZ@N;KbL-A_@G_R*CwC72IeCR^{Bn%eu@euWoWDp}L?f6%k2Qgjr# z1sjLRpk{Mht|454Jlj|$5sTD4HBUCb3R_|w9rdQf-cp0g#+1KVF1}(y#bEivWhRx3 zDCvxw97{*H4Q~_n!Ppxd4ih0))DN8q+GvPWTDZ(j2|0sr18su`{3CrCzN>*teuJNL zmv`&ko8A4rP27xYOW~gU-g(QjA7&WR8IVQGOO8%@|842V%uj*O%adySip&0(y(s@| zq0OTUGTdN9K+M{z+Lbzs-lu!TbTfGL;2Nv%WvprF!X#?@Qx>v1{x4bzR!YsW+E^Jp znt<4e@-Vebqb!%LCBx&RbTKt!@HkhAZSkK=)+nhd*|NmixQj7GQBrs(+f>W2FwC4| z$~1J+4b~=7T8br`kb8*c#Gm+MaHqVhT$Vh-7yc%M!wwo-0b_JQXSFd<>GSLA=zr?N^*()nBV`^E zmST;zEw&%Cmxy>5wJh2iofMTCg+_gerlUtkRf+h|KHv7z_Qe_<_R};EDi1zp57St? zguIGR#$JGf*(%L6*(acU2li^PM4*YUjVHoguXwwYag;8+oxeZtP0rNpP{xLIe`@X2 zdq2niz>_1BOy5d;tN*ppw3rfzAHz#I7Vk90Gcs&o~odgzFi7=l*P zw1cR_AdKP2&=P+S-+QmoThnu;xRg`pxKVheU}^rE+&5XV8ULh@Nb{$zN}2cLSJJ8P z4ZpYlp76cnk0&V=GB0N}&U58gbA*cL1n;u_#cDF4{E1~y3z(6*MfwW*Hio;VhL&X0 zaMOQ=?)sAYlll?NXQ-PibXjJBZk_HVQ%bi%M;ZdA2+MBE+puTWwYHDpqa$yIH;i;d zRE#i29gaK^zSaKPKFQj`*2-$NOf!jwkbbC6WL`q8;3{zv{|9Y>4%h6Fk4fdkAh(ZQ z6*T(y`#yS?ddhpE+>@N+99qZAf~NV_yaw5|U>0keCZsk=dH8c#(*C3`-;aH}{e4nO zr5}H1yvVtq^E_{eqoXI$SCe%KqomTx5d1S0&%9hJ068FrWsnfsgTn!XzTHB8b! z(9hJJXSy;+nLdEOxvRgU)9XhYT;R2R(xMB?56iPg*sDe0Q8y#xh$WGKMa+!c8|erS z3;$>zZOgV!vOY0SF`YF^dQ>kkGqj`02H@M-7wd)AN9rrjrP0!8p(k%=(ID%0`<6rP z$?1+RUhQ1%yj(Q3@Q;GBc@?vb+5U_g=~dGlsfT_}_<1}zH|b(>*OWhgUCn%%Q!=kv z!B^)@cl%%w_e*FkCu=0sMcrm{q1D|6qp`O6pal!tV7_3Q1yRyG;}gR}-9epM-$Z{~ zKi+Ub-$Vb%u*i7T-RGKm}>L`;wo_%{|j@W!;mDv#QZ5O6dLhw*-qf*eAvIiXY;o8WV=2& zIcEa>(+Kga(}`guR~ZW^97CUCei-S~m;}bI8?Sq)o1)vI z+XtSv*NuB1iHFR_GG;s=_hb=~DA#OQAdL=d!zH)!U(x;7|@xSoS_H=^#Oo9V(jxGu>`~Y=M zI%it;?5r^vi_$lw4Npx>-Ja4gWkbsAU&qp(XI0J3&5J6`auV+M{&Q?+euDTzzN|?C zcl#ss1E6Yr&?W2tHr6sJfUF4`6g_UB4Wo2*b^D-pSPgpUrtY!sy?&`-i{T$r1M?hn zg5^z^Gi;%?r9H(~!S1nFw5QuX+b`Q&*hbk#S~pq(mM)g=reVf@hBvyO%nc@sdPKU3 z-uOuDIGU%qE^FjLA}P2z9UBvD66ob$?(6Nn4yc;8aM#@K*jzNOaC`pC+*diPvm>*f zWVXzp)61nL{?eosrH#)>&U}~CCa+at2S>-^-k#5aP$-MnORb?sn24_+{{ck96hNbn+8wt2O1x#1Wf;rVZ z%#v&_X)#&WS+<(TnD?44z!kXP@V9=5E{e&g-%}y7AM_B39f3ONV5O!kh#A5v{wlYF zT^yPkoE8}3KjLfbedPJaU0j^%8UUD%F^<2A1{H28uou+NznXU`w{q^_95&}sPMzG| zxoCcJe#^qOMG4M+#Unf;eHQ~~LKnG-!gdLmR*vB}=m}0nNh&D7e;D+*sxBAukec+VdOkZ6er)PDi zbnA6Jbpo>l5QwX_8Bhi7Ox-7s6NR`QAC7HCA3}|!6X-MTlpRo&yeei1y@c=lQhpHs zgS*Pr<3zS95FU226`*!~EYvnsIhYZ672y3Ve7$`UzCGTTo-v-Do*JHb_l4r3Vzaxn z`&+TeeachZQ~30It6M4hlLW^FsOl^6hBExP-k0&EJd$i_wggd8S*AI6-?M} zXtm4^rkIf!g^AXsF@u=iOg%={+8GgWX;Dm;_MkRVo1pEgHEENfLf?)4K@FoyQLo9l zWHr)BTp>mi?TIqP13(~@$J4NtSO@GUc)GSnBcayv80rJDNEfJJG=h9d3vl}lLTUyKj;ees_5KJmu+ z#`|vj1m6h%JAb3Vi9qvUQgCr7j?HBcal`o1LRqM#7D=q^0r!?G$Racbdx_P@U*o-q z45A&movcP3r)=~*x`=L}y{5&OCNM|OW=1kAnd{6NW+StfSq<~tU?zPll!1?ny(K{0Cqbj4~Eb$SCCD}Xz(J)0)@8Iz)t7{d|lN&tq@qcYe9XCRMO=RnmU9c zyT#A!O3+2?p&aFB@!h4W(n#qV)TtW^`C>N#}U2I|)fbMe!d21&`&sVn?ZibOu~|8_L_{E>N*A4cPqGP{&9`{)YO& zC-gEp6OOnW9Ql3xF+P{*P22~6_A3NMqT~bOAYma9as@Gi*o3bnN)kKp6sQB@_)kD? zpGUtzb$L5F7S*DwK$UDgREI4{e^3^94-D;Tn!3P$x(~dBL!jyR6Fd)QC}!CPS}kdS zK0hkm22bH8G6Hp>aOnx;4?9WQ#MjclfY8Vj0h0x8!*8S;;t1)PXpv*2I#OwI6jb+( zQi4!dN*C)%W^iPl0hzOh@&k#Im&qwoMeuz22sNdO$V+(vuny)?osf0-CZr1vbx!0t zb_d-{?M9>Uo!Dh;3$YD}1NzWFVk=pT&bE%AmTG4zE<ef?g9H=I3k2ggJsa67td-cpjGhln&I*cjf70ZmmwE$4jGN=l!inH z;wX`i9G5>J6VXXT7-&rom)oNv-URQ6FrY&64r_=f;KPVwr7`{)IYfQL%b*^#GjWJ+ zhfkw*p->Avf6cYDtV4O@1Vp_0X7Dq~F5bune@`35R;$Wbu&Lgm-f5 zLmPsLp-p@w^o-bb{K+-G}c zZ)v?|PS#s>24>*%i3qi@ZNE$EyBR=DPvKlrRoDiZ}oZS+X13uqkPtbYPHPeyknCnU|v=@~w zJeH%$GG+Bik4`VnD4m;}Rk~nu?zVyh&V*n_Xe`$YI9;ciPR0qQ=cXpMmUcSgL(KK4 zS&;`LKSiY2hFI^Khw2lw8woSE8tM?Q`1s%xKkMu6E$elACKn%fukeKX_XmpEJG>|_ zA^K6N2CH#TSkQ7HqEU2I?8*2gC03L;U*b&Mnz(mSHzIf2dfWcCUNSE*^2~9n1@RIP zD@)~)l0}sGEa3rnl*6fzFr5kf{G6&@!DEi^C3(tb-;w#CF_tI|GSA^AMXImND*XU4mt%zo}F1AHB zB5aLilJSP_10-b`Q9z)k*)s6knxmUe=% zk)fG+r@3TQVZ^8Cb|r%G`LT~;oH6l{Ld1OgTYE>_xUjCKiN=e%5ky^V75YwN1jNBX zd4hCWSi{%hsOTJ}da1TJ zUbqnI5n2$u6U+%s3uuGi{k?r(+@D=TieEY|x(2%h_jyOa;mi-_oyobK(=p4RrO6(h z*&_E=ZjZcj`CXmafxX^pqFV8bM3>MxYHW5tvcSOMXNzf82XNf>G? zj+hbiN7TURim}C!?IY@kcMV@*W6k@`mrSQjhnODpzvL#w587qm0WZaiI`N$NgFgxL z@9AJg_Fy0^R3c~$A_1$93AFX##bZ2(E7|?r+23=*5pwYbogK+}*9$h}^vP|US0N`U zH@2vEK~HA||JmXk|9SR$XsL8w)?)|Ii`pJkRb7;Ms?KUSX3RJIr=MWDqc74Q)gjuI z+O^bPJWV5`E0tODV&x0m=Uv=E{se!RH}Zu-6!>m5#C`aD>J{0argR1R3+CI#a<(Yz z9$N)l{qSkFdSU0o&YG`U-k9E+#v4cLcWN5~f>4IvSTq~3M#ybIX;hJ`h}nSSNDMCK z6WHxy zx{`|W971859|!i8pIsny{Ca=!TSp1*k#Pc*;SLR3V*2 zCSYrzCbU-BB^QFeTsr8Gozz@JN&vU^3i=g0fb_z~(|3pxw23LH9jjkSb<=-j>KY<+ zPW=sSNBv@E5ZzN(g%05Lv{Ue1T2@&{Q=pIaT>Gc+f*BV;jT3`QnMJ{Wi5bd3Zwca# z&{Dq79Ya3ygN0NU*Q7&F#0g_Ud*uai)=42rFoh2D4}-78QEt7E6s#qu`BsWazE`5v zTaTX;+!*Q@Y~qu+d%=4CzXDIWJ%NwxE$*H0kjvyFgihcRc3i9?-$D|kDC9ZX2i=QY z(wv2Mw!yl%51x(mz=xyjvCjAx?ChK#o)1iXy&PlO)U|uX8iB zXV_|Tf7yw~a!* zkJyMmuxU~(wOg4C{Kjj5VcLN7z%DD5$tU7P?G<0Qu*%*$)DbM3C5&3OoOTX}(wX93 z;-fg9+{Xo|7NOmSWZxvJG|~_vprB@iPz66B(MWC4g%1{cf;RPRzA-VG)dBM2As0gx za_s~Py@6il-T>#bpEyIRj~Rfhb~N;1LxW(xBVKZ`@F_ivJZl17fOxaxmDJ z8Z3)}vRa&bNti-35<%?;xz6hT8uBk@iT@mS0WTS-Yv}9awBLp5Vo&lFyB%+!^yR9N zQ@AMPhH{)Q1+2`Rz%qMqhaydY&gj0=+uNZSRV>l0RNcHr&sg0vaEh)rSZQd?02_l@YmpQZL7 zLMTN)HPn_Wr|b#b);|%BAU%Xz%oLG9=E-}M)#wVu!`;W5qo?tPyr>yN86-~W&uwRJ zD?=qchRP2KNs}C?ZRzj%gQ1jn!WpW9xDgx_ULq2)Js6K9XwM0okiLRITt&)pF~U;n zCi*$3BY(&nH5i%|3Zr|W$Jqm9vhoJ00y?+(8i&wNcQW`3Z>+HrC)pHy9dbm&N;RN` z-=SU5a%ha;CgzJhfZuADX>2Yxoot9@2{ojCcqa0O1MLBd68iB$ZB=#&*_jK|#rp0j z_f1z_XEoRLqPGeXVc6jkhybqTn*;Bto!(k}9WBG`@t!79gkmg0k=Wz9=fW;=6t*}d zY9n}K@V@PYyDfE@olI{Btgl|aLq7|RldBo~yA9|uZ6Gw$zuLAlkf3-Fl-R?u=v^@y znIZNf^&E=b7h9A0+<#bx@ER{GSAZz0N2sp#gnKzUlUnUfH>3o+30iF>4(fh^Z*(Dk z+f~Z;uk(eu*f*PhX>aH7(|r(|^tXOc{y=-Wa;NBy)gGv7-kNhEKGr|jKihuO%?c;9 zUC>S}B5$*pxgPvY;B#1nG~cm^7=HnW+orkf`gy+n`Vn#sZl>7I(4N1=WgCA67eV$S z6u1qz^~cU0y5aaUPD}iwDGt7)K%JPahJTgs2wv?(-)(d*vxF-U-;o&q58f;^ov4ML zadTnwkb92ou;$A4P&@pi@($m^uEh{i6Sz#R)nvH~A{rK)!anQs+{=j($QhZTFS*)l zhwE0PbFt3C2&OG@IxiGQ7uQs}Q5$3=v{JO{c59aT$6+UNPzDYxLh|$-i)UiP&C$7$ z5$B8EQ0=K$PX?8Q#L3OkNB&%fr-v1cH_cVNoSSat9%#D8*|}(DUnqrQSOz|h4d&`?!AIWLwypldFnfwfb-lgl6EOrYEq!rc3$N#IicqvE zR7id*o=D~6n~H8l*@Ge8tKaB5M?Wp3qV4`HKWCXGSmY$3r}EI&Fy%hIE-EiOQ+Z^# zugUiOC#Wk;!y$-ql{@R!KjY^PaC%E25)vI~Z@iTyVmE*zI-K ztJ2}~tP&N3Hf$e#nAfH2;J0D}sV~L9jC$shb2$}8W8V78X(S*n!GqEX=1=*7H#M?? z`-8hzr06W9DoQo5=FFc#t#BZ^m3yY6blHZk2}DULKeXHa)ni5qsGZ_7@|G)F>tiN( z4RTjwl+;99*4Y|cX@8#GRIC`WNN%m!SyaoKMy>P|2E*}oQOP+2q~TVBd4BeBE>z-Z z@uDK5uD+pLa1YWS`t+H%JvGra#M)E3o;4^wUZ^XYeAVrzkz%fxZKmfXCQzWgQ{2dU zM}FrZ)`&*dKwz5x!+T#`FZ?U)IpB{ zRLuqKHSaO#3YL&vOs5LU;Qi?S!7QE6f6la>e*)|WM7yoE-v*uBGUFkzoKiP_&4DPfY3$+q|$wTpS z!Lfl0y83{;->+GXFCz~7?&FF61H@+bh2S9PAO>nh=n_!|s~Nm2aLQ-G5^9I+$JiJs6!ef_M%rxOS!$^ zovVmh{5GUI76eO#)=r-(oy^65A8l2|g5UGz-LK*a7g> z>a6sZS^|~{2g$2fns!n}WP-Q~Jtije6|u(R6sZYh%AWw|_mF%LSu9(XkDw!witSc_ zC5_Jl$AksQ7(f>F2L+Jz;5obwOTaByUwppCjWz;Bhk1bBIs$pj-N<|NBl-=QgEm2{ zDQ&PSs*WzIm-k|;K#!s()(JQzI64qe>F1E%2#El5LK%P_K-y^{!NvF!M4bDPk4RTd zeWW*he?3|#w?b8Zd1?t3SK>PO(Bm}v(H1Hz$44j*uiVf`#S|2Nrub?vW8t}g; z@+WXAcY!iP6VM6lseD0Fz_TF*v_d|C9=rk?7%XVCcF~jujgg(uD$hU%W4=5CSr5vf zfYAgUMk_G#B+YTT0jT4P@>fkq*$DXemohj50vAUIj0{j`hi%jZ#gl8Gigy!u>K+Xv z%b>-vRz9z(0OZ*tN1W-QO2RisU@?gk@o|h-W{u4ms<`XCv z=7SziQ^02L1qJv2RSmhkf{G6To>h68uL29 zSObb3 zXq8+CMU=~+gqxtb2o*?`%m6>g?`=!ej&R+3lwxE*azjm4OzB~8cbAAV8q9fp_CgF3A;ArQV^7L}N$6EOD6~mKQ0YXkQ zoXf=^E_VdUeL2V~eZc^tz=BG||NjnNpJ8w~Y{proOo_M`?fN=o82O=QTZ#Xp1y%+P zKFcF`KRTE$5*o2W*xw^EfuZ zw}rj^rU_ocosES8ZhU)O*wWn1OnlW`fQ(Qj0JZQI+Il z>JWAU)~Hq&5!FNx^$5762AGoAGsH}0Ix@n`-FLOo=%*b7wW<~Q1TKuj zrd7&4%plqsKWfu(PHE7bvsI_I8w#Xupy|vu*1~sbHrb$bP9g?s0=)pwk0+=-L~ANr zxnj(q`kE%wU8!bzsTQS_8=py$SZxb2<%uPCX(7gDCDAlpzCe72b-lv$qn<#<>05{# zb&&pCnNIM?R5_=;)dSqgsKNH@R-%Wofe@98`bu@UzEf)sM~Oq5LiB;ZtBrb38KCvY zBq7!GQ2)p1rH;~1sZZ5AaP|A^$zTYe0#$#~yx}s7HhF{Q5@z(*F5nq#n=wwEYHChR zRWIv7$~mMuShUZ`zt!`^0QGnM7c~dz^YQ8eJr#V45Ysx{M<1&GYD8-`q+et}Q@TNJ zZV<{tW4_wY=%C%eJvh&B==Tgwo1y2cZ@_842;zlcBxxV%t@Zv!51f60oMdbwDzvY# zHqAz`K3sc*EKXIgz^=1e8vsVqe$yo4Q{3^*HDCRak&ZO^DDZpZuw$}%8})N=kj_B) z>p}ko@5fr)MT5YOPNZ_EeH6%?L?^i3E@3}OfnPV@w2I{6$w?!AgnwtH@en?tV_2^j zFhQRQLehWysSvD|Nnn$&z=R|hjy)1trG25-PJ=V7Gsr$UnyiPz6XMn>II`;Wk=kjb zc6$^3;p*w5ZB=ipm(&iX41I&C6Ee-(L$$k2y@*-OF|7>V022OG$@9 zHwxJ|ckqtG^v3v(5s^Q!2{TP!VmSGd5Xoo6DEQ(w!0&U491pKxCMHV$bXzKzeucSU ze`Xv#ijKqVsV_5`IfhxQFYCgLLT2N*L2LuaCTYwc^rzH95T91-g1S<%E0dJ<%3Nio zG*P|GQ;NBv>6u*Wd zYiIowdoM>DAzfaijsg*G3%QH=%<`Sj;DFMQqoG^FibJ1=mWCFE%nLDsrv~-)>);b; zy<@IpuTs;A9DFZ~H7aoj6%r>*Hsp);g179H(HVZv0nke9(!I0?m_R=htwLBsMs1#L zX33V~5k&(EBZ^82mgToAtjzzVa7szD@?lO!zeTa0zj~kcJMDivXjy1Mf@?<|RIZDnb4k73{zn&0~_ATxi@rtRN^Wz{iL8OyLdsDDjQ@N)8dv3TC*S`^an5G4SZkg1_pq zmWNfaortDS(!taQvZJwAD{`B}VNTM1v}QuZ_|hKvDX*KnN_~Fie+jQca%UHXmBiP) z5rfqfKHB=Z-{!!%Sjjq-%-F3?4kQyufbu$%D+2}d+FglaD&P+3ZWa($gwLDH;(j%WA%sYpE}yQ7K>rZ z1Z}k`k#=x@TjqLf@F=r9<40Q-LKD%*!;inqtso9-C2p6iiBqvR{ZLi4x8h2%EpOS| zbYV;otIcDz&yeKVPmf*a@@PP$uTdYH#M0No*$MNN6J_i|O1JxFY&J*2i+LcP6$ibH>hsEk?{{U%n7`8>KTa(5#!DAH$&xhv_V zelH|DiX(`_pTMKc{V*93+pBH~w@W1?`w^4c93!WDyml3uV)tU99Tt6Ek`_^kt z*xiVGvBTp}#pFcIiZ~UzDzLS$Xx+jk5PRf3j=b7b+xfEalGpD#74I$nw3MqTsp@5a zB)YXKW{!1??}nhxq1PHQ;pf5&BmRxp6fq`zN0==r$+wM%K()}%NdGzqI!YRXotIpV zrAzW)?V!dS5oqzGhk{)Gc@4MUXgAy|S1L<@s+2ky1Mqro51yQLnLD7pM zH;4HKdiad?n8b~wd<{}LE!2U86;XZOR#q{zVt>VXTY1&6+NJeLf>Q~nLd|CHfBlaK zJpxgqNAQWDK7m^U&IJtjzwS51r`CF$|3LLH7AtY?ba9|qBwmoZ%X5?;wQr4b)0fl( zI*mQW*|@HJ4Y!k1*#qoBC{O-mr?NMhVCEzhPd+wf>am(#_ID2wS2%}(P~%azy=GZ; z%j!1OH>!?TAFqC2U0gG=_C?+I_6rR)u0HZe<074G+2h^b-xNdz?+!Etbn%I^j^szv zdr`ObP|M^+?tap8aXscf>s@1At6g<2Q5YwECugbKjassr=?cEs87Lf&Lk0QXy1-+O z<+gdZd5?LExv%*Lz7uznxk}}L-#FU1h1|=La31|m6p>@;mCQbNK9|I==5O;$`F!qs zZZ0b`3{ywXpyR;xJ3$`Bl(@HkRc)r^x|1bJdhs6V80NKw1z45x_zJhjeUCk z=!RC#P$6IXLs9i?VmCdJ&E_O7lUu`@8Hw_PZsj!jm^@D|1MA0!@JG!%K)<4eYYynN zBB6TA*EWLKwaH|LcYQomT4Ah}o5zjh8~79E?v`1WEtW{j9PQu%*EW~t+~BNs)Z*3DxzX9r6(#H!o4Cg)&9pA4;-AAUKZ2ydY{&&w zH5IkaW+=aIq8GCn)qflu=Q0?w5n%7eqWiEM-ueF+8x`O#45fxs!|9%MA0~@g&9dA8 z?kV>x?(?>sgI&PJvm&liSI{N{*g7T~XJ9dt!^W{Up+K{;2XL%=neI$Wx*q3NAy1K6 zAu)s z`bHcsZWmq(Wx@sQSe!gp`54bjex@aa15W*6qzm0CWCgOQE9 z&=wleL{HL-Y6)&+E{>%E%D`WU`=BnYCB|Y``4Sx2OJL{jG@2nrW*B;1wZ=j4Y-b^r zE0z2MTmYFWqkm%hg8@2^d(F+^mUFw<4eTjqJnEEk`Y1h@UJM1@RVtDC0D{6>QYE`k zGvJYFkI(#?8UVtuwyU0Q^LJbJVJV+DW&=DyBrS}!_f}DgMvy8sN3}Jmh zxcLDb#zkxw_$Z%%UX;kBFuj>Upq=H@kLgfQu+l*O*$)O)DU}S`MK08aIpA4z#Pn`6 zdQ)>jMScOIbO1UprP^FgP$RVx5ES1jk~~ko?e60~D#=owbW&P`yYjHqS(+@Bi5?g3r!WX!g5iO2BoeoTHs_69FSLY7c1 z$P0ejPIZ&|n>t?|g4NzxF9T7wpD77koaR&#J)8N9U5jtxpZM``hMwb+xu4-W+Rb!i z0-5#r{;H;0Q++5q@h@suCoz*;Ox`21uuCo?+k#AVhR8Bq#;SjZ%-0ksc9YRJPjTajcusA?kL@qdbyi&R~fH5)i7jm z-bbc3$j0OZYCGM5t>n_oR*!l(OFCG)dNi|?@)_Jl_6)O{F);#ti>{(yQJ1Mtu)e+^ z^Klohr=E~c>2&fEH4jvj%{ZPx#HXPCyu_aN23@x4=&pn)+46LEhUCIkX@b?%LmVqa ziGSkArV0C8om`us{RwPn(~w)A)gU@12#IdcF^N1ng@0xF+L~d#=n-!5;H$v1x<;)> z#lMg4$K=pe)c53K(+h2tG6YctVd6M3LYm_KOKt-~Py`stEeT(;9kqt8VZ3?GoNArq zb=!NnPl``p?|Gj6JaYIuYyeY1Wnu*{fGXiUUel>y%0XshN8E|ukI6*pF=@jtHH4f? L_9LGXTZ#Vzn3yCM literal 0 HcmV?d00001 diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 00000000000..bdd437b84a1 --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,33 @@ +{ + "boi_token": "", + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eoi_token": "", + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "image_token": "", + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 00000000000..9c2583b0aa4 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,51347 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "5": { + "content": "[multimodal]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "6": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "7": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "8": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "9": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "10": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "11": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "12": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "13": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "14": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "15": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "16": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "17": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "18": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "19": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "20": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "21": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "22": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "23": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "24": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "25": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "26": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "27": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "28": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "29": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "30": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "31": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "32": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "33": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "34": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "35": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "36": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "37": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "38": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "39": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "40": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "41": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "42": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "43": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "44": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "45": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "46": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "47": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "48": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "49": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "50": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "51": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "52": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "53": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "54": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "55": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "56": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "57": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "58": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "59": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "60": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "61": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "62": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "63": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "64": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "65": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "66": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "67": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "68": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "69": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "70": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "71": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "72": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "73": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "74": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "75": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "76": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "77": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "78": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "79": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "80": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "81": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "82": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "83": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "84": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "85": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "86": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "87": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "88": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "89": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "90": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "91": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "92": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "93": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "94": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "95": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "96": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "97": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "98": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "99": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "100": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "101": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "102": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "103": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "104": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "105": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "106": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "107": { + "content": "\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "108": { + "content": "\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "109": { + "content": "\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "110": { + "content": "\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "111": { + "content": "\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "112": { + "content": "\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "113": { + "content": "\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "114": { + "content": "\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "115": { + "content": "\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "116": { + "content": "\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "117": { + "content": "\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "118": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "119": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "120": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "121": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "122": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "123": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "124": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "125": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "126": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "127": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "128": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "129": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "130": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "131": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "132": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "133": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "134": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "135": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "136": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "137": { + "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "138": { + "content": "\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "139": { + "content": "\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "140": { + "content": "\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "141": { + "content": "\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "142": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "143": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "144": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "145": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "146": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "147": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "148": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "149": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "150": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "152": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "153": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "154": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "155": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "156": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "157": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "158": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "159": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "160": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "161": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "162": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "164": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "165": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "166": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "167": { + "content": "\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581\u2581", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "168": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "169": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "171": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "172": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "173": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "174": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "170": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "175": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "176": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "177": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "178": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "179": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "180": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "181": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "182": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "183": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "184": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "185": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "186": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "187": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "188": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "189": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "190": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "191": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "192": { + "content": "

", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "193": { + "content": "
", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "194": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "195": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "196": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "197": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "198": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "199": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "200": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "201": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "202": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "203": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "204": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "205": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "206": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "207": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "208": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "209": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "210": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "211": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "212": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "213": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "214": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "215": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "216": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "217": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "218": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "219": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "220": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "221": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "222": { + "content": "