Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ func (s *SQLite) BlobDBFile() (string, error) {
}

func (s *SQLite) connectionStringForDBFile(dbFile string) string {
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate", dbFile)
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_auto_vacuum=incremental", dbFile)
if s.BusyTimeoutSeconds > 0 {
timeout := s.BusyTimeoutSeconds * 1000
connectionString = fmt.Sprintf("%s&_busy_timeout=%d", connectionString, timeout)
Expand Down
4 changes: 2 additions & 2 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ func TestGormParams(t *testing.T) {
dbType, uri, err := cfg.GormParams()
require.Nil(t, err)
require.Equal(t, SQLiteBackend, dbType)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate"), uri)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_auto_vacuum=incremental"), uri)

cfg.SQLite.BusyTimeoutSeconds = 5
dbType, uri, err = cfg.GormParams()
require.Nil(t, err)
require.Equal(t, SQLiteBackend, dbType)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_busy_timeout=5000"), uri)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_auto_vacuum=incremental&_busy_timeout=5000"), uri)

cfg.DbBackend = MySQLBackend
cfg.MySQL = getMySQLDefaultConfig()
Expand Down
97 changes: 54 additions & 43 deletions database/sql/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,47 @@ import (
"github.com/cloudbase/garm/util"
)

// func (s *sqlDatabase) CreateFileObject(ctx context.Context, name string, size int64, tags []string, reader io.Reader) (fileObjParam params.FileObject, err error) {
// streamBlobContent opens a raw SQLite blob handle, streams initialData followed
// by the rest of r into it, and returns the hex-encoded SHA256 of the written content.
// The raw *sql.Conn is closed before returning so the caller can safely use
// s.objectsConn afterwards without pool starvation.
func (s *sqlDatabase) streamBlobContent(ctx context.Context, blobID uint, initialData []byte, r io.Reader) (string, error) {
conn, err := s.objectsSQLDB.Conn(ctx)
if err != nil {
return "", fmt.Errorf("failed to get connection from pool: %w", err)
}
defer conn.Close()

var sha256sum string
err = conn.Raw(func(driverConn any) error {
sqliteConn := driverConn.(*sqlite3.SQLiteConn)

blob, err := sqliteConn.Blob("main", "file_blobs", "content", int64(blobID), 1)
if err != nil {
return fmt.Errorf("failed to open blob: %w", err)
}
defer blob.Close()

hasher := sha256.New()

if _, err := blob.Write(initialData); err != nil {
return fmt.Errorf("failed to write blob initial buffer: %w", err)
}
hasher.Write(initialData)

if _, err := io.Copy(io.MultiWriter(blob, hasher), r); err != nil {
return fmt.Errorf("failed to write blob: %w", err)
}

sha256sum = hex.EncodeToString(hasher.Sum(nil))
return nil
})
if err != nil {
return "", fmt.Errorf("failed to write blob: %w", err)
}
return sha256sum, nil
}

func (s *sqlDatabase) CreateFileObject(ctx context.Context, param params.CreateFileObjectParams, reader io.Reader) (fileObjParam params.FileObject, err error) {
// Save the file to temporary storage first. This allows us to accept the entire file, even over
// a slow connection, without locking the database as we stream the file to the DB.
Expand Down Expand Up @@ -99,44 +139,12 @@ func (s *sqlDatabase) CreateFileObject(ctx context.Context, param params.CreateF
if err != nil {
return params.FileObject{}, fmt.Errorf("failed to create database entry for blob: %w", err)
}
// Stream file to blob and compute SHA256
conn, err := s.objectsSQLDB.Conn(ctx)
if err != nil {
return params.FileObject{}, fmt.Errorf("failed to get connection from pool: %w", err)
}
defer conn.Close()

var sha256sum string
err = conn.Raw(func(driverConn any) error {
sqliteConn := driverConn.(*sqlite3.SQLiteConn)

blob, err := sqliteConn.Blob("main", "file_blobs", "content", int64(fileBlob.ID), 1)
if err != nil {
return fmt.Errorf("failed to open blob: %w", err)
}
defer blob.Close()

// Create SHA256 hasher
hasher := sha256.New()

// Write the buffered data first
if _, err := blob.Write(buffer[:n]); err != nil {
return fmt.Errorf("failed to write blob initial buffer: %w", err)
}
hasher.Write(buffer[:n])

// Stream the rest with hash computation
_, err = io.Copy(io.MultiWriter(blob, hasher), tmpFile)
if err != nil {
return fmt.Errorf("failed to write blob: %w", err)
}

// Get final hash
sha256sum = hex.EncodeToString(hasher.Sum(nil))
return nil
})
// Stream file to blob and compute SHA256.
// We obtain a raw *sql.Conn for the SQLite blob API, which pins a connection
// from the pool. We must close it before using s.objectsConn again.
sha256sum, err := s.streamBlobContent(ctx, fileBlob.ID, buffer[:n], tmpFile)
if err != nil {
return params.FileObject{}, fmt.Errorf("failed to write blob: %w", err)
return params.FileObject{}, err
}

// Update document with SHA256
Expand Down Expand Up @@ -405,15 +413,18 @@ func (s *sqlDatabase) SearchFileObjectByTags(_ context.Context, tags []string, p

// OpenFileObjectContent opens a blob for reading and returns an io.ReadCloser
func (s *sqlDatabase) OpenFileObjectContent(ctx context.Context, objID uint) (io.ReadCloser, error) {
conn, err := s.objectsSQLDB.Conn(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}

// Query the blob metadata first, before pinning a raw connection.
// With MaxOpenConns(1), pinning the connection before this query would
// deadlock because GORM needs the same pooled connection.
var fileBlob FileBlob
if err := s.objectsConn.Where("file_object_id = ?", objID).Omit("content").First(&fileBlob).Error; err != nil {
return nil, fmt.Errorf("failed to get file blob: %w", err)
}

conn, err := s.objectsSQLDB.Conn(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
var blobReader io.ReadCloser
err = conn.Raw(func(driverConn any) error {
sqliteConn := driverConn.(*sqlite3.SQLiteConn)
Expand Down
144 changes: 0 additions & 144 deletions database/sql/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
package sql

import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"testing"
Expand All @@ -26,7 +24,6 @@ import (

runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
garmTesting "github.com/cloudbase/garm/internal/testing"
Expand Down Expand Up @@ -909,144 +906,3 @@ func (s *GithubTestSuite) TestDeleteGithubEndpointFailsWithOrgsReposOrCredential
func TestGithubTestSuite(t *testing.T) {
suite.Run(t, new(GithubTestSuite))
}

func TestCredentialsAndEndpointMigration(t *testing.T) {
cfg := garmTesting.GetTestSqliteDBConfig(t)

// Copy the sample DB
data, err := os.ReadFile("../../testdata/db/v0.1.4/garm.db")
if err != nil {
t.Fatalf("failed to read test data: %s", err)
}

if cfg.SQLite.DBFile == "" {
t.Fatalf("DB file not set")
}
if err := os.WriteFile(cfg.SQLite.DBFile, data, 0o600); err != nil {
t.Fatalf("failed to write test data: %s", err)
}

// define some credentials
credentials := []config.Github{
{
Name: "test-creds",
Description: "test creds",
AuthType: config.GithubAuthTypePAT,
PAT: config.GithubPAT{
OAuth2Token: "test",
},
},
{
Name: "ghes-test",
Description: "ghes creds",
APIBaseURL: testAPIBaseURL,
UploadBaseURL: testUploadBaseURL,
BaseURL: testBaseURL,
AuthType: config.GithubAuthTypeApp,
App: config.GithubApp{
AppID: 1,
InstallationID: 99,
PrivateKeyPath: "../../testdata/certs/srv-key.pem",
},
},
}
// Set the config credentials in the cfg. This is what happens in the main function.
// of GARM as well.
cfg.MigrateCredentials = credentials

ctx := context.Background()
watcher.InitWatcher(ctx)
defer watcher.CloseWatcher()

db, err := NewSQLDatabase(ctx, cfg)
if err != nil {
t.Fatalf("failed to create db connection: %s", err)
}

// We expect that 2 endpoints will exist in the migrated DB and 2 credentials.
ctx = garmTesting.ImpersonateAdminContext(ctx, db, t)

endpoints, err := db.ListGithubEndpoints(ctx)
if err != nil {
t.Fatalf("failed to list endpoints: %s", err)
}
if len(endpoints) != 2 {
t.Fatalf("expected 2 endpoints, got %d", len(endpoints))
}
if endpoints[0].Name != defaultGithubEndpoint {
t.Fatalf("expected default endpoint to exist, got %s", endpoints[0].Name)
}
if endpoints[1].Name != "example.com" {
t.Fatalf("expected example.com endpoint to exist, got %s", endpoints[1].Name)
}
if endpoints[1].UploadBaseURL != testUploadBaseURL {
t.Fatalf("expected upload base URL to be %s, got %s", testUploadBaseURL, endpoints[1].UploadBaseURL)
}
if endpoints[1].BaseURL != testBaseURL {
t.Fatalf("expected base URL to be %s, got %s", testBaseURL, endpoints[1].BaseURL)
}
if endpoints[1].APIBaseURL != testAPIBaseURL {
t.Fatalf("expected API base URL to be %s, got %s", testAPIBaseURL, endpoints[1].APIBaseURL)
}

creds, err := db.ListGithubCredentials(ctx)
if err != nil {
t.Fatalf("failed to list credentials: %s", err)
}
if len(creds) != 2 {
t.Fatalf("expected 2 credentials, got %d", len(creds))
}
if creds[0].Name != "test-creds" {
t.Fatalf("expected test-creds to exist, got %s", creds[0].Name)
}
if creds[1].Name != "ghes-test" {
t.Fatalf("expected ghes-test to exist, got %s", creds[1].Name)
}
if creds[0].Endpoint.Name != defaultGithubEndpoint {
t.Fatalf("expected test-creds to be associated with default endpoint, got %s", creds[0].Endpoint.Name)
}
if creds[1].Endpoint.Name != "example.com" {
t.Fatalf("expected ghes-test to be associated with example.com endpoint, got %s", creds[1].Endpoint.Name)
}

if creds[0].AuthType != params.ForgeAuthTypePAT {
t.Fatalf("expected test-creds to have PAT auth type, got %s", creds[0].AuthType)
}
if creds[1].AuthType != params.ForgeAuthTypeApp {
t.Fatalf("expected ghes-test to have App auth type, got %s", creds[1].AuthType)
}
if len(creds[0].CredentialsPayload) == 0 {
t.Fatalf("expected test-creds to have credentials payload, got empty")
}

var pat params.GithubPAT
if err := json.Unmarshal(creds[0].CredentialsPayload, &pat); err != nil {
t.Fatalf("failed to unmarshal test-creds credentials payload: %s", err)
}
if pat.OAuth2Token != "test" {
t.Fatalf("expected test-creds to have PAT token test, got %s", pat.OAuth2Token)
}

var app params.GithubApp
if err := json.Unmarshal(creds[1].CredentialsPayload, &app); err != nil {
t.Fatalf("failed to unmarshal ghes-test credentials payload: %s", err)
}
if app.AppID != 1 {
t.Fatalf("expected ghes-test to have app ID 1, got %d", app.AppID)
}
if app.InstallationID != 99 {
t.Fatalf("expected ghes-test to have installation ID 99, got %d", app.InstallationID)
}
if app.PrivateKeyBytes == nil {
t.Fatalf("expected ghes-test to have private key bytes, got nil")
}

certBundle, err := credentials[1].App.PrivateKeyBytes()
if err != nil {
t.Fatalf("failed to read CA cert bundle: %s", err)
}

if !bytes.Equal(app.PrivateKeyBytes, certBundle) {
t.Fatalf("expected ghes-test private key to be equal to the CA cert bundle")
}
}
2 changes: 1 addition & 1 deletion database/sql/instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param p
return fmt.Errorf("error fetching pool: %w", err)
}
var cnt int64
q := s.conn.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt)
q := tx.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt)
if q.Error != nil {
return fmt.Errorf("error fetching instance count: %w", q.Error)
}
Expand Down
29 changes: 29 additions & 0 deletions database/sql/migrations/0001_baseline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2026 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

package migrations

import (
"github.com/go-gormigrate/gormigrate/v2"
"gorm.io/gorm"
)

func init() {
Register(&gormigrate.Migration{
ID: "0001_baseline",
Migrate: func(tx *gorm.DB) error {
return nil
},
})
}
29 changes: 29 additions & 0 deletions database/sql/migrations/0001_baseline_file_objects.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2026 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

package migrations

import (
"github.com/go-gormigrate/gormigrate/v2"
"gorm.io/gorm"
)

func init() {
RegisterFileObjects(&gormigrate.Migration{
ID: "0001_baseline",
Migrate: func(tx *gorm.DB) error {
return nil
},
})
}
Loading