From eed4b782fd6811b132f91397a4c38fc588915155 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 18 Dec 2025 09:56:43 -0800 Subject: [PATCH 01/14] modifies multipart upload to call single part when file is <5GB --- client/g3cmd/upload-multipart.go | 34 +++++++++++++-------- client/g3cmd/upload-single.go | 51 ++++++++++++++++---------------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/client/g3cmd/upload-multipart.go b/client/g3cmd/upload-multipart.go index fa6d26f..bc658b4 100644 --- a/client/g3cmd/upload-multipart.go +++ b/client/g3cmd/upload-multipart.go @@ -104,22 +104,19 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi g3.Logger().Printf("File Name: '%s', File Size: '%d'\n", stat.Name(), stat.Size()) - if stat.Size() == 0 { + fileSize := stat.Size() + if fileSize == 0 { return fmt.Errorf("file is empty: %s", req.Filename) } - // Initialize multipart upload - uploadID, finalGUID, err := InitMultipartUpload(g3, req, bucketName) - if err != nil { - return fmt.Errorf("failed to initiate multipart upload: %w", err) + if fileSize < 5*1024*1024*1024 { + g3.Logger().Printf("File size < 5GB (%d bytes), using single-part upload\n", fileSize) + err := UploadSingle(g3.GetCredential().Profile, req.GUID, req.FilePath, req.Bucket, showProgress) + if err != nil { + g3.Logger().Fatal(err.Error()) + } + return nil } - req.GUID = finalGUID // update with server-provided GUID - - key := finalGUID + "/" + req.Filename - chunkSize := optimalChunkSize(stat.Size()) - - numChunks := int((stat.Size() + chunkSize - 1) / chunkSize) - parts := make([]MultipartPartObject, 0, numChunks) // Progress bar setup (modern mpb) var p *mpb.Progress @@ -138,6 +135,19 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi ) } + // Initialize multipart upload + uploadID, finalGUID, err := InitMultipartUpload(g3, req, bucketName) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + req.GUID = finalGUID // update with server-provided GUID + + key := finalGUID + "/" + req.Filename + chunkSize := optimalChunkSize(stat.Size()) + + numChunks := int((stat.Size() + chunkSize - 1) / chunkSize) + parts := make([]MultipartPartObject, 0, numChunks) + // Channel for chunk indices chunks := make(chan int, numChunks) for i := 1; i <= numChunks; i++ { diff --git a/client/g3cmd/upload-single.go b/client/g3cmd/upload-single.go index 8395370..f154c4d 100644 --- a/client/g3cmd/upload-single.go +++ b/client/g3cmd/upload-single.go @@ -43,18 +43,19 @@ func init() { RootCmd.AddCommand(uploadSingleCmd) } -func UploadSingle(profile string, guid string, filePath string, bucketName string, enableLogs bool) error { +func UploadSingle(profile string, guid string, filePath string, bucketName string, showProgress bool) error { - logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog()) - if enableLogs { - logger, closer = logs.New( - profile, - logs.WithSucceededLog(), - logs.WithFailedLog(), - logs.WithScoreboard(), - logs.WithConsole(), - ) + opts := []logs.Option{ + logs.WithSucceededLog(), + logs.WithFailedLog(), + logs.WithMessageFile(), } + + if showProgress { + opts = append(opts, logs.WithScoreboard(), logs.WithConsole()) + } + + logger, closer := logs.New(profile, opts...) defer closer() // Instantiate interface to Gen3 @@ -67,6 +68,14 @@ func UploadSingle(profile string, guid string, filePath string, bucketName strin return fmt.Errorf("failed to parse config on profile %s: %w", profile, err) } + updateUI := func() { + if showProgress { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } + } + filePaths, err := common.ParseFilePaths(filePath, false) if len(filePaths) > 1 { return errors.New("more than 1 file location has been found. Do not use \"*\" in file path or provide a folder as file path") @@ -80,17 +89,13 @@ func UploadSingle(profile string, guid string, filePath string, bucketName strin filename := filepath.Base(filePath) if _, err := os.Stat(filePath); os.IsNotExist(err) { g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() + updateUI() return fmt.Errorf("[ERROR] The file you specified \"%s\" does not exist locally\n", filePath) } file, err := os.Open(filePath) if err != nil { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() + updateUI() g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) g3i.Logger().Println("File open error: " + err.Error()) return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", filePath, err.Error()) @@ -98,25 +103,21 @@ func UploadSingle(profile string, guid string, filePath string, bucketName strin defer file.Close() furObject := common.FileUploadRequestObject{FilePath: filePath, Filename: filename, GUID: guid, Bucket: bucketName} - furObject, err = GenerateUploadRequest(g3i, furObject, file, nil) if err != nil { file.Close() g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() + updateUI() g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) } err = uploadFile(g3i, furObject, 0) if err != nil { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) + updateUI() return fmt.Errorf("[ERROR] Error uploading file %s: %s\n", filePath, err.Error()) - } else { - g3i.Logger().Scoreboard().IncrementSB(0) } - g3i.Logger().Scoreboard().PrintSB() + if showProgress { + g3i.Logger().Scoreboard().PrintSB() + } return nil } From c6be76f17edbbf4ecea627bfcf62a633410c0aab Mon Sep 17 00:00:00 2001 From: Matthew Peterkort <33436238+matthewpeterkort@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:52:29 -0800 Subject: [PATCH 02/14] split client out further into more digestable parts (#13) * split client out further into more digestable parts * simplify, factor business logic out of cmd package and into upload and download packages * rearrange some more stuff. Start testing * fix some bugs * abstract out token handling into lowest level request interface * fix a bunch of bugs * adds very minimal docs --- client/api/gen3.go | 366 ++++++++++++ client/{jwt/utils.go => api/types.go} | 17 +- client/client/client.go | 69 +++ client/common/common.go | 105 +--- client/common/constants.go | 89 +++ client/common/logHelper.go | 6 - client/common/types.go | 59 ++ client/conf/config.go | 256 +++++++++ client/conf/validate.go | 69 +++ client/download/batch.go | 164 ++++++ client/download/downloader.go | 166 ++++++ client/download/file_info.go | 125 +++++ client/download/types.go | 60 ++ client/download/url_resolution.go | 80 +++ client/download/utils.go | 79 +++ client/g3cmd/delete.go | 34 -- client/g3cmd/download-multiple.go | 495 ---------------- client/g3cmd/gitversion.go | 6 - client/g3cmd/retry-upload.go | 215 ------- client/g3cmd/root.go | 124 ---- client/g3cmd/upload-multipart.go | 309 ---------- client/g3cmd/upload-multiple.go | 236 -------- client/g3cmd/upload-single.go | 123 ---- client/g3cmd/utils.go | 686 ----------------------- client/gen3Client/client.go | 120 ---- client/jwt/configure.go | 321 ----------- client/jwt/functions.go | 370 ------------ client/jwt/update.go | 78 --- client/logs/scoreboard.go | 18 - client/logs/tee_logger.go | 10 +- client/mocks/mock_configure.go | 129 ++--- client/mocks/mock_functions.go | 110 ++-- client/mocks/mock_gen3interface.go | 132 +++-- client/mocks/mock_request.go | 49 +- client/request/auth.go | 103 ++++ client/request/builder.go | 54 ++ client/request/request.go | 96 ++++ client/upload/batch.go | 161 ++++++ client/upload/multipart.go | 303 ++++++++++ client/upload/request.go | 125 +++++ client/upload/retry.go | 171 ++++++ client/upload/singleFile.go | 97 ++++ client/upload/types.go | 73 +++ client/upload/upload.go | 125 +++++ client/upload/utils.go | 133 +++++ {client/g3cmd => cmd}/auth.go | 12 +- {client/g3cmd => cmd}/configure.go | 25 +- cmd/delete.go | 42 ++ cmd/download-multipart.go | 261 +++++++++ cmd/download-multiple.go | 111 ++++ {client/g3cmd => cmd}/download-single.go | 28 +- {client/g3cmd => cmd}/generate-tsv.go | 2 +- cmd/gitversion.go | 6 + cmd/retry-upload.go | 59 ++ cmd/root.go | 31 + cmd/upload-multipart.go | 82 +++ cmd/upload-multiple.go | 176 ++++++ cmd/upload-single.go | 37 ++ {client/g3cmd => cmd}/upload.go | 61 +- docs/DEVELOPER_DOCS.md | 91 +++ go.mod | 4 +- go.sum | 16 +- main.go | 4 +- tests/download-multiple_test.go | 271 +++++---- tests/functions_test.go | 286 +++++----- tests/utils_test.go | 281 +++++----- 66 files changed, 4675 insertions(+), 3927 deletions(-) create mode 100644 client/api/gen3.go rename client/{jwt/utils.go => api/types.go} (62%) create mode 100644 client/client/client.go create mode 100644 client/common/constants.go create mode 100644 client/common/types.go create mode 100644 client/conf/config.go create mode 100644 client/conf/validate.go create mode 100644 client/download/batch.go create mode 100644 client/download/downloader.go create mode 100644 client/download/file_info.go create mode 100644 client/download/types.go create mode 100644 client/download/url_resolution.go create mode 100644 client/download/utils.go delete mode 100644 client/g3cmd/delete.go delete mode 100644 client/g3cmd/download-multiple.go delete mode 100644 client/g3cmd/gitversion.go delete mode 100644 client/g3cmd/retry-upload.go delete mode 100644 client/g3cmd/root.go delete mode 100644 client/g3cmd/upload-multipart.go delete mode 100644 client/g3cmd/upload-multiple.go delete mode 100644 client/g3cmd/upload-single.go delete mode 100644 client/g3cmd/utils.go delete mode 100644 client/gen3Client/client.go delete mode 100644 client/jwt/configure.go delete mode 100644 client/jwt/functions.go delete mode 100644 client/jwt/update.go create mode 100644 client/request/auth.go create mode 100644 client/request/builder.go create mode 100644 client/request/request.go create mode 100644 client/upload/batch.go create mode 100644 client/upload/multipart.go create mode 100644 client/upload/request.go create mode 100644 client/upload/retry.go create mode 100644 client/upload/singleFile.go create mode 100644 client/upload/types.go create mode 100644 client/upload/upload.go create mode 100644 client/upload/utils.go rename {client/g3cmd => cmd}/auth.go (88%) rename {client/g3cmd => cmd}/configure.go (83%) create mode 100644 cmd/delete.go create mode 100644 cmd/download-multipart.go create mode 100644 cmd/download-multiple.go rename {client/g3cmd => cmd}/download-single.go (83%) rename {client/g3cmd => cmd}/generate-tsv.go (96%) create mode 100644 cmd/gitversion.go create mode 100644 cmd/retry-upload.go create mode 100644 cmd/root.go create mode 100644 cmd/upload-multipart.go create mode 100644 cmd/upload-multiple.go create mode 100644 cmd/upload-single.go rename {client/g3cmd => cmd}/upload.go (70%) create mode 100644 docs/DEVELOPER_DOCS.md diff --git a/client/api/gen3.go b/client/api/gen3.go new file mode 100644 index 0000000..1746638 --- /dev/null +++ b/client/api/gen3.go @@ -0,0 +1,366 @@ +package api + +//go:generate mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/api FunctionInterface + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/request" + "github.com/hashicorp/go-version" +) + +func NewFunctions(config conf.ManagerInterface, request request.RequestInterface, cred *conf.Credential, logger logs.Logger) FunctionInterface { + return &Functions{ + RequestInterface: request, + Cred: cred, + Config: config, + Logger: logger, + } +} + +type Functions struct { + request.RequestInterface + + Cred *conf.Credential + Config conf.ManagerInterface + Logger logs.Logger +} + +type FunctionInterface interface { + request.RequestInterface + + CheckPrivileges(ctx context.Context) (map[string]any, error) + CheckForShepherdAPI(ctx context.Context) (bool, error) + DeleteRecord(ctx context.Context, guid string) (string, error) + GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) + + ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) + ExportCredential(ctx context.Context, cred *conf.Credential) error +} + +func (f *Functions) NewAccessToken(ctx context.Context) error { + if f.Cred.APIKey == "" { + return errors.New("APIKey is required to refresh access token") + } + + payload, err := json.Marshal(map[string]string{"api_key": f.Cred.APIKey}) + if err != nil { + return err + } + bodyReader := bytes.NewReader(payload) + + resp, err := f.Do( + ctx, + f.New(http.MethodPost, f.Cred.APIEndpoint+common.FenceAccessTokenEndpoint). + WithHeader(common.HeaderContentType, common.MIMEApplicationJSON). + WithBody(bodyReader), + ) + + if err != nil { + return fmt.Errorf("Error when calling Request.Do: %s", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("failed to refresh token, status: " + strconv.Itoa(resp.StatusCode)) + } + + var result common.AccessTokenStruct + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return errors.New("failed to parse token response: " + err.Error()) + } + + f.Cred.AccessToken = result.AccessToken + return nil +} + +func (f *Functions) GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { + hasShepherd, err := f.CheckForShepherdAPI(ctx) // error already logged upstream + if err == nil && hasShepherd { + return f.resolveFromShepherd(ctx, guid) + } + return f.resolveFromFence(ctx, guid, protocolText) +} + +// Todo: why isn't this calld in every fence response that has a body ? why is this seperated out +func (f *Functions) ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) { + msg := FenceResponse{} + if resp == nil { + return msg, errors.New("Nil response received") + } + + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + bodyStr := buf.String() + + err := json.Unmarshal(buf.Bytes(), &msg) + if err != nil { + return msg, fmt.Errorf("failed to decode JSON: %w (Raw body: %s)", err, buf.String()) + } + + if !(resp.StatusCode == 200 || resp.StatusCode == 201 || resp.StatusCode == 204) { + strUrl := resp.Request.URL.String() + switch resp.StatusCode { + case http.StatusUnauthorized: + return msg, fmt.Errorf("401 Unauthorized: %s (URL: %s)", bodyStr, strUrl) + case http.StatusForbidden: + return msg, fmt.Errorf("403 Forbidden: %s (URL: %s)", bodyStr, strUrl) + case http.StatusNotFound: + return msg, fmt.Errorf("404 Not Found: %s (URL: %s)", bodyStr, strUrl) + case http.StatusInternalServerError: + return msg, fmt.Errorf("500 Internal Server Error: %s (URL: %s)", bodyStr, strUrl) + case http.StatusServiceUnavailable: + return msg, fmt.Errorf("503 Service Unavailable: %s (URL: %s)", bodyStr, strUrl) + case http.StatusBadGateway: + return msg, fmt.Errorf("502 Bad Gateway: %s (URL: %s)", bodyStr, strUrl) + default: + return msg, fmt.Errorf("Unexpected Error (%d): %s (URL: %s)", resp.StatusCode, bodyStr, strUrl) + } + } + + // Logic for successful status codes + if strings.Contains(bodyStr, "Can't find a location for the data") { + return msg, errors.New("The provided GUID is not found") + } + + return msg, nil +} + +func (f *Functions) CheckForShepherdAPI(ctx context.Context) (bool, error) { + // Check if Shepherd is enabled + if f.Cred.UseShepherd == "false" { + return false, nil + } + if f.Cred.UseShepherd != "true" && common.DefaultUseShepherd == false { + return false, nil + } + // If Shepherd is enabled, make sure that the commons has a compatible version of Shepherd deployed. + // Compare the version returned from the Shepherd version endpoint with the minimum acceptable Shepherd version. + var minShepherdVersion string + if f.Cred.MinShepherdVersion == "" { + minShepherdVersion = common.DefaultMinShepherdVersion + } else { + minShepherdVersion = f.Cred.MinShepherdVersion + } + + res, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.Cred.APIEndpoint + common.ShepherdVersionEndpoint, + Method: http.MethodGet, + Token: f.Cred.AccessToken, + }, + ) + if err != nil { + return false, errors.New("Error occurred during generating HTTP request: " + err.Error()) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return false, nil + } + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return false, errors.New("Error occurred when reading HTTP request: " + err.Error()) + } + body, err := strconv.Unquote(string(bodyBytes)) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) + } + // Compare the version in the response to the target version + ver, err := version.NewVersion(body) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) + } + minVer, err := version.NewVersion(minShepherdVersion) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing minimum acceptable Shepherd version: %v: %v", minShepherdVersion, err) + } + if ver.GreaterThanOrEqual(minVer) { + return true, nil + } + return false, fmt.Errorf("Shepherd is enabled, but %v does not have correct Shepherd version. (Need Shepherd version >=%v, got %v)", f.Cred.APIEndpoint, minVer, ver) +} +func (f *Functions) CheckPrivileges(ctx context.Context) (map[string]any, error) { + /* + Return user privileges from specified profile + */ + var err error + var data map[string]any + + resp, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.Cred.APIEndpoint + common.FenceUserEndpoint, + Method: http.MethodGet, + Token: f.Cred.AccessToken, + }, + ) + if err != nil { + return nil, errors.New("Error occurred when getting response from remote: " + err.Error()) + } + defer resp.Body.Close() + + str := ResponseToString(resp) + err = json.Unmarshal([]byte(str), &data) + if err != nil { + return nil, errors.New("Error occurred when unmarshalling response: " + err.Error()) + } + + resourceAccess, ok := data["authz"].(map[string]any) + + // If the `authz` section (Arborist permissions) is empty or missing, try get `project_access` section (Fence permissions) + if len(resourceAccess) == 0 || !ok { + resourceAccess, ok = data["project_access"].(map[string]any) + if !ok { + return nil, errors.New("Not possible to read access privileges of user") + } + } + + return resourceAccess, err +} + +func (f *Functions) DeleteRecord(ctx context.Context, guid string) (string, error) { + endpoint := common.FenceDataEndpoint + "/" + guid + msg := "" + hasShepherd, err := f.CheckForShepherdAPI(ctx) + if err != nil { + f.Logger.Printf("WARNING: Error checking Shepherd API: %v. Falling back to Fence.\n", err) + } else if hasShepherd { + endpoint = common.ShepherdEndpoint + "/objects/" + guid + } + + resp, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.Cred.APIEndpoint + endpoint, + Method: http.MethodDelete, + Token: f.Cred.AccessToken, + }, + ) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == 204 { + msg = "Record with GUID " + guid + " has been deleted" + } else { + _, err = f.ParseFenceURLResponse(resp) + if err != nil { + return "", err + } + } + return msg, nil +} + +func (f *Functions) ExportCredential(ctx context.Context, cred *conf.Credential) error { + + if cred.Profile == "" { + return fmt.Errorf("profile name is required") + } + if cred.APIEndpoint == "" { + return fmt.Errorf("API endpoint is required") + } + + // Normalize endpoint + cred.APIEndpoint = strings.TrimSpace(cred.APIEndpoint) + cred.APIEndpoint = strings.TrimSuffix(cred.APIEndpoint, "/") + + // Validate URL format + parsedURL, err := conf.ValidateUrl(cred.APIEndpoint) + if err != nil { + return fmt.Errorf("invalid apiendpoint URL: %w", err) + } + fenceBase := parsedURL.Scheme + "://" + parsedURL.Host + if _, err := f.Config.Load(cred.Profile); err != nil && !errors.Is(err, conf.ErrProfileNotFound) { + return err + } + + if cred.APIKey != "" { + // Always refresh the access token — ignore any old one that might be in the struct + err = f.NewAccessToken(ctx) + if err != nil { + if strings.Contains(err.Error(), "401") { + return fmt.Errorf("authentication failed (401) for %s — your API key is invalid, revoked, or expired", fenceBase) + } + if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "no such host") { + return fmt.Errorf("cannot reach Fence at %s — is this a valid Gen3 commons?", fenceBase) + } + return fmt.Errorf("failed to refresh access token: %w", err) + } + } else { + f.Logger.Printf("WARNING: Your profile will only be valid for 24 hours since you have only provided a refresh token for authentication") + } + + // Clean up shepherd flags + cred.UseShepherd = strings.TrimSpace(cred.UseShepherd) + cred.MinShepherdVersion = strings.TrimSpace(cred.MinShepherdVersion) + + if cred.MinShepherdVersion != "" { + if _, err = version.NewVersion(cred.MinShepherdVersion); err != nil { + return fmt.Errorf("invalid min-shepherd-version: %w", err) + } + } + + if err := f.Config.Save(cred); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + + return nil +} + +func (f *Functions) resolveFromShepherd(ctx context.Context, guid string) (string, error) { + // We use f.Cred.APIEndpoint because the struct owns the credential state + url := fmt.Sprintf("%s%s/objects/%s/download", f.Cred.APIEndpoint, common.ShepherdEndpoint, guid) + + // We call f.Do directly because of method promotion (embedding) + resp, err := f.Do(ctx, f.New(http.MethodGet, url)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("shepherd error: %d", resp.StatusCode) + } + + var result struct { + URL string `json:"url"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to decode shepherd response: %w", err) + } + + return result.URL, nil +} + +func (f *Functions) resolveFromFence(ctx context.Context, guid, protocolText string) (string, error) { + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Url: f.Cred.APIEndpoint + common.FenceDataDownloadEndpoint + "/" + guid + protocolText, + Method: http.MethodGet, + Token: f.Cred.AccessToken, + }, + ) + if err != nil { + return "", errors.New("Failed to get URL from Fence via DoAuthenticatedRequest: " + err.Error()) + } + defer resp.Body.Close() + + msg, err := f.ParseFenceURLResponse(resp) + if err != nil || msg.URL == "" { + return "", errors.New("Failed to get URL from Fence via ParseFenceURLResponse: " + err.Error()) + } + + return msg.URL, nil +} diff --git a/client/jwt/utils.go b/client/api/types.go similarity index 62% rename from client/jwt/utils.go rename to client/api/types.go index 466a3a0..59feec5 100644 --- a/client/jwt/utils.go +++ b/client/api/types.go @@ -1,20 +1,14 @@ -package jwt +package api import ( "bytes" - "encoding/json" "net/http" ) type Message any - type Response any -type AccessTokenStruct struct { - AccessToken string `json:"access_token"` -} - -type JsonMessage struct { +type FenceResponse struct { URL string `json:"url"` GUID string `json:"guid"` UploadID string `json:"uploadId"` @@ -24,15 +18,8 @@ type JsonMessage struct { Size int64 `json:"size"` } -type DoRequest func(*http.Response) *http.Response - func ResponseToString(resp *http.Response) string { buf := new(bytes.Buffer) buf.ReadFrom(resp.Body) // nolint: errcheck return buf.String() } - -func DecodeJsonFromString(str string, msg Message) error { - err := json.Unmarshal([]byte(str), &msg) - return err -} diff --git a/client/client/client.go b/client/client/client.go new file mode 100644 index 0000000..41fb41f --- /dev/null +++ b/client/client/client.go @@ -0,0 +1,69 @@ +package client + +import ( + "context" + "fmt" + + "github.com/calypr/data-client/client/api" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/request" +) + +//go:generate mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/client Gen3Interface + +// Top level wrapper Interface for calling lower level interface functions. +// +// Gen3Interface contains minimum number of methods to enable calling functions in the FunctionInterface +// The credential is embedded in the implementation, so it doesn't need to be passed to each method. +type Gen3Interface interface { + GetCredential() *conf.Credential + Logger() *logs.TeeLogger + + api.FunctionInterface +} + +// Gen3Client wraps jwt.FunctionInterface and embeds the credential +type Gen3Client struct { + Ctx context.Context + api.FunctionInterface + + credential *conf.Credential + logger *logs.TeeLogger +} + +func (g *Gen3Client) Logger() *logs.TeeLogger { + return g.logger +} + +// GetCredential returns the embedded credential +func (g *Gen3Client) GetCredential() *conf.Credential { + return g.credential +} + +// NewGen3Interface returns a Gen3Client that embeds the credential and implements Gen3Interface. +// This eliminates the need to pass credentials around everywhere. +func NewGen3Interface(profile string, logger *logs.TeeLogger, opts ...func(*Gen3Client)) (Gen3Interface, error) { + config := conf.NewConfigure(logger) + cred, err := config.Load(profile) + if err != nil { + return nil, err + } + + if valid, err := config.IsValid(cred); !valid { + return nil, fmt.Errorf("invalid credential: %v", err) + } + + apiClient := api.NewFunctions( + config, + request.NewRequestInterface(logger, cred, config), + cred, + logger, + ) + + return &Gen3Client{ + FunctionInterface: apiClient, + credential: cred, + logger: logger, + }, nil +} diff --git a/client/common/common.go b/client/common/common.go index 8eea7bf..57dce2b 100644 --- a/client/common/common.go +++ b/client/common/common.go @@ -1,112 +1,25 @@ package common import ( + "bytes" + "encoding/json" "fmt" "io" "log" - "net/http" "os" "path/filepath" "strings" - "time" "github.com/hashicorp/go-multierror" - "github.com/vbauerster/mpb/v8" ) -// DefaultUseShepherd sets whether gen3client will attempt to use the Shepherd / Object Management API -// endpoints if available. -// The user can override this default using the `data-client configure` command. -const DefaultUseShepherd = false - -// DefaultMinShepherdVersion is the minimum version of Shepherd that the gen3client will use. -// Before attempting to use Shepherd, the client will check for Shepherd's version, and if the version is -// below this number the gen3client will instead warn the user and fall back to fence/indexd. -// The user can override this default using the `data-client configure` command. -const DefaultMinShepherdVersion = "2.0.0" - -// ShepherdEndpoint is the endpoint postfix for SHEPHERD / the Object Management API -const ShepherdEndpoint = "/mds" - -// ShepherdVersionEndpoint is the endpoint used to check what version of Shepherd a commons has deployed -const ShepherdVersionEndpoint = "/mds/version" - -// IndexdIndexEndpoint is the endpoint postfix for INDEXD index -const IndexdIndexEndpoint = "/index/index" - -// FenceUserEndpoint is the endpoint postfix for FENCE user -const FenceUserEndpoint = "/user/user" - -// FenceDataEndpoint is the endpoint postfix for FENCE data -const FenceDataEndpoint = "/user/data" - -// FenceAccessTokenEndpoint is the endpoint postfix for FENCE access token -const FenceAccessTokenEndpoint = "/user/credentials/api/access_token" - -// FenceDataUploadEndpoint is the endpoint postfix for FENCE data upload -const FenceDataUploadEndpoint = FenceDataEndpoint + "/upload" - -// FenceDataDownloadEndpoint is the endpoint postfix for FENCE data download -const FenceDataDownloadEndpoint = FenceDataEndpoint + "/download" - -// FenceDataMultipartInitEndpoint is the endpoint postfix for FENCE multipart init -const FenceDataMultipartInitEndpoint = FenceDataEndpoint + "/multipart/init" - -// FenceDataMultipartUploadEndpoint is the endpoint postfix for FENCE multipart upload -const FenceDataMultipartUploadEndpoint = FenceDataEndpoint + "/multipart/upload" - -// FenceDataMultipartCompleteEndpoint is the endpoint postfix for FENCE multipart complete -const FenceDataMultipartCompleteEndpoint = FenceDataEndpoint + "/multipart/complete" - -// PathSeparator is os dependent path separator char -const PathSeparator = string(os.PathSeparator) - -// DefaultTimeout is used to set timeout value for http client -const DefaultTimeout = 120 * time.Second - -// FileUploadRequestObject defines a object for file upload -type FileUploadRequestObject struct { - FilePath string - Filename string - FileMetadata FileMetadata - GUID string - PresignedURL string - Request *http.Request - Progress *mpb.Progress - Bar *mpb.Bar - Bucket string `json:"bucket,omitempty"` -} - -// FileDownloadResponseObject defines a object for file download -type FileDownloadResponseObject struct { - DownloadPath string - Filename string - GUID string - URL string - Range int64 - Overwrite bool - Skip bool - Response *http.Response - Writer io.Writer -} - -// FileMetadata defines the metadata accepted by the new object management API, Shepherd -type FileMetadata struct { - Authz []string `json:"authz"` - Aliases []string `json:"aliases"` - // Metadata is an encoded JSON string of any arbitrary metadata the user wishes to upload. - Metadata map[string]any `json:"metadata"` -} - -// RetryObject defines a object for retry upload -type RetryObject struct { - FilePath string - Filename string - FileMetadata FileMetadata - GUID string - RetryCount int - Multipart bool - Bucket string +func ToJSONReader(payload any) (io.Reader, error) { + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to encode JSON payload: %w", err) + } + return &buf, nil } // ParseRootPath parses dirname that has "~" in the beginning diff --git a/client/common/constants.go b/client/common/constants.go new file mode 100644 index 0000000..aae8f10 --- /dev/null +++ b/client/common/constants.go @@ -0,0 +1,89 @@ +package common + +import ( + "os" + "time" +) + +const ( + // B is bytes + B int64 = iota + // KB is kilobytes + KB int64 = 1 << (10 * iota) + // MB is megabytes + MB + // GB is gigabytes + GB + // TB is terrabytes + TB +) +const ( + // DefaultUseShepherd sets whether gen3client will attempt to use the Shepherd / Object Management API + // endpoints if available. + // The user can override this default using the `data-client configure` command. + DefaultUseShepherd = false + + // DefaultMinShepherdVersion is the minimum version of Shepherd that the gen3client will use. + // Before attempting to use Shepherd, the client will check for Shepherd's version, and if the version is + // below this number the gen3client will instead warn the user and fall back to fence/indexd. + // The user can override this default using the `data-client configure` command. + DefaultMinShepherdVersion = "2.0.0" + + // ShepherdEndpoint is the endpoint postfix for SHEPHERD / the Object Management API + ShepherdEndpoint = "/mds" + + // ShepherdVersionEndpoint is the endpoint used to check what version of Shepherd a commons has deployed + ShepherdVersionEndpoint = "/mds/version" + + // IndexdIndexEndpoint is the endpoint postfix for INDEXD index + IndexdIndexEndpoint = "/index/index" + + // FenceUserEndpoint is the endpoint postfix for FENCE user + FenceUserEndpoint = "/user/user" + + // FenceDataEndpoint is the endpoint postfix for FENCE data + FenceDataEndpoint = "/user/data" + + // FenceAccessTokenEndpoint is the endpoint postfix for FENCE access token + FenceAccessTokenEndpoint = "/user/credentials/api/access_token" + + // FenceDataUploadEndpoint is the endpoint postfix for FENCE data upload + FenceDataUploadEndpoint = FenceDataEndpoint + "/upload" + + // FenceDataDownloadEndpoint is the endpoint postfix for FENCE data download + FenceDataDownloadEndpoint = FenceDataEndpoint + "/download" + + // FenceDataMultipartInitEndpoint is the endpoint postfix for FENCE multipart init + FenceDataMultipartInitEndpoint = FenceDataEndpoint + "/multipart/init" + + // FenceDataMultipartUploadEndpoint is the endpoint postfix for FENCE multipart upload + FenceDataMultipartUploadEndpoint = FenceDataEndpoint + "/multipart/upload" + + // FenceDataMultipartCompleteEndpoint is the endpoint postfix for FENCE multipart complete + FenceDataMultipartCompleteEndpoint = FenceDataEndpoint + "/multipart/complete" + + // PathSeparator is os dependent path separator char + PathSeparator = string(os.PathSeparator) + + // DefaultTimeout is used to set timeout value for http client + DefaultTimeout = 120 * time.Second + + HeaderContentType = "Content-Type" + MIMEApplicationJSON = "application/json" + + // FileSizeLimit is the maximun single file size for non-multipart upload (5GB) + FileSizeLimit = 5 * GB + + // MultipartFileSizeLimit is the maximun single file size for multipart upload (5TB) + MultipartFileSizeLimit = 5 * TB + MinMultipartChunkSize = 5 * MB + + // MaxRetryCount is the maximum retry number per record + MaxRetryCount = 5 + MaxWaitTime = 300 + + MaxMultipartParts = 10000 + MaxConcurrentUploads = 10 + MaxRetries = 5 + MinChunkSize = 5 * 1024 * 1024 +) diff --git a/client/common/logHelper.go b/client/common/logHelper.go index a117bbc..5622694 100644 --- a/client/common/logHelper.go +++ b/client/common/logHelper.go @@ -16,9 +16,3 @@ func LoadFailedLog(path string) (map[string]RetryObject, error) { } return m, nil } - -func AlreadySucceededFromFile(filePath string) bool { - // Simple: check if any succeeded log contains this path - // Or just return false — safer to re-upload than skip - return false -} diff --git a/client/common/types.go b/client/common/types.go new file mode 100644 index 0000000..5a0ac8d --- /dev/null +++ b/client/common/types.go @@ -0,0 +1,59 @@ +package common + +import ( + "io" + "net/http" +) + +type AccessTokenStruct struct { + AccessToken string `json:"access_token"` +} + +// FileUploadRequestObject defines a object for file upload +type FileUploadRequestObject struct { + FilePath string + Filename string + FileMetadata FileMetadata + GUID string + PresignedURL string + Bucket string `json:"bucket,omitempty"` +} + +// FileDownloadResponseObject defines a object for file download +type FileDownloadResponseObject struct { + DownloadPath string + Filename string + GUID string + URL string + Range int64 + Overwrite bool + Skip bool + Response *http.Response + Writer io.Writer +} + +// FileMetadata defines the metadata accepted by the new object management API, Shepherd +type FileMetadata struct { + Authz []string `json:"authz"` + Aliases []string `json:"aliases"` + // Metadata is an encoded JSON string of any arbitrary metadata the user wishes to upload. + Metadata map[string]any `json:"metadata"` +} + +// RetryObject defines a object for retry upload +type RetryObject struct { + FilePath string + Filename string + FileMetadata FileMetadata + GUID string + RetryCount int + Multipart bool + Bucket string +} + +type ManifestObject struct { + ObjectID string `json:"object_id"` + SubjectID string `json:"subject_id"` + Title string `json:"title"` + Size int64 `json:"size"` +} diff --git a/client/conf/config.go b/client/conf/config.go new file mode 100644 index 0000000..4297f50 --- /dev/null +++ b/client/conf/config.go @@ -0,0 +1,256 @@ +package conf + +//go:generate mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/conf ManagerInterface + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path" + "strings" + + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "gopkg.in/ini.v1" +) + +var ErrProfileNotFound = errors.New("profile not found in config file") + +type Credential struct { + Profile string + KeyID string + APIKey string + AccessToken string + APIEndpoint string + UseShepherd string + MinShepherdVersion string +} + +type Manager struct { + Logger logs.Logger +} + +func NewConfigure(logs logs.Logger) ManagerInterface { + return &Manager{ + Logger: logs, + } +} + +type ManagerInterface interface { + // Loads credential from ~/.gen3/ credential file + Import(filePath, fenceToken string) (*Credential, error) + + // Loads credential from ~/.gen3/config.ini + Load(profile string) (*Credential, error) + Save(cred *Credential) error + + EnsureExists() error + IsValid(*Credential) (bool, error) +} + +func (man *Manager) configPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + configPath := path.Join( + homeDir + + common.PathSeparator + + ".gen3" + + common.PathSeparator + + "gen3_client_config.ini", + ) + return configPath, nil +} + +func (man *Manager) Load(profile string) (*Credential, error) { + /* + Looking profile in config file. The config file is a text file located at ~/.gen3 directory. It can + contain more than 1 profile. If there is no profile found, the user is asked to run a command to + create the profile + + The format of config file is described as following + + [profile1] + key_id=key_id_example_1 + api_key=api_key_example_1 + access_token=access_token_example_1 + api_endpoint=http://localhost:8000 + use_shepherd=true + min_shepherd_version=2.0.0 + + [profile2] + key_id=key_id_example_2 + api_key=api_key_example_2 + access_token=access_token_example_2 + api_endpoint=http://localhost:8000 + use_shepherd=false + min_shepherd_version= + + Args: + profile: the specific profile in config file + Returns: + An instance of Credential + */ + + homeDir, err := os.UserHomeDir() + if err != nil { + errs := fmt.Errorf("Error occurred when getting home directory: %s", err.Error()) + man.Logger.Printf(errs.Error()) + return nil, errs + } + configPath := path.Join(homeDir + common.PathSeparator + ".gen3" + common.PathSeparator + "gen3_client_config.ini") + + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return nil, fmt.Errorf("%w Run configure command (with a profile if desired) to set up account credentials \n"+ + "Example: ./data-client configure --profile= --cred= --apiendpoint=https://data.mycommons.org", ErrProfileNotFound) + } + + // If profile not in config file, prompt user to set up config first + cfg, err := ini.Load(configPath) + if err != nil { + errs := fmt.Errorf("Error occurred when reading config file: %s", err.Error()) + return nil, errs + } + sec, err := cfg.GetSection(profile) + if err != nil { + return nil, fmt.Errorf("%w: Need to run \"data-client configure --profile="+profile+" --cred= --apiendpoint=\" first", ErrProfileNotFound) + } + + profileConfig := &Credential{ + Profile: profile, + KeyID: sec.Key("key_id").String(), + APIKey: sec.Key("api_key").String(), + AccessToken: sec.Key("access_token").String(), + APIEndpoint: sec.Key("api_endpoint").String(), + UseShepherd: sec.Key("use_shepherd").String(), + MinShepherdVersion: sec.Key("min_shepherd_version").String(), + } + + if profileConfig.KeyID == "" && profileConfig.APIKey == "" && profileConfig.AccessToken == "" { + errs := fmt.Errorf("key_id, api_key and access_token not found in profile.") + return nil, errs + } + if profileConfig.APIEndpoint == "" { + errs := fmt.Errorf("api_endpoint not found in profile.") + return nil, errs + } + + return profileConfig, nil +} + +func (man *Manager) Save(profileConfig *Credential) error { + /* + Overwrite the config file with new credential + + Args: + profileConfig: Credential object represents config of a profile + configPath: file path to config file + */ + configPath, err := man.configPath() + if err != nil { + errs := fmt.Errorf("error occurred when getting config path: %s", err.Error()) + man.Logger.Println(errs.Error()) + return errs + } + cfg, err := ini.Load(configPath) + if err != nil { + errs := fmt.Errorf("error occurred when loading config file: %s", err.Error()) + man.Logger.Println(errs.Error()) + return errs + } + + section := cfg.Section(profileConfig.Profile) + if profileConfig.KeyID != "" { + section.Key("key_id").SetValue(profileConfig.KeyID) + } + if profileConfig.APIKey != "" { + section.Key("api_key").SetValue(profileConfig.APIKey) + } + if profileConfig.AccessToken != "" { + section.Key("access_token").SetValue(profileConfig.AccessToken) + } + if profileConfig.APIEndpoint != "" { + section.Key("api_endpoint").SetValue(profileConfig.APIEndpoint) + } + + section.Key("use_shepherd").SetValue(profileConfig.UseShepherd) + section.Key("min_shepherd_version").SetValue(profileConfig.MinShepherdVersion) + err = cfg.SaveTo(configPath) + if err != nil { + errs := fmt.Errorf("error occurred when saving config file: %s", err.Error()) + man.Logger.Println(errs.Error()) + return fmt.Errorf("error occurred when saving config file: %s", err.Error()) + } + return nil +} + +func (man *Manager) EnsureExists() error { + /* + Make sure the config exists on start up + */ + configPath, err := man.configPath() + if err != nil { + return err + } + + if _, err := os.Stat(path.Dir(configPath)); os.IsNotExist(err) { + osErr := os.Mkdir(path.Join(path.Dir(configPath)), os.FileMode(0777)) + if osErr != nil { + return err + } + _, osErr = os.Create(configPath) + if osErr != nil { + return err + } + } + if _, err := os.Stat(configPath); os.IsNotExist(err) { + _, osErr := os.Create(configPath) + if osErr != nil { + return err + } + } + _, err = ini.Load(configPath) + + return err +} + +func (man *Manager) Import(filePath, fenceToken string) (*Credential, error) { + var cred Credential + + if filePath != "" { + fullPath, err := common.GetAbsolutePath(filePath) + if err != nil { + man.Logger.Println("error parsing credential file path: " + err.Error()) + return nil, err + } + + content, err := os.ReadFile(fullPath) + if err != nil { + if os.IsNotExist(err) { + man.Logger.Println("File not found: " + fullPath) + } else { + man.Logger.Println("error reading file: " + err.Error()) + } + return nil, err + } + + jsonStr := strings.ReplaceAll(string(content), "\n", "") + // Normalize keys from snake_case to CamelCase for unmarshaling + jsonStr = strings.ReplaceAll(jsonStr, "key_id", "KeyID") + jsonStr = strings.ReplaceAll(jsonStr, "api_key", "APIKey") + + if err := json.Unmarshal([]byte(jsonStr), &cred); err != nil { + errMsg := fmt.Errorf("cannot parse JSON credential file: %w", err) + man.Logger.Println(errMsg.Error()) + return nil, errMsg + } + } else if fenceToken != "" { + cred.AccessToken = fenceToken + } else { + return nil, errors.New("either credential file or fence token must be provided") + } + + return &cred, nil +} diff --git a/client/conf/validate.go b/client/conf/validate.go new file mode 100644 index 0000000..d50362b --- /dev/null +++ b/client/conf/validate.go @@ -0,0 +1,69 @@ +package conf + +import ( + "errors" + "fmt" + "net/url" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func ValidateUrl(apiEndpoint string) (*url.URL, error) { + parsedURL, err := url.Parse(apiEndpoint) + if err != nil { + return parsedURL, errors.New("Error occurred when parsing apiendpoint URL: " + err.Error()) + } + if parsedURL.Host == "" { + return parsedURL, errors.New("Invalid endpoint. A valid endpoint looks like: https://www.tests.com") + } + return parsedURL, nil +} + +func (man *Manager) IsValid(profileConfig *Credential) (bool, error) { + if profileConfig == nil { + return false, fmt.Errorf("profileConfig is nil") + } + /* Checks to see if credential in credential file is still valid */ + // Parse the token without verifying the signature to access the claims. + token, _, err := new(jwt.Parser).ParseUnverified(profileConfig.APIKey, jwt.MapClaims{}) + if err != nil { + return false, fmt.Errorf("invalid token format: %v", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return false, fmt.Errorf("unable to parse claims from provided token %#v", token) + } + + exp, ok := claims["exp"].(float64) + if !ok { + return false, fmt.Errorf("'exp' claim not found or is not a number for claims %s", claims) + } + + iat, ok := claims["iat"].(float64) + if !ok { + return false, fmt.Errorf("'iat' claim not found or is not a number for claims %s", claims) + } + + now := time.Now().UTC() + expTime := time.Unix(int64(exp), 0).UTC() + iatTime := time.Unix(int64(iat), 0).UTC() + + if expTime.Before(now) { + return false, fmt.Errorf("key %s expired %s < %s", profileConfig.APIKey, expTime.Format(time.RFC3339), now.Format(time.RFC3339)) + } + if iatTime.After(now) { + return false, fmt.Errorf("key %s not yet valid %s > %s", profileConfig.APIKey, iatTime.Format(time.RFC3339), now.Format(time.RFC3339)) + } + + delta := expTime.Sub(now) + // threshold days set to 10 + if delta > 0 && delta.Hours() < float64(10*24) { + daysUntilExpiration := int(delta.Hours() / 24) + if daysUntilExpiration > 0 { + return true, fmt.Errorf("warning %s: Key will expire in %d days, on %s", profileConfig.APIKey, daysUntilExpiration, expTime.Format(time.RFC3339)) + } + } + return true, nil +} diff --git a/client/download/batch.go b/client/download/batch.go new file mode 100644 index 0000000..de86659 --- /dev/null +++ b/client/download/batch.go @@ -0,0 +1,164 @@ +package download + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "sync/atomic" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "github.com/hashicorp/go-multierror" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + "golang.org/x/sync/errgroup" +) + +// downloadFiles performs bounded parallel downloads and collects ALL errors +func downloadFiles( + ctx context.Context, + g3i client.Gen3Interface, + files []common.FileDownloadResponseObject, + numParallel int, + protocol string, +) (int, error) { + if len(files) == 0 { + return 0, nil + } + + logger := g3i.Logger() + + protocolText := "" + if protocol != "" { + protocolText = "?protocol=" + protocol + } + + // Scoreboard: maxRetries = 0 for now (no retry logic yet) + sb := logs.NewSB(0, logger) + + p := mpb.New(mpb.WithOutput(os.Stdout)) + + var eg errgroup.Group + eg.SetLimit(numParallel) + + var success atomic.Int64 + var mu sync.Mutex + var allErrors []*multierror.Error + + for i := range files { + fdr := &files[i] // capture loop variable + + eg.Go(func() error { + var err error + + defer func() { + if err != nil { + // Final failure bucket + sb.IncrementSB(len(sb.Counts) - 1) + + mu.Lock() + allErrors = append(allErrors, multierror.Append(nil, err)) + mu.Unlock() + } else { + success.Add(1) + sb.IncrementSB(0) // success, no retries + } + }() + + // Get presigned URL + if err = GetDownloadResponse(ctx, g3i, fdr, protocolText); err != nil { + err = fmt.Errorf("get URL for %s (GUID: %s): %w", fdr.Filename, fdr.GUID, err) + return err + } + + // Prepare directories + fullPath := filepath.Join(fdr.DownloadPath, fdr.Filename) + if dir := filepath.Dir(fullPath); dir != "." { + if err = os.MkdirAll(dir, 0766); err != nil { + _ = fdr.Response.Body.Close() + err = fmt.Errorf("mkdir for %s: %w", fullPath, err) + return err + } + } + + flags := os.O_CREATE | os.O_WRONLY + if fdr.Range > 0 { + flags |= os.O_APPEND + } else if fdr.Overwrite { + flags |= os.O_TRUNC + } + + file, err := os.OpenFile(fullPath, flags, 0666) + if err != nil { + _ = fdr.Response.Body.Close() + err = fmt.Errorf("open local file %s: %w", fullPath, err) + return err + } + + // Progress bar for this file + total := fdr.Response.ContentLength + fdr.Range + bar := p.AddBar(total, + mpb.PrependDecorators( + decor.Name(truncateFilename(fdr.Filename, 40)+" "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), "% .1f"), + ), + ) + + if fdr.Range > 0 { + bar.SetCurrent(fdr.Range) + } + + writer := bar.ProxyWriter(file) + + _, copyErr := io.Copy(writer, fdr.Response.Body) + _ = fdr.Response.Body.Close() + _ = file.Close() + + if copyErr != nil { + bar.Abort(true) + err = fmt.Errorf("download failed for %s: %w", fdr.Filename, copyErr) + return err + } + + return nil + }) + } + + // Wait for all downloads + _ = eg.Wait() + p.Wait() + + // Combine errors + var combinedError error + mu.Lock() + if len(allErrors) > 0 { + multiErr := multierror.Append(nil, nil) + for _, e := range allErrors { + multiErr = multierror.Append(multiErr, e.Errors...) + } + combinedError = multiErr.ErrorOrNil() + } + mu.Unlock() + + downloaded := int(success.Load()) + + // Print scoreboard summary + sb.PrintSB() + + if combinedError != nil { + logger.Printf("%d files downloaded, but %d failed:\n", downloaded, len(allErrors)) + logger.Println(combinedError.Error()) + } else { + logger.Printf("%d files downloaded successfully.\n", downloaded) + } + + return downloaded, combinedError +} diff --git a/client/download/downloader.go b/client/download/downloader.go new file mode 100644 index 0000000..92683e2 --- /dev/null +++ b/client/download/downloader.go @@ -0,0 +1,166 @@ +package download + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" +) + +// DownloadMultiple is the public entry point called from g3cmd +func DownloadMultiple( + ctx context.Context, + g3i client.Gen3Interface, + objects []common.ManifestObject, + downloadPath string, + filenameFormat string, + rename bool, + noPrompt bool, + protocol string, + numParallel int, + skipCompleted bool, +) error { + logger := g3i.Logger() + + // === Input validation === + if numParallel < 1 { + return fmt.Errorf("numparallel must be a positive integer") + } + + var err error + downloadPath, err = common.ParseRootPath(downloadPath) + if err != nil { + return fmt.Errorf("invalid download path: %w", err) + } + if !strings.HasSuffix(downloadPath, "/") { + downloadPath += "/" + } + + filenameFormat = strings.ToLower(strings.TrimSpace(filenameFormat)) + if filenameFormat != "original" && filenameFormat != "guid" && filenameFormat != "combined" { + return fmt.Errorf("filename-format must be one of: original, guid, combined") + } + if (filenameFormat == "guid" || filenameFormat == "combined") && rename { + logger.Println("NOTICE: rename flag is ignored in guid/combined mode") + rename = false + } + + // === Warnings and user confirmation === + if err := handleWarningsAndConfirmation(logger, downloadPath, filenameFormat, rename, noPrompt); err != nil { + return err // aborted by user + } + + // === Create download directory === + if err := os.MkdirAll(downloadPath, 0766); err != nil { + return fmt.Errorf("cannot create directory %s: %w", downloadPath, err) + } + + // === Prepare files (metadata + local validation) === + toDownload, skipped, renamed, err := prepareFiles(ctx, g3i, objects, downloadPath, filenameFormat, rename, skipCompleted, protocol) + if err != nil { + return err + } + + logger.Printf("Total objects: %d | To download: %d | Skipped: %d\n", + len(objects), len(toDownload), len(skipped)) + + // === Download phase === + downloaded, downloadErr := downloadFiles(ctx, g3i, toDownload, numParallel, protocol) + + // === Final summary === + logger.Printf("%d files downloaded successfully.\n", downloaded) + printRenamed(logger, renamed) + printSkipped(logger, skipped) + + if downloadErr != nil { + logger.Printf("Some downloads failed. See errors above.\n") + } + + return nil // we log failures but don't fail the whole command unless critical +} + +// handleWarningsAndConfirmation prints warnings and asks for confirmation if needed +func handleWarningsAndConfirmation(logger logs.Logger, downloadPath, filenameFormat string, rename, noPrompt bool) error { + if filenameFormat == "guid" || filenameFormat == "combined" { + logger.Printf("WARNING: in %q mode, duplicate files in %q will be overwritten\n", filenameFormat, downloadPath) + } else if !rename { + logger.Printf("WARNING: rename=false in original mode – duplicates in %q will be overwritten\n", downloadPath) + } else { + logger.Printf("NOTICE: rename=true in original mode – duplicates in %q will be renamed with a counter\n", downloadPath) + } + + if noPrompt { + return nil + } + if !AskForConfirmation(logger, "Proceed? (y/N)") { + logger.Fatal("Aborted by user") + } + return nil +} + +// prepareFiles gathers metadata, checks local files, collects skips/renames +func prepareFiles( + ctx context.Context, + g3i client.Gen3Interface, + objects []common.ManifestObject, + downloadPath, filenameFormat string, + rename, skipCompleted bool, + protocol string, +) ([]common.FileDownloadResponseObject, []RenamedOrSkippedFileInfo, []RenamedOrSkippedFileInfo, error) { + logger := g3i.Logger() + renamed := make([]RenamedOrSkippedFileInfo, 0) + skipped := make([]RenamedOrSkippedFileInfo, 0) + toDownload := make([]common.FileDownloadResponseObject, 0, len(objects)) + + p := mpb.New(mpb.WithOutput(os.Stdout)) + bar := p.AddBar(int64(len(objects)), + mpb.PrependDecorators(decor.Name("Preparing "), decor.CountersNoUnit("%d / %d")), + mpb.AppendDecorators(decor.Percentage()), + ) + + for _, obj := range objects { + if obj.ObjectID == "" { + logger.Println("Empty GUID, skipping entry") + bar.Increment() + continue + } + + info := &IndexdResponse{Name: obj.Title, Size: obj.Size} + var err error + if info.Name == "" || info.Size == 0 { + // Very strict object id checking + info, err = AskGen3ForFileInfo(ctx, g3i, obj.ObjectID, protocol, downloadPath, filenameFormat, rename, &renamed) + if err != nil { + return nil, nil, nil, err + } + } + + fdr := common.FileDownloadResponseObject{ + DownloadPath: downloadPath, + Filename: info.Name, + GUID: obj.ObjectID, + } + + if !rename { + validateLocalFileStat(logger, &fdr, int64(info.Size), skipCompleted) + } + + if fdr.Skip { + logger.Printf("Skipping %q (GUID: %s) – complete local copy exists\n", fdr.Filename, fdr.GUID) + skipped = append(skipped, RenamedOrSkippedFileInfo{GUID: fdr.GUID, OldFilename: fdr.Filename}) + } else { + toDownload = append(toDownload, fdr) + } + + bar.Increment() + } + p.Wait() + logger.Println("Preparation complete") + return toDownload, skipped, renamed, nil +} diff --git a/client/download/file_info.go b/client/download/file_info.go new file mode 100644 index 0000000..e3b6a89 --- /dev/null +++ b/client/download/file_info.go @@ -0,0 +1,125 @@ +package download + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/request" +) + +func AskGen3ForFileInfo( + ctx context.Context, + g3i client.Gen3Interface, + guid, protocol, downloadPath, filenameFormat string, + rename bool, + renamedFiles *[]RenamedOrSkippedFileInfo, +) (*IndexdResponse, error) { + hasShepherd, err := g3i.CheckForShepherdAPI(ctx) + if err != nil { + g3i.Logger().Println("Error checking Shepherd API: " + err.Error()) + g3i.Logger().Println("Falling back to Indexd...") + hasShepherd = false + } + + if hasShepherd { + return fetchFromShepherd(ctx, g3i, guid, downloadPath, filenameFormat, renamedFiles) + } + return fetchFromIndexd(ctx, g3i, http.MethodGet, guid, protocol, downloadPath, filenameFormat, rename, renamedFiles) +} + +func fetchFromShepherd( + ctx context.Context, + g3i client.Gen3Interface, + guid, downloadPath, filenameFormat string, + renamedFiles *[]RenamedOrSkippedFileInfo, +) (*IndexdResponse, error) { + cred := g3i.GetCredential() + res, err := g3i.Do(ctx, + &request.RequestBuilder{ + Url: cred.APIEndpoint + "/" + cred.AccessToken + common.ShepherdEndpoint + "/objects/" + guid, + Method: http.MethodGet, + Token: cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var decoded struct { + Record struct { + FileName string `json:"file_name"` + Size int64 `json:"size"` + } `json:"record"` + } + if err := json.NewDecoder(res.Body).Decode(&decoded); err != nil { + return nil, err + } + + return &IndexdResponse{applyFilenameFormat(decoded.Record.FileName, guid, downloadPath, filenameFormat, false, renamedFiles), decoded.Record.Size}, nil +} + +func fetchFromIndexd( + ctx context.Context, + g3i client.Gen3Interface, method, + guid, protocol, downloadPath, filenameFormat string, + rename bool, + renamedFiles *[]RenamedOrSkippedFileInfo, +) (*IndexdResponse, error) { + + cred := g3i.GetCredential() + resp, err := g3i.Do( + ctx, + &request.RequestBuilder{ + Url: cred.APIEndpoint + common.IndexdIndexEndpoint + "/" + guid, + Method: method, + Token: cred.AccessToken, + }, + ) + if err != nil { + return nil, fmt.Errorf("Error in fetch FromIndexd: %s", err) + } + + defer resp.Body.Close() + msg, err := g3i.ParseFenceURLResponse(resp) + if err != nil { + return nil, err + } + + if filenameFormat == "guid" { + return &IndexdResponse{guid, msg.Size}, nil + } + + if msg.FileName == "" { + return nil, fmt.Errorf("FileName is a required field in Indexd to download the file, but upload record %#v does not contain it", msg) + } + + return &IndexdResponse{applyFilenameFormat(msg.FileName, guid, downloadPath, filenameFormat, rename, renamedFiles), msg.Size}, nil +} + +func applyFilenameFormat(baseName, guid, downloadPath, format string, rename bool, renamedFiles *[]RenamedOrSkippedFileInfo) string { + switch format { + case "guid": + return guid + case "combined": + return guid + "_" + baseName + case "original": + if !rename { + return baseName + } + newName := processOriginalFilename(downloadPath, baseName) + if newName != baseName { + *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{ + GUID: guid, + OldFilename: baseName, + NewFilename: newName, + }) + } + return newName + default: + return baseName + } +} diff --git a/client/download/types.go b/client/download/types.go new file mode 100644 index 0000000..651b97e --- /dev/null +++ b/client/download/types.go @@ -0,0 +1,60 @@ +package download + +import ( + "os" + + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" +) + +type IndexdResponse struct { + Name string + Size int64 +} +type RenamedOrSkippedFileInfo struct { + GUID string + OldFilename string + NewFilename string +} + +func validateLocalFileStat( + logger logs.Logger, + fdr *common.FileDownloadResponseObject, + filesize int64, + skipCompleted bool, +) { + fullPath := fdr.DownloadPath + fdr.Filename + + fi, err := os.Stat(fullPath) + if err != nil { + if os.IsNotExist(err) { + // No local file → full download, nothing special + return + } + logger.Printf("Error statting local file \"%s\": %s\n", fullPath, err.Error()) + logger.Println("Will attempt full download anyway") + return + } + + localSize := fi.Size() + + // User doesn't want to skip completed files → force full overwrite + if !skipCompleted { + fdr.Overwrite = true + return + } + + // Exact match → skip entirely + if localSize == filesize { + fdr.Skip = true + return + } + + // Local file larger than expected → overwrite fully (corrupted or different file) + if localSize > filesize { + fdr.Overwrite = true + return + } + + fdr.Range = localSize +} diff --git a/client/download/url_resolution.go b/client/download/url_resolution.go new file mode 100644 index 0000000..475a55e --- /dev/null +++ b/client/download/url_resolution.go @@ -0,0 +1,80 @@ +package download + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/request" +) + +// GetDownloadResponse gets presigned URL and prepares HTTP response +func GetDownloadResponse(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject, protocolText string) error { + url, err := g3.GetDownloadPresignedUrl(ctx, fdr.GUID, protocolText) + if err != nil { + return err + } + fdr.URL = url + + if fdr.Range > 0 && !isCloudPresignedURL(url) { + if !supportsRange(url) { + fdr.Range = 0 + } + } + + return makeDownloadRequest(ctx, g3, fdr) +} + +func isCloudPresignedURL(url string) bool { + return strings.Contains(url, "X-Amz-Signature") || strings.Contains(url, "X-Goog-Signature") +} + +func supportsRange(url string) bool { + resp, err := http.Head(url) + if err != nil || resp.Header.Get("Accept-Ranges") != "bytes" { + return false + } + return true +} + +func makeDownloadRequest(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject) error { + headers := map[string]string{} + if fdr.Range > 0 { + headers["Range"] = "bytes=" + strconv.FormatInt(fdr.Range, 10) + "-" + } + + resp, err := g3.Do( + ctx, + &request.RequestBuilder{ + Method: http.MethodGet, + Url: fdr.URL, + Headers: headers, + }, + ) + + if err != nil { + return errors.New("Request failed: " + strings.ReplaceAll(err.Error(), fdr.URL, "")) + } + + // Check for non-success status codes + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + defer resp.Body.Close() // Ensure the body is closed + + bodyBytes, err := io.ReadAll(resp.Body) + bodyString := "" + if err == nil { + bodyString = string(bodyBytes) + } + + return fmt.Errorf("non-OK response: %d, body: %s", resp.StatusCode, bodyString) + } + + fdr.Response = resp + return nil +} diff --git a/client/download/utils.go b/client/download/utils.go new file mode 100644 index 0000000..864a0c6 --- /dev/null +++ b/client/download/utils.go @@ -0,0 +1,79 @@ +package download + +import ( + "bufio" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/calypr/data-client/client/logs" +) + +// AskForConfirmation asks user for confirmation before proceed, will wait if user entered garbage +func AskForConfirmation(logger logs.Logger, s string) bool { + reader := bufio.NewReader(os.Stdin) + + for { + logger.Printf("%s [y/n]: ", s) + + response, err := reader.ReadString('\n') + if err != nil { + logger.Fatal("Error occurred during parsing user's confirmation: " + err.Error()) + } + + switch strings.ToLower(strings.TrimSpace(response)) { + case "y", "yes": + return true + case "n", "no": + return false + default: + return false // Example of defaulting to false + } + } +} + +func processOriginalFilename(downloadPath string, actualFilename string) string { + _, err := os.Stat(downloadPath + actualFilename) + if os.IsNotExist(err) { + return actualFilename + } + extension := filepath.Ext(actualFilename) + filename := strings.TrimSuffix(actualFilename, extension) + counter := 2 + for { + newFilename := filename + "_" + strconv.Itoa(counter) + extension + _, err := os.Stat(downloadPath + newFilename) + if os.IsNotExist(err) { + return newFilename + } + counter++ + } +} + +// truncateFilename shortens long filenames for progress bar display +func truncateFilename(name string, max int) string { + if len(name) <= max { + return name + } + return "..." + name[len(name)-max+3:] +} + +// printRenamed shows renamed files in final summary +func printRenamed(logger logs.Logger, renamed []RenamedOrSkippedFileInfo) { + if len(renamed) == 0 { + return + } + logger.Printf("%d files renamed:\n", len(renamed)) + for _, r := range renamed { + logger.Printf(" %q (GUID: %s) → %q\n", r.OldFilename, r.GUID, r.NewFilename) + } +} + +// printSkipped shows skipped files in final summary +func printSkipped(logger logs.Logger, skipped []RenamedOrSkippedFileInfo) { + if len(skipped) == 0 { + return + } + logger.Printf("%d files skipped (complete local copy exists)\n", len(skipped)) +} diff --git a/client/g3cmd/delete.go b/client/g3cmd/delete.go deleted file mode 100644 index 5be6795..0000000 --- a/client/g3cmd/delete.go +++ /dev/null @@ -1,34 +0,0 @@ -package g3cmd - -import ( - "log" - - "github.com/spf13/cobra" -) - -//Not support yet, place holder only - -var deleteCmd = &cobra.Command{ // nolint:deadcode,unused,varcheck - Use: "delete", - Short: "Send DELETE HTTP Request for given URI", - Long: `Deletes a given URI from the database. -If no profile is specified, "default" profile is used for authentication.`, - Example: `./data-client delete --uri=v0/submission/bpa/test/entities/example_id - ./data-client delete --profile=user1 --uri=v0/submission/bpa/test/entities/1af1d0ab-efec-4049-98f0-ae0f4bb1bc64`, - Run: func(cmd *cobra.Command, args []string) { - log.Fatalf("Not supported!") - // request := new(jwt.Request) - // configure := new(jwt.Configure) - // function := new(jwt.Functions) - - // function.Config = configure - // function.Request = request - - // fmt.Println(jwt.ResponseToString( - // function.DoRequestWithSignedHeader(RequestDelete, profile, "txt", uri))) - }, -} - -func init() { - // RootCmd.AddCommand(deleteCmd) -} diff --git a/client/g3cmd/download-multiple.go b/client/g3cmd/download-multiple.go deleted file mode 100644 index d8dbca8..0000000 --- a/client/g3cmd/download-multiple.go +++ /dev/null @@ -1,495 +0,0 @@ -package g3cmd - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - - "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" - "github.com/calypr/data-client/client/logs" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" - - "github.com/spf13/cobra" -) - -// mockgen -destination=../mocks/mock_gen3interface.go -package=mocks . Gen3Interface - -func AskGen3ForFileInfo(g3i client.Gen3Interface, guid string, protocol string, downloadPath string, filenameFormat string, rename bool, renamedFiles *[]RenamedOrSkippedFileInfo) (string, int64) { - var fileName string - var fileSize int64 - - // If the commons has the newer Shepherd API deployed, get the filename and file size from the Shepherd API. - // Otherwise, fall back on Indexd and Fence. - hasShepherd, err := g3i.CheckForShepherdAPI() - if err != nil { - g3i.Logger().Println("Error occurred when checking for Shepherd API: " + err.Error()) - g3i.Logger().Println("Falling back to Indexd...") - } - if hasShepherd { - endPointPostfix := common.ShepherdEndpoint + "/objects/" + guid - _, res, err := g3i.GetResponse(endPointPostfix, "GET", "", nil) - if err != nil { - g3i.Logger().Println("Error occurred when querying filename from Shepherd: " + err.Error()) - g3i.Logger().Println("Using GUID for filename instead.") - if filenameFormat != "guid" { - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: "N/A", NewFilename: guid}) - } - return guid, 0 - } - - decoded := struct { - Record struct { - FileName string `json:"file_name"` - Size int64 `json:"size"` - } - }{} - err = json.NewDecoder(res.Body).Decode(&decoded) - if err != nil { - g3i.Logger().Println("Error occurred when reading response from Shepherd: " + err.Error()) - g3i.Logger().Println("Using GUID for filename instead.") - if filenameFormat != "guid" { - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: "N/A", NewFilename: guid}) - } - return guid, 0 - } - defer res.Body.Close() - - fileName = decoded.Record.FileName - fileSize = decoded.Record.Size - - } else { - // Attempt to get the filename from Indexd - endPointPostfix := common.IndexdIndexEndpoint + "/" + guid - indexdMsg, err := g3i.DoRequestWithSignedHeader(endPointPostfix, "", nil) - if err != nil { - g3i.Logger().Println("Error occurred when querying filename from IndexD: " + err.Error()) - g3i.Logger().Println("Using GUID for filename instead.") - if filenameFormat != "guid" { - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: "N/A", NewFilename: guid}) - } - return guid, 0 - } - - if filenameFormat == "guid" { - return guid, indexdMsg.Size - } - - actualFilename := indexdMsg.FileName - if actualFilename == "" { - if len(indexdMsg.URLs) > 0 { - // Indexd record has no file name but does have URLs, try to guess file name from URL - var indexdURL = indexdMsg.URLs[0] - if protocol != "" { - for _, url := range indexdMsg.URLs { - if strings.HasPrefix(url, protocol) { - indexdURL = url - } - } - } - - actualFilename = guessFilenameFromURL(indexdURL) - if actualFilename == "" { - g3i.Logger().Println("Error occurred when guessing filename for object " + guid) - g3i.Logger().Println("Using GUID for filename instead.") - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: "N/A", NewFilename: guid}) - return guid, indexdMsg.Size - } - } else { - // Neither file name nor URLs exist in the Indexd record - // Indexd record is busted for that file, just return as we are renaming the file for now - // The download logic will handle the errors - g3i.Logger().Println("Neither file name nor URLs exist in the Indexd record of " + guid) - g3i.Logger().Println("The attempt of downloading file is likely to fail! Check Indexd record!") - g3i.Logger().Println("Using GUID for filename instead.") - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: "N/A", NewFilename: guid}) - return guid, indexdMsg.Size - } - } - - fileName = actualFilename - fileSize = indexdMsg.Size - } - - if filenameFormat == "original" { - if !rename { // no renaming in original mode - return fileName, fileSize - } - newFilename := processOriginalFilename(downloadPath, fileName) - if fileName != newFilename { - *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: fileName, NewFilename: newFilename}) - } - return newFilename, fileSize - } - // filenameFormat == "combined" - combinedFilename := guid + "_" + fileName - return combinedFilename, fileSize -} - -func guessFilenameFromURL(URL string) string { - splittedURLWithFilename := strings.Split(URL, "/") - actualFilename := splittedURLWithFilename[len(splittedURLWithFilename)-1] - return actualFilename -} - -func processOriginalFilename(downloadPath string, actualFilename string) string { - _, err := os.Stat(downloadPath + actualFilename) - if os.IsNotExist(err) { - return actualFilename - } - extension := filepath.Ext(actualFilename) - filename := strings.TrimSuffix(actualFilename, extension) - counter := 2 - for { - newFilename := filename + "_" + strconv.Itoa(counter) + extension - _, err := os.Stat(downloadPath + newFilename) - if os.IsNotExist(err) { - return newFilename - } - counter++ - } -} - -func validateLocalFileStat(logger logs.Logger, downloadPath string, filename string, filesize int64, skipCompleted bool) common.FileDownloadResponseObject { - fi, err := os.Stat(downloadPath + filename) // check filename for local existence - if err != nil { - if os.IsNotExist(err) { - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename} // no local file, normal full length download - } - logger.Printf("Error occurred when getting information for file \"%s\": %s\n", downloadPath+filename, err.Error()) - logger.Println("Will try to download the whole file") - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename} // errorred when trying to get local FI, normal full length download - } - - // have existing local file and may want to skip, check more conditions - if !skipCompleted { - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename, Overwrite: true} // not skipping any local files, normal full length download - } - - localFilesize := fi.Size() - if localFilesize == filesize { - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename, Skip: true} // both filename and filesize matches, consider as completed - } - if localFilesize > filesize { - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename, Overwrite: true} // local filesize is greater than INDEXD record, overwrite local existing - } - // local filesize is less than INDEXD record, try ranged download - return common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename, Range: localFilesize} -} - -func batchDownload(g3 client.Gen3Interface, progress *mpb.Progress, batchFDRSlice []common.FileDownloadResponseObject, protocolText string, workers int, errCh chan error) int { - fdrs := make([]common.FileDownloadResponseObject, 0) - for _, fdrObject := range batchFDRSlice { - err := GetDownloadResponse(g3, &fdrObject, protocolText) - if err != nil { - errCh <- err - continue - } - - fileFlag := os.O_CREATE | os.O_RDWR - if fdrObject.Range != 0 { - fileFlag = os.O_APPEND | os.O_RDWR - } else if fdrObject.Overwrite { - fileFlag = os.O_TRUNC | os.O_RDWR - } - - subDir := filepath.Dir(fdrObject.Filename) - if subDir != "." && subDir != "/" { - err = os.MkdirAll(fdrObject.DownloadPath+subDir, 0766) - if err != nil { - errCh <- err - continue - } - } - file, err := os.OpenFile(fdrObject.DownloadPath+fdrObject.Filename, fileFlag, 0666) - if err != nil { - errCh <- errors.New("Error occurred during opening local file: " + err.Error()) - continue - } - total := fdrObject.Response.ContentLength + fdrObject.Range - bar := progress.AddBar(total, - mpb.PrependDecorators( - decor.Name(fdrObject.Filename+" "), - decor.CountersKibiByte("% .1f / % .1f"), - ), - mpb.AppendDecorators( - decor.Percentage(), - decor.AverageSpeed(decor.SizeB1024(0), " % .1f"), - ), - ) - if fdrObject.Range > 0 { - bar.SetCurrent(fdrObject.Range) - } - writer := bar.ProxyWriter(file) - fdrObject.Writer = writer - fdrs = append(fdrs, fdrObject) - defer file.Close() - defer fdrObject.Response.Body.Close() - } - - fdrCh := make(chan common.FileDownloadResponseObject, len(fdrs)) - wg := sync.WaitGroup{} - succeeded := 0 - var err error - for range workers { - wg.Add(1) - go func() { - for fdr := range fdrCh { - if _, err = io.Copy(fdr.Writer, fdr.Response.Body); err != nil { - errCh <- errors.New("io.Copy error: " + err.Error()) - return - } - succeeded++ - } - wg.Done() - }() - } - - for _, fdr := range fdrs { - fdrCh <- fdr - } - close(fdrCh) - - wg.Wait() - return succeeded -} - -// AskForConfirmation asks user for confirmation before proceed, will wait if user entered garbage -func AskForConfirmation(logger logs.Logger, s string) bool { - reader := bufio.NewReader(os.Stdin) - - for { - logger.Printf("%s [y/n]: ", s) - - response, err := reader.ReadString('\n') - if err != nil { - logger.Fatal("Error occurred during parsing user's confirmation: " + err.Error()) - } - - switch strings.ToLower(strings.TrimSpace(response)) { - case "y", "yes": - return true - case "n", "no": - return false - default: - return false // Example of defaulting to false - } - } -} - -func downloadFile(g3i client.Gen3Interface, objects []ManifestObject, downloadPath string, filenameFormat string, rename bool, noPrompt bool, protocol string, numParallel int, skipCompleted bool) error { - if numParallel < 1 { - return fmt.Errorf("invalid value for option \"numparallel\": must be a positive integer! Please check your input") - } - - downloadPath, err := common.ParseRootPath(downloadPath) - if err != nil { - return fmt.Errorf("downloadFile Error: %s", err.Error()) - } - if !strings.HasSuffix(downloadPath, "/") { - downloadPath += "/" - } - filenameFormat = strings.ToLower(strings.TrimSpace(filenameFormat)) - if (filenameFormat == "guid" || filenameFormat == "combined") && rename { - g3i.Logger().Println("NOTICE: flag \"rename\" only works if flag \"filename-format\" is \"original\"") - rename = false - } - - if filenameFormat != "original" && filenameFormat != "guid" && filenameFormat != "combined" { - return fmt.Errorf("invalid option found! option \"filename-format\" can either be \"original\", \"guid\" or \"combined\" only") - } - if filenameFormat == "guid" || filenameFormat == "combined" { - g3i.Logger().Printf("WARNING: in \"guid\" or \"combined\" mode, duplicated files under \"%s\" will be overwritten\n", downloadPath) - if !noPrompt && !AskForConfirmation(g3i.Logger(), "Proceed?") { - g3i.Logger().Fatal("Aborted by user") - } - } else if !rename { - g3i.Logger().Printf("WARNING: flag \"rename\" was set to false in \"original\" mode, duplicated files under \"%s\" will be overwritten\n", downloadPath) - if !noPrompt && !AskForConfirmation(g3i.Logger(), "Proceed?") { - g3i.Logger().Fatal("Aborted by user") - } - } else { - g3i.Logger().Printf("NOTICE: flag \"rename\" was set to true in \"original\" mode, duplicated files under \"%s\" will be renamed by appending a counter value to the original filenames\n", downloadPath) - } - - protocolText := "" - if protocol != "" { - protocolText = "?protocol=" + protocol - } - - err = os.MkdirAll(downloadPath, 0766) - if err != nil { - return fmt.Errorf("cannot create folder %s", downloadPath) - } - - renamedFiles := make([]RenamedOrSkippedFileInfo, 0) - skippedFiles := make([]RenamedOrSkippedFileInfo, 0) - fdrObjects := make([]common.FileDownloadResponseObject, 0) - - g3i.Logger().Printf("Total number of objects in manifest: %d\n", len(objects)) - g3i.Logger().Println("Preparing file info for each file, please wait...") - fileInfoProgress := mpb.New(mpb.WithOutput(os.Stdout)) - fileInfoBar := fileInfoProgress.AddBar(int64(len(objects)), - mpb.PrependDecorators( - decor.Name("Preparing files "), - decor.CountersNoUnit("%d / %d"), - ), - mpb.AppendDecorators(decor.Percentage()), - ) - for _, obj := range objects { - if obj.ObjectID == "" { - g3i.Logger().Println("Found empty object_id (GUID), skipping this entry") - continue - } - var fdrObject common.FileDownloadResponseObject - filename := obj.Filename - filesize := obj.Filesize - // only queries Gen3 services if any of these 2 values doesn't exists in manifest - if filename == "" || filesize == 0 { - filename, filesize = AskGen3ForFileInfo(g3i, obj.ObjectID, protocol, downloadPath, filenameFormat, rename, &renamedFiles) - } - fdrObject = common.FileDownloadResponseObject{DownloadPath: downloadPath, Filename: filename} - if !rename { - fdrObject = validateLocalFileStat(g3i.Logger(), downloadPath, filename, filesize, skipCompleted) - } - fdrObject.GUID = obj.ObjectID - fdrObjects = append(fdrObjects, fdrObject) - fileInfoBar.Increment() - } - fileInfoProgress.Wait() - g3i.Logger().Println("File info prepared successfully") - - totalCompeleted := 0 - workers, _, errCh, _ := initBatchUploadChannels(numParallel, len(fdrObjects)) - downloadProgress := mpb.New(mpb.WithOutput(os.Stdout)) - batchFDRSlice := make([]common.FileDownloadResponseObject, 0) - for _, fdrObject := range fdrObjects { - if fdrObject.Skip { - g3i.Logger().Printf("File \"%s\" (GUID: %s) has been skipped because there is a complete local copy\n", fdrObject.Filename, fdrObject.GUID) - skippedFiles = append(skippedFiles, RenamedOrSkippedFileInfo{GUID: fdrObject.GUID, OldFilename: fdrObject.Filename}) - continue - } - - if len(batchFDRSlice) < workers { - batchFDRSlice = append(batchFDRSlice, fdrObject) - } else { - totalCompeleted += batchDownload(g3i, downloadProgress, batchFDRSlice, protocolText, workers, errCh) - batchFDRSlice = make([]common.FileDownloadResponseObject, 0) - batchFDRSlice = append(batchFDRSlice, fdrObject) - } - } - totalCompeleted += batchDownload(g3i, downloadProgress, batchFDRSlice, protocolText, workers, errCh) // download remainders - downloadProgress.Wait() - - g3i.Logger().Printf("%d files downloaded.\n", totalCompeleted) - - if len(renamedFiles) > 0 { - g3i.Logger().Printf("%d files have been renamed as the following:\n", len(renamedFiles)) - for _, rfi := range renamedFiles { - g3i.Logger().Printf("File \"%s\" (GUID: %s) has been renamed as: %s\n", rfi.OldFilename, rfi.GUID, rfi.NewFilename) - } - } - if len(skippedFiles) > 0 { - g3i.Logger().Printf("%d files have been skipped\n", len(skippedFiles)) - } - if len(errCh) > 0 { - close(errCh) - g3i.Logger().Printf("%d files have encountered an error during downloading, detailed error messages are:\n", len(errCh)) - for err := range errCh { - g3i.Logger().Println(err.Error()) - } - } - return nil -} - -func init() { - var manifestPath string - var downloadPath string - var filenameFormat string - var rename bool - var noPrompt bool - var protocol string - var numParallel int - var skipCompleted bool - - var downloadMultipleCmd = &cobra.Command{ - Use: "download-multiple", - Short: "Download multiple of files from a specified manifest", - Long: `Get presigned URLs for multiple of files specified in a manifest file and then download all of them.`, - Example: `./data-client download-multiple --profile= --manifest= --download-path=`, - Run: func(cmd *cobra.Command, args []string) { - // don't initialize transmission logs for non-uploading related commands - - logger, logCloser := logs.New(profile, logs.WithConsole(), logs.WithFailedLog(), logs.WithScoreboard(), logs.WithSucceededLog()) - defer logCloser() - - g3i, err := client.NewGen3Interface(context.Background(), profile, logger) - if err != nil { - log.Fatalf("Failed to parse config on profile %s, %v", profile, err) - } - - manifestPath, _ = common.GetAbsolutePath(manifestPath) - manifestFile, err := os.Open(manifestPath) - if err != nil { - g3i.Logger().Fatalf("Failed to open manifest file %s, %v\n", manifestPath, err) - } - defer manifestFile.Close() - manifestFileStat, err := manifestFile.Stat() - if err != nil { - g3i.Logger().Fatalf("Failed to get manifest file stats %s, %v\n", manifestPath, err) - } - g3i.Logger().Println("Reading manifest...") - manifestFileSize := manifestFileStat.Size() - manifestProgress := mpb.New(mpb.WithOutput(os.Stdout)) - manifestFileBar := manifestProgress.AddBar(manifestFileSize, - mpb.PrependDecorators( - decor.Name("Manifest "), - decor.CountersKibiByte("% .1f / % .1f"), - ), - mpb.AppendDecorators(decor.Percentage()), - ) - - manifestFileReader := manifestFileBar.ProxyReader(manifestFile) - - manifestBytes, err := io.ReadAll(manifestFileReader) - if err != nil { - g3i.Logger().Fatalf("Failed reading manifest %s, %v\n", manifestPath, err) - } - manifestProgress.Wait() - - var objects []ManifestObject - err = json.Unmarshal(manifestBytes, &objects) - if err != nil { - g3i.Logger().Fatalf("Error has occurred during unmarshalling manifest object: %v\n", err) - } - - err = downloadFile(g3i, objects, downloadPath, filenameFormat, rename, noPrompt, protocol, numParallel, skipCompleted) - if err != nil { - g3i.Logger().Fatal(err.Error()) - } - }, - } - - downloadMultipleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") - downloadMultipleCmd.MarkFlagRequired("profile") //nolint:errcheck - downloadMultipleCmd.Flags().StringVar(&manifestPath, "manifest", "", "The manifest file to read from. A valid manifest can be acquired by using the \"Download Manifest\" button in Data Explorer from a data common's portal") - downloadMultipleCmd.MarkFlagRequired("manifest") //nolint:errcheck - downloadMultipleCmd.Flags().StringVar(&downloadPath, "download-path", ".", "The directory in which to store the downloaded files") - downloadMultipleCmd.Flags().StringVar(&filenameFormat, "filename-format", "original", "The format of filename to be used, including \"original\", \"guid\" and \"combined\"") - downloadMultipleCmd.Flags().BoolVar(&rename, "rename", false, "Only useful when \"--filename-format=original\", will rename file by appending a counter value to its filename if set to true, otherwise the same filename will be used") - downloadMultipleCmd.Flags().BoolVar(&noPrompt, "no-prompt", false, "If set to true, will not display user prompt message for confirmation") - downloadMultipleCmd.Flags().StringVar(&protocol, "protocol", "", "Specify the preferred protocol with --protocol=s3") - downloadMultipleCmd.Flags().IntVar(&numParallel, "numparallel", 1, "Number of downloads to run in parallel") - downloadMultipleCmd.Flags().BoolVar(&skipCompleted, "skip-completed", false, "If set to true, will check for filename and size before download and skip any files in \"download-path\" that matches both") - RootCmd.AddCommand(downloadMultipleCmd) -} diff --git a/client/g3cmd/gitversion.go b/client/g3cmd/gitversion.go deleted file mode 100644 index cb3a308..0000000 --- a/client/g3cmd/gitversion.go +++ /dev/null @@ -1,6 +0,0 @@ -package g3cmd - -var ( - gitcommit = "N/A" - gitversion = "2023.11" -) diff --git a/client/g3cmd/retry-upload.go b/client/g3cmd/retry-upload.go deleted file mode 100644 index edd5c52..0000000 --- a/client/g3cmd/retry-upload.go +++ /dev/null @@ -1,215 +0,0 @@ -package g3cmd - -import ( - "context" - "os" - "path/filepath" - "time" - - "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" - "github.com/calypr/data-client/client/logs" - - "github.com/spf13/cobra" -) - -func handleFailedRetry(g3i client.Gen3Interface, ro common.RetryObject, retryObjCh chan common.RetryObject, err error) { - logger := g3i.Logger() - - // Record failure in JSON log - logger.Failed(ro.FilePath, ro.Filename, ro.FileMetadata, ro.GUID, ro.RetryCount, ro.Multipart) - - if err != nil { - logger.Println("Error:", err) - } - - if ro.RetryCount < MaxRetryCount { - retryObjCh <- ro - return - } - - // Max retries reached — clean up - if ro.GUID != "" { - if msg, err := DeleteRecord(g3i, ro.GUID); err == nil { - logger.Println(msg) - } else { - logger.Println("Cleanup failed:", err) - } - } - - // Final failure - sb, err := logs.FromSBContext(context.Background()) - if err != nil { - logger.Println(err) - } - sb.IncrementSB(MaxRetryCount + 1) - - if len(retryObjCh) == 0 { - close(retryObjCh) - logger.Println("Retry channel closed — all done") - } -} - -func retryUpload(g3i client.Gen3Interface, failedLogMap map[string]common.RetryObject) { - logger := g3i.Logger() - - sb, err := logs.FromSBContext(context.Background()) - if err != nil { - logger.Println(err) - } - - if len(failedLogMap) == 0 { - logger.Println("No failed files to retry.") - return - } - - logger.Println("Starting retry-upload...") - retryObjCh := make(chan common.RetryObject, len(failedLogMap)) - - // Load failed entries (skip already succeeded ones) - for _, ro := range failedLogMap { - // Simple check: if succeeded log exists and contains this path, skip - if common.AlreadySucceededFromFile(ro.FilePath) { - logger.Printf("Already uploaded: %s — skipping\n", ro.FilePath) - continue - } - retryObjCh <- ro - } - - if len(retryObjCh) == 0 { - logger.Println("All failed files were already successfully uploaded in a previous run.") - return - } - - for ro := range retryObjCh { - ro.RetryCount++ - logger.Printf("#%d retry — %s\n", ro.RetryCount, ro.FilePath) - logger.Printf("Waiting %.0f seconds...\n", GetWaitTime(ro.RetryCount).Seconds()) - time.Sleep(GetWaitTime(ro.RetryCount)) - - // Optional: delete old record - if ro.GUID != "" { - if msg, err := DeleteRecord(g3i, ro.GUID); err == nil { - logger.Println(msg) - } - } - - // Fix missing filename if needed - if ro.Filename == "" { - absPath, _ := common.GetAbsolutePath(ro.FilePath) - ro.Filename = filepath.Base(absPath) - } - - var err error - if ro.Multipart { - // Multipart retry - req := common.FileUploadRequestObject{ - FilePath: ro.FilePath, - Filename: ro.Filename, - GUID: ro.GUID, - } - err = MultipartUpload(context.Background(), g3i, req, ro.Bucket, true) - if err == nil { - logger.Succeeded(ro.FilePath, req.GUID) - sb.IncrementSB(ro.RetryCount - 1) // success on this retry - continue - } - } else { - // Single-part retry - var presignedURL, guid string - presignedURL, guid, err = GeneratePresignedURL(g3i, ro.Filename, ro.FileMetadata, ro.Bucket) - if err != nil { - handleFailedRetry(g3i, ro, retryObjCh, err) - continue - } - - file, err := os.Open(ro.FilePath) - if err != nil { - handleFailedRetry(g3i, ro, retryObjCh, err) - continue - } - stat, _ := file.Stat() - file.Close() - - if stat.Size() > FileSizeLimit { - ro.Multipart = true - retryObjCh <- ro - continue - } - - fur := common.FileUploadRequestObject{ - FilePath: ro.FilePath, - Filename: ro.Filename, - FileMetadata: ro.FileMetadata, - GUID: guid, - PresignedURL: presignedURL, - } - - fur, err = GenerateUploadRequest(g3i, fur, nil, nil) - if err != nil { - handleFailedRetry(g3i, ro, retryObjCh, err) - continue - } - - err = uploadFile(g3i, fur, ro.RetryCount) - if err != nil { - handleFailedRetry(g3i, ro, retryObjCh, err) - continue - } - - logger.Succeeded(ro.FilePath, fur.GUID) - sb.IncrementSB(ro.RetryCount - 1) - } - - if len(retryObjCh) == 0 { - close(retryObjCh) - } - } -} - -func init() { - var failedLogPath, profile string - - var retryUploadCmd = &cobra.Command{ - Use: "retry-upload", - Short: "Retry failed uploads from a failed_log.json", - Long: `Re-uploads files listed in a failed log using exponential backoff and progress bars.`, - Example: `./data-client retry-upload --profile=myprofile --failed-log-path=/path/to/failed_log.json`, - Run: func(cmd *cobra.Command, args []string) { - Logger, closer := logs.New(profile, - logs.WithConsole(), - logs.WithMessageFile(), - logs.WithFailedLog(), - logs.WithSucceededLog(), - ) - defer closer() - - g3, err := client.NewGen3Interface(context.Background(), profile, Logger) - if err != nil { - Logger.Fatalf("Failed to initialize client: %v", err) - } - - logger := g3.Logger() - - // Create scoreboard with our logger injected - sb := logs.NewSB(MaxRetryCount, logger) - - // Load failed log - failedMap, err := common.LoadFailedLog(failedLogPath) - if err != nil { - logger.Fatalf("Cannot read failed log: %v", err) - } - - retryUpload(g3, failedMap) - sb.PrintSB() - }, - } - - retryUploadCmd.Flags().StringVar(&profile, "profile", "", "Profile to use") - retryUploadCmd.MarkFlagRequired("profile") - - retryUploadCmd.Flags().StringVar(&failedLogPath, "failed-log-path", "", "Path to failed_log.json") - retryUploadCmd.MarkFlagRequired("failed-log-path") - - RootCmd.AddCommand(retryUploadCmd) -} diff --git a/client/g3cmd/root.go b/client/g3cmd/root.go deleted file mode 100644 index 8bc7ab9..0000000 --- a/client/g3cmd/root.go +++ /dev/null @@ -1,124 +0,0 @@ -package g3cmd - -import ( - "encoding/json" - "net/http" - "os" - "strconv" - "time" - - "github.com/calypr/data-client/client/jwt" - "github.com/calypr/data-client/client/logs" - "github.com/spf13/cobra" - "golang.org/x/mod/semver" -) - -var profile string - -// Package-level variable to hold the closer function -// (Assuming logs.Closer is a type that can hold a function, like func() error) -var logCloser func() - -// Or just: -// var logCloser io.Closer // if closer implements io.Closer - -// RootCmd represents the base command when called without any subcommands -var RootCmd = &cobra.Command{ - Use: "data-client", - Short: "Use the data-client to interact with a Gen3 Data Commons", - Long: "Gen3 Client for downloading, uploading and submitting data to data commons.\ndata-client version: " + gitversion + ", commit: " + gitcommit, - Version: gitversion, -} - -// Execute adds all child commands to the root command sets flags appropriately -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if logCloser != nil { - defer func() { - logCloser() - }() - } - - if err := RootCmd.Execute(); err != nil { - os.Stderr.WriteString("Error: " + err.Error() + "\n") - os.Exit(1) - } -} - -func init() { - cobra.OnInitialize(initConfig) - - // Define flags and configuration settings. - RootCmd.PersistentFlags().StringVar(&profile, "profile", "", "Specify profile to use") - _ = RootCmd.MarkFlagRequired("profile") -} - -type GitHubRelease struct { - TagName string `json:"tag_name"` -} - -func initConfig() { - // The logger is needed throughout the application, so we don't store it here, - // but the closer must be stored. - logger, closer := logs.New(profile, - logs.WithConsole(), - logs.WithMessageFile(), - logs.WithFailedLog(), - logs.WithSucceededLog(), - ) - - // 2. ASSIGN CLOSER TO PACKAGE VARIABLE - logCloser = closer - - // The rest of the function remains the same, except for removing the 'defer resp.Body.Close()' - // from the initConfig body, as that was unrelated to the logs closer. - // The rest of your original logic follows... - - conf := jwt.Configure{} - // init local config file - err := conf.InitConfigFile() - if err != nil { - logger.Fatal("Error occurred when trying to init config file: " + err.Error()) - } - - // version checker - if os.Getenv("GEN3_CLIENT_VERSION_CHECK") != "false" && - gitversion != "" && gitversion != "N/A" { - - const ( - owner = "uc-cdis" - repository = "cdis-data-client" - // The official GitHub API endpoint for the latest release - apiURL = "https://api.github.com/repos/" + owner + "/" + repository + "/releases/latest" - ) - - client := http.Client{Timeout: 5 * time.Second} - resp, err := client.Get(apiURL) - if err != nil { - logger.Println("Error occurred when fetching latest version (HTTP request failed): " + err.Error()) - // Continue execution, as version check failure is non-fatal - return - } - - // This defer is correct and should remain, as it cleans up the HTTP response body - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - logger.Println("Error occurred when fetching latest version (GitHub API returned status " + strconv.Itoa(resp.StatusCode) + ")") - return - } - - var release GitHubRelease - if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { - logger.Println("Error occurred when decoding latest version response: " + err.Error()) - return - } - - latestVersionTag := release.TagName - - if semver.Compare(gitversion, latestVersionTag) < 0 { - logger.Println("A new version of data-client is available! The latest version is " + latestVersionTag + ". You are using version " + gitversion) - logger.Println("Please download the latest data-client release from https://github.com/uc-cdis/cdis-data-client/releases/latest") - } - } -} diff --git a/client/g3cmd/upload-multipart.go b/client/g3cmd/upload-multipart.go deleted file mode 100644 index bc658b4..0000000 --- a/client/g3cmd/upload-multipart.go +++ /dev/null @@ -1,309 +0,0 @@ -package g3cmd - -import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "sort" - "strings" - "sync" - "time" - - "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" - "github.com/calypr/data-client/client/logs" - "github.com/spf13/cobra" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" -) - -const ( - minChunkSize = 5 * 1024 * 1024 // S3 minimum part size - maxMultipartParts = 10000 - maxConcurrentUploads = 10 - maxRetries = 5 -) - -func NewUploadMultipartCmd() *cobra.Command { - var ( - filePath string - guid string - bucketName string - ) - - cmd := &cobra.Command{ - Use: "upload-multipart", - Short: "Upload a single file using multipart upload", - Long: `Uploads a large file to object storage using multipart upload. -This method is resilient to network interruptions and supports resume capability.`, - Example: `./data-client upload-multipart --profile=myprofile --file-path=./large.bam -./data-client upload-multipart --profile=myprofile --file-path=./data.bam --guid=existing-guid`, - RunE: func(cmd *cobra.Command, args []string) error { - profile, _ := cmd.Flags().GetString("profile") - - return UploadSingleFile(profile, bucketName, filePath, guid) - }, - } - - cmd.Flags().StringVar(&filePath, "file-path", "", "Path to the file to upload") - cmd.Flags().StringVar(&guid, "guid", "", "Optional existing GUID (otherwise generated)") - cmd.Flags().StringVar(&bucketName, "bucket", "", "Target bucket (defaults to configured DATA_UPLOAD_BUCKET)") - - _ = cmd.MarkFlagRequired("profile") - _ = cmd.MarkFlagRequired("file-path") - - return cmd -} - -func UploadSingleFile(profile, bucket, filePath, guid string) error { - - logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) - defer closer() - g3, err := client.NewGen3Interface( - context.Background(), - profile, - logger, - ) - if err != nil { - return fmt.Errorf("failed to initialize Gen3 interface: %w", err) - } - - absPath, err := common.GetAbsolutePath(filePath) - if err != nil { - return fmt.Errorf("invalid file path: %w", err) - } - - fileInfo := common.FileUploadRequestObject{ - FilePath: absPath, - Filename: filepath.Base(absPath), - GUID: guid, - FileMetadata: common.FileMetadata{}, - } - - return MultipartUpload(context.TODO(), g3, fileInfo, bucket, true) -} - -// MultipartUpload is now clean, context-aware, and uses modern progress bars -func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, bucketName string, showProgress bool) error { - g3.Logger().Printf("File Upload Request: %#v\n", req) - - file, err := os.Open(req.FilePath) - if err != nil { - return fmt.Errorf("cannot open file %s: %w", req.FilePath, err) - } - defer file.Close() - - stat, err := file.Stat() - if err != nil { - return fmt.Errorf("cannot stat file: %w", err) - } - - g3.Logger().Printf("File Name: '%s', File Size: '%d'\n", stat.Name(), stat.Size()) - - fileSize := stat.Size() - if fileSize == 0 { - return fmt.Errorf("file is empty: %s", req.Filename) - } - - if fileSize < 5*1024*1024*1024 { - g3.Logger().Printf("File size < 5GB (%d bytes), using single-part upload\n", fileSize) - err := UploadSingle(g3.GetCredential().Profile, req.GUID, req.FilePath, req.Bucket, showProgress) - if err != nil { - g3.Logger().Fatal(err.Error()) - } - return nil - } - - // Progress bar setup (modern mpb) - var p *mpb.Progress - var bar *mpb.Bar - if showProgress { - p = mpb.New(mpb.WithOutput(os.Stdout)) - bar = p.AddBar(stat.Size(), - mpb.PrependDecorators( - decor.Name(req.Filename+" "), - decor.CountersKibiByte("%.1f / %.1f"), - ), - mpb.AppendDecorators( - decor.Percentage(), - decor.AverageSpeed(decor.SizeB1024(0), " % .1f"), - ), - ) - } - - // Initialize multipart upload - uploadID, finalGUID, err := InitMultipartUpload(g3, req, bucketName) - if err != nil { - return fmt.Errorf("failed to initiate multipart upload: %w", err) - } - req.GUID = finalGUID // update with server-provided GUID - - key := finalGUID + "/" + req.Filename - chunkSize := optimalChunkSize(stat.Size()) - - numChunks := int((stat.Size() + chunkSize - 1) / chunkSize) - parts := make([]MultipartPartObject, 0, numChunks) - - // Channel for chunk indices - chunks := make(chan int, numChunks) - for i := 1; i <= numChunks; i++ { - chunks <- i - } - close(chunks) - - var ( - wg sync.WaitGroup - mu sync.Mutex - uploadErrors []error - ) - - worker := func() { - defer wg.Done() - buf := make([]byte, chunkSize) - - for partNum := range chunks { - offset := int64(partNum-1) * chunkSize - end := offset + chunkSize - end = min(end, stat.Size()) - size := end - offset - - // Read chunk - if _, err := file.Seek(offset, io.SeekStart); err != nil { - mu.Lock() - uploadErrors = append(uploadErrors, fmt.Errorf("seek failed for part %d: %w", partNum, err)) - mu.Unlock() - continue - } - n, err := io.ReadFull(file, buf[:size]) - if err != nil && err != io.ErrUnexpectedEOF { - mu.Lock() - uploadErrors = append(uploadErrors, fmt.Errorf("read failed for part %d: %w", partNum, err)) - mu.Unlock() - continue - } - - reader := bytes.NewReader(buf[:n]) - - // Get presigned URL + upload with retry - var etag string - if err := retryWithBackoff(ctx, maxRetries, func() error { - url, err := GenerateMultipartPresignedURL(g3, key, uploadID, partNum, bucketName) - if err != nil { - return err - } - - return uploadPart(url, reader, &etag) - }); err != nil { - mu.Lock() - uploadErrors = append(uploadErrors, fmt.Errorf("part %d failed after retries: %w", partNum, err)) - mu.Unlock() - continue - } - - // Success - mu.Lock() - etag = strings.Trim(etag, `"`) - parts = append(parts, MultipartPartObject{PartNumber: partNum, ETag: etag}) - g3.Logger().Printf("Appended part %d with ETag %s\n", partNum, etag) - if bar != nil { - bar.IncrBy(n) - } - mu.Unlock() - } - } - - // Launch workers - for range maxConcurrentUploads { - wg.Add(1) - go worker() - } - wg.Wait() - - if p != nil { - p.Wait() - } - - if len(uploadErrors) > 0 { - return fmt.Errorf("multipart upload failed: %d parts failed: %v", len(uploadErrors), uploadErrors) - } - - // Sort parts by PartNumber - sort.Slice(parts, func(i, j int) bool { - return parts[i].PartNumber < parts[j].PartNumber - }) - - g3.Logger().Printf("Completing multipart upload with %d parts for file %s\n", len(parts), req.Filename) - for _, part := range parts { - g3.Logger().Printf(" Part %d: ETag=%s\n", part.PartNumber, part.ETag) - } - - if err := CompleteMultipartUpload(g3, key, uploadID, parts, bucketName); err != nil { - return fmt.Errorf("failed to complete multipart upload: %w", err) - } - - g3.Logger().Printf("Successfully uploaded %s as %s (%d)", req.Filename, finalGUID, stat.Size()) - return nil -} - -// Helper: exponential backoff retry -func retryWithBackoff(ctx context.Context, attempts int, fn func() error) error { - var err error - for i := range attempts { - if err = fn(); err == nil { - return nil - } - if i == attempts-1 { - break - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(backoffDuration(i)): - } - } - return fmt.Errorf("after %d attempts: %w", attempts, err) -} - -func backoffDuration(attempt int) time.Duration { - return min(time.Duration(1< --manifest= --upload-path= --bucket= --force-multipart= --include-subdirname= --batch=`, - Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Notice: this is the upload method which requires the user to provide GUIDs. In this method files will be uploaded to specified GUIDs.\nIf your intention is to upload files without pre-existing GUIDs, consider to use \"./data-client upload\" instead.\n\n") - - logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) - defer closer() - - // Instantiate interface to Gen3 - g3i, err := client.NewGen3Interface(context.Background(), profile, logger) - if err != nil { - g3i.Logger().Fatalf("Failed to parse config on profile %s, %v", profile, err) - } - - host, err := g3i.GetHost() - if err != nil { - g3i.Logger().Fatal("Error occurred during parsing config file for hostname: " + err.Error()) - } - dataExplorerURL := host.Scheme + "://" + host.Host + "/explorer" - - var objects []ManifestObject - - manifestFile, err := os.Open(manifestPath) - if err != nil { - g3i.Logger().Println("Failed to open manifest file") - g3i.Logger().Fatal("A valid manifest can be acquired by using the \"Download Manifest\" button on " + dataExplorerURL) - } - defer manifestFile.Close() - switch { - case strings.EqualFold(filepath.Ext(manifestPath), ".json"): - manifestBytes, err := os.ReadFile(manifestPath) - if err != nil { - g3i.Logger().Printf("Failed reading manifest %s, %v\n", manifestPath, err) - g3i.Logger().Fatal("A valid manifest can be acquired by using the \"Download Manifest\" button on " + dataExplorerURL) - } - err = json.Unmarshal(manifestBytes, &objects) - if err != nil { - g3i.Logger().Fatal("Unmarshalling manifest failed with error: " + err.Error()) - } - default: - g3i.Logger().Println("Unsupported manifast format") - g3i.Logger().Fatal("A valid manifest can be acquired by using the \"Download Manifest\" button on " + dataExplorerURL) - } - - absUploadPath, err := common.GetAbsolutePath(uploadPath) - if err != nil { - g3i.Logger().Fatalf("Error when parsing file paths: %s", err.Error()) - } - - // Create unified upload request objects - uploadRequestObjects := make([]common.FileUploadRequestObject, 0, len(objects)) - - for _, object := range objects { - var localFilePath string - // Determine the local file path - if object.Filename != "" { - // conform to fence naming convention - localFilePath, err = getFullFilePath(absUploadPath, object.Filename) - } else { - // Otherwise, here we are assuming the local filename will be the same as GUID - localFilePath, err = getFullFilePath(absUploadPath, object.ObjectID) - } - - if err != nil { - g3i.Logger().Println(err.Error()) - continue - } - - fileInfo, err := ProcessFilename(g3i.Logger(), absUploadPath, localFilePath, object.ObjectID, includeSubDirName, false) - if err != nil { - g3i.Logger().Println("Process filename error: " + err.Error()) - g3i.Logger().Failed(localFilePath, filepath.Base(localFilePath), common.FileMetadata{}, object.ObjectID, 0, false) - continue - } - - // Convert FileInfo to the unified common.FileUploadRequestObject - furObject := common.FileUploadRequestObject{ - FilePath: fileInfo.FilePath, - Filename: fileInfo.Filename, - FileMetadata: fileInfo.FileMetadata, - GUID: fileInfo.GUID, - } - uploadRequestObjects = append(uploadRequestObjects, furObject) - } - - // Separate into single-part and multipart objects - singlePartObjects, multipartObjects := separateSingleAndMultipartUploads(g3i, uploadRequestObjects, forceMultipart) - // Pass the unified objects to the upload handlers - if batch { - workers, respCh, errCh, batchFURObjects := initBatchUploadChannels(numParallel, len(singlePartObjects)) - for i, furObject := range singlePartObjects { - // FileInfo processing and path normalization are already done, so we use the object directly - if len(batchFURObjects) < workers { - batchFURObjects = append(batchFURObjects, furObject) - } else { - batchUpload(g3i, batchFURObjects, workers, respCh, errCh, bucketName) - batchFURObjects = []common.FileUploadRequestObject{furObject} - } - if !forceMultipart && i == len(singlePartObjects)-1 && len(batchFURObjects) > 0 { // upload remainders - batchUpload(g3i, batchFURObjects, workers, respCh, errCh, bucketName) - } - } - } else { - processSingleUploads(g3i, singlePartObjects, bucketName, includeSubDirName, absUploadPath) // Assuming updated - } - - if len(multipartObjects) > 0 { - err := processMultipartUpload(g3i, multipartObjects, bucketName, includeSubDirName, absUploadPath) - if err != nil { - g3i.Logger().Fatal(err.Error()) - } - } - - if len(g3i.Logger().GetSucceededLogMap()) == 0 { - retryUpload(g3i, g3i.Logger().GetFailedLogMap()) - } - - g3i.Logger().Scoreboard().PrintSB() - }, - } - - uploadMultipleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") - uploadMultipleCmd.MarkFlagRequired("profile") //nolint:errcheck - uploadMultipleCmd.Flags().StringVar(&manifestPath, "manifest", "", "The manifest file to read from. A valid manifest can be acquired by using the \"Download Manifest\" button in Data Explorer for Common portal") - uploadMultipleCmd.MarkFlagRequired("manifest") //nolint:errcheck - uploadMultipleCmd.Flags().StringVar(&uploadPath, "upload-path", "", "The directory in which contains files to be uploaded") - uploadMultipleCmd.MarkFlagRequired("upload-path") //nolint:errcheck - uploadMultipleCmd.Flags().BoolVar(&batch, "batch", true, "Upload in parallel") - uploadMultipleCmd.Flags().IntVar(&numParallel, "numparallel", 3, "Number of uploads to run in parallel") - uploadMultipleCmd.Flags().StringVar(&bucketName, "bucket", "", "The bucket to which files will be uploaded. If not provided, defaults to Gen3's configured DATA_UPLOAD_BUCKET.") - uploadMultipleCmd.Flags().BoolVar(&forceMultipart, "force-multipart", false, "Force to use multipart upload when possible (file size >= 5MB)") - uploadMultipleCmd.Flags().BoolVar(&includeSubDirName, "include-subdirname", true, "Include subdirectory names in file name") - RootCmd.AddCommand(uploadMultipleCmd) -} - -func processSingleUploads(g3i client.Gen3Interface, singleObjects []common.FileUploadRequestObject, bucketName string, includeSubDirName bool, uploadPath string) { - for _, furObject := range singleObjects { - filePath := furObject.FilePath - file, err := os.Open(filePath) - if err != nil { - g3i.Logger().Println("File open error: " + err.Error()) - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - continue - } - startSingleFileUpload(g3i, furObject, file, bucketName) - file.Close() - } -} - -func startSingleFileUpload(g3i client.Gen3Interface, furObject common.FileUploadRequestObject, file *os.File, bucketName string) { - - fi, err := file.Stat() - if err != nil { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - g3i.Logger().Println("File stat error for file" + fi.Name() + ", file may be missing or unreadable because of permissions.\n") - return - } - - respURL, guid, err := GeneratePresignedURL(g3i, furObject.Filename, furObject.FileMetadata, bucketName) - if err != nil { - g3i.Logger().Println(err.Error()) - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, guid, 0, false) - return - } - furObject.GUID = guid - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - furObject.PresignedURL = respURL - - furObject, err = GenerateUploadRequest(g3i, furObject, file, nil) - if err != nil { - file.Close() - g3i.Logger().Printf("Error occurred during request generation: %s\n", err.Error()) - return - } - - err = uploadFile(g3i, furObject, 0) - if err != nil { - g3i.Logger().Println(err.Error()) - } else { - g3i.Logger().Scoreboard().IncrementSB(0) - } - - file.Close() -} - -func processMultipartUpload(g3i client.Gen3Interface, multipartObjects []common.FileUploadRequestObject, bucketName string, includeSubDirName bool, uploadPath string) error { - cred := g3i.GetCredential() - if cred.UseShepherd == "true" || - cred.UseShepherd == "" && common.DefaultUseShepherd == true { - return fmt.Errorf("error: Shepherd currently does not support multipart uploads. For the moment, please disable Shepherd with\n $ data-client configure --profile=%v --use-shepherd=false\nand try again", cred.Profile) - } - g3i.Logger().Println("Multipart uploading...") - - for _, furObject := range multipartObjects { - // No more redundant ProcessFilename call! - // Pass the complete FileUploadRequestObject to the streamlined multipartUpload. - // Enable progress bar for batch uploads (interactive CLI use) - err := MultipartUpload(context.Background(), g3i, furObject, bucketName, true) - - if err != nil { - g3i.Logger().Println(err.Error()) - } else { - g3i.Logger().Scoreboard().IncrementSB(0) - } - } - return nil -} diff --git a/client/g3cmd/upload-single.go b/client/g3cmd/upload-single.go deleted file mode 100644 index f154c4d..0000000 --- a/client/g3cmd/upload-single.go +++ /dev/null @@ -1,123 +0,0 @@ -package g3cmd - -// Deprecated: Use upload instead. -import ( - "context" - "errors" - "fmt" - "log" - "os" - "path/filepath" - - "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" - "github.com/calypr/data-client/client/logs" - "github.com/spf13/cobra" -) - -func init() { - var guid string - var filePath string - var bucketName string - - var uploadSingleCmd = &cobra.Command{ - Use: "upload-single", - Short: "Upload a single file to a GUID", - Long: `Gets a presigned URL for which to upload a file associated with a GUID and then uploads the specified file.`, - Example: `./data-client upload-single --profile= --guid=f6923cf3-xxxx-xxxx-xxxx-14ab3f84f9d6 --file=`, - Run: func(cmd *cobra.Command, args []string) { - // initialize transmission logs - err := UploadSingle(profile, guid, filePath, bucketName, true) - if err != nil { - log.Fatalln(err.Error()) - } - }, - } - uploadSingleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") - uploadSingleCmd.MarkFlagRequired("profile") //nolint:errcheck - uploadSingleCmd.Flags().StringVar(&guid, "guid", "", "Specify the guid for the data you would like to work with") - uploadSingleCmd.MarkFlagRequired("guid") //nolint:errcheck - uploadSingleCmd.Flags().StringVar(&filePath, "file", "", "Specify file to upload to with --file=~/path/to/file") - uploadSingleCmd.MarkFlagRequired("file") //nolint:errcheck - uploadSingleCmd.Flags().StringVar(&bucketName, "bucket", "", "The bucket to which files will be uploaded. If not provided, defaults to Gen3's configured DATA_UPLOAD_BUCKET.") - RootCmd.AddCommand(uploadSingleCmd) -} - -func UploadSingle(profile string, guid string, filePath string, bucketName string, showProgress bool) error { - - opts := []logs.Option{ - logs.WithSucceededLog(), - logs.WithFailedLog(), - logs.WithMessageFile(), - } - - if showProgress { - opts = append(opts, logs.WithScoreboard(), logs.WithConsole()) - } - - logger, closer := logs.New(profile, opts...) - defer closer() - - // Instantiate interface to Gen3 - g3i, err := client.NewGen3Interface( - context.Background(), - profile, - logger, - ) - if err != nil { - return fmt.Errorf("failed to parse config on profile %s: %w", profile, err) - } - - updateUI := func() { - if showProgress { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() - } - } - - filePaths, err := common.ParseFilePaths(filePath, false) - if len(filePaths) > 1 { - return errors.New("more than 1 file location has been found. Do not use \"*\" in file path or provide a folder as file path") - } - if err != nil { - return errors.New("file path parsing error: " + err.Error()) - } - if len(filePaths) == 1 { - filePath = filePaths[0] - } - filename := filepath.Base(filePath) - if _, err := os.Stat(filePath); os.IsNotExist(err) { - g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - updateUI() - return fmt.Errorf("[ERROR] The file you specified \"%s\" does not exist locally\n", filePath) - } - - file, err := os.Open(filePath) - if err != nil { - updateUI() - g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - g3i.Logger().Println("File open error: " + err.Error()) - return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", filePath, err.Error()) - } - defer file.Close() - - furObject := common.FileUploadRequestObject{FilePath: filePath, Filename: filename, GUID: guid, Bucket: bucketName} - furObject, err = GenerateUploadRequest(g3i, furObject, file, nil) - if err != nil { - file.Close() - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) - updateUI() - g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) - return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) - } - err = uploadFile(g3i, furObject, 0) - if err != nil { - updateUI() - return fmt.Errorf("[ERROR] Error uploading file %s: %s\n", filePath, err.Error()) - } - if showProgress { - g3i.Logger().Scoreboard().PrintSB() - } - return nil -} diff --git a/client/g3cmd/utils.go b/client/g3cmd/utils.go deleted file mode 100644 index d65f488..0000000 --- a/client/g3cmd/utils.go +++ /dev/null @@ -1,686 +0,0 @@ -package g3cmd - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net/http" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" - "github.com/calypr/data-client/client/logs" - - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" -) - -// ManifestObject represents an object from manifest that downloaded from windmill / data-portal -type ManifestObject struct { - ObjectID string `json:"object_id"` - SubjectID string `json:"subject_id"` - Filename string `json:"file_name"` - Filesize int64 `json:"file_size"` -} - -// InitRequestObject represents the payload that sends to FENCE for getting a singlepart upload presignedURL or init a multipart upload for new object file -type InitRequestObject struct { - Filename string `json:"file_name"` - Bucket string `json:"bucket,omitempty"` - GUID string `json:"guid,omitempty"` -} - -// ShepherdInitRequestObject represents the payload that sends to Shepherd for getting a singlepart upload presignedURL or init a multipart upload for new object file -type ShepherdInitRequestObject struct { - Filename string `json:"file_name"` - Authz struct { - Version string `json:"version"` - ResourcePaths []string `json:"resource_paths"` - } `json:"authz"` - Aliases []string `json:"aliases"` - // Metadata is an encoded JSON string of any arbitrary metadata the user wishes to upload. - Metadata map[string]any `json:"metadata"` -} - -// MultipartUploadRequestObject represents the payload that sends to FENCE for getting a presignedURL for a part -type MultipartUploadRequestObject struct { - Key string `json:"key"` - UploadID string `json:"uploadId"` - PartNumber int `json:"partNumber"` - Bucket string `json:"bucket,omitempty"` -} - -// MultipartCompleteRequestObject represents the payload that sends to FENCE for completeing a multipart upload -type MultipartCompleteRequestObject struct { - Key string `json:"key"` - UploadID string `json:"uploadId"` - Parts []MultipartPartObject `json:"parts"` - Bucket string `json:"bucket,omitempty"` -} - -// MultipartPartObject represents a part object -type MultipartPartObject struct { - PartNumber int `json:"PartNumber"` - ETag string `json:"ETag"` -} - -// FileInfo is a helper struct for including subdirname as filename -type FileInfo struct { - FilePath string - Filename string - FileMetadata common.FileMetadata - ObjectId string -} - -// RenamedOrSkippedFileInfo is a helper struct for recording renamed or skipped files -type RenamedOrSkippedFileInfo struct { - GUID string - OldFilename string - NewFilename string -} - -const ( - // B is bytes - B int64 = iota - // KB is kilobytes - KB int64 = 1 << (10 * iota) - // MB is megabytes - MB - // GB is gigabytes - GB - // TB is terrabytes - TB -) - -var unitMap = map[int64]string{ - B: "B", - KB: "KB", - MB: "MB", - GB: "GB", - TB: "TB", -} - -// FileSizeLimit is the maximun single file size for non-multipart upload (5GB) -const FileSizeLimit = 5 * GB - -// MultipartFileSizeLimit is the maximun single file size for multipart upload (5TB) -const MultipartFileSizeLimit = 5 * TB -const minMultipartChunkSize = 5 * MB - -// MaxRetryCount is the maximum retry number per record -const MaxRetryCount = 5 -const maxWaitTime = 300 - -// InitMultipartUpload helps sending requests to FENCE to init a multipart upload -func InitMultipartUpload(g3 client.Gen3Interface, furObject common.FileUploadRequestObject, bucketName string) (string, string, error) { - // Use Filename and GUID directly from the unified request object - multipartInitObject := InitRequestObject{Filename: furObject.Filename, Bucket: bucketName, GUID: furObject.GUID} - - objectBytes, err := json.Marshal(multipartInitObject) - if err != nil { - return "", "", errors.New("Error has occurred during marshalling data for multipart upload initialization, detailed error message: " + err.Error()) - } - - msg, err := g3.DoRequestWithSignedHeader(common.FenceDataMultipartInitEndpoint, "application/json", objectBytes) - - if err != nil { - if strings.Contains(err.Error(), "404") { - return "", "", errors.New(err.Error() + "\nPlease check to ensure FENCE version is at 2.8.0 or beyond") - } - return "", "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) - } - if msg.UploadID == "" || msg.GUID == "" { - return "", "", errors.New("unknown error has occurred during multipart upload initialization. Please check logs from Gen3 services") - } - return msg.UploadID, msg.GUID, err -} - -// GenerateMultipartPresignedURL helps sending requests to FENCE to get a presigned URL for a part during a multipart upload -func GenerateMultipartPresignedURL(g3 client.Gen3Interface, key string, uploadID string, partNumber int, bucketName string) (string, error) { - multipartUploadObject := MultipartUploadRequestObject{Key: key, UploadID: uploadID, PartNumber: partNumber, Bucket: bucketName} - objectBytes, err := json.Marshal(multipartUploadObject) - if err != nil { - return "", errors.New("Error has occurred during marshalling data for multipart upload presigned url generation, detailed error message: " + err.Error()) - } - - msg, err := g3.DoRequestWithSignedHeader(common.FenceDataMultipartUploadEndpoint, "application/json", objectBytes) - - if err != nil { - return "", errors.New("Error has occurred during multipart upload presigned url generation, detailed error message: " + err.Error()) - } - if msg.PresignedURL == "" { - return "", errors.New("unknown error has occurred during multipart upload presigned url generation. Please check logs from Gen3 services") - } - return msg.PresignedURL, err -} - -// CompleteMultipartUpload helps sending requests to FENCE to complete a multipart upload -func CompleteMultipartUpload(g3 client.Gen3Interface, key string, uploadID string, parts []MultipartPartObject, bucketName string) error { - multipartCompleteObject := MultipartCompleteRequestObject{Key: key, UploadID: uploadID, Parts: parts, Bucket: bucketName} - objectBytes, err := json.Marshal(multipartCompleteObject) - if err != nil { - return errors.New("Error has occurred during marshalling data for multipart upload, detailed error message: " + err.Error()) - } - - _, err = g3.DoRequestWithSignedHeader(common.FenceDataMultipartCompleteEndpoint, "application/json", objectBytes) - if err != nil { - return errors.New("Error has occurred during completing multipart upload, detailed error message: " + err.Error()) - } - return nil -} - -// GetDownloadResponse helps grabbing a response for downloading a file specified with GUID -func GetDownloadResponse(g3 client.Gen3Interface, fdrObject *common.FileDownloadResponseObject, protocolText string) error { - // Attempt to get the file download URL from Shepherd if it's deployed in this commons, - // otherwise fall back to Fence. - var fileDownloadURL string - hasShepherd, err := g3.CheckForShepherdAPI() - if err != nil { - g3.Logger().Println("Error occurred when checking for Shepherd API: " + err.Error()) - g3.Logger().Println("Falling back to Indexd...") - } else if hasShepherd { - endPointPostfix := common.ShepherdEndpoint + "/objects/" + fdrObject.GUID + "/download" - _, r, err := g3.GetResponse(endPointPostfix, "GET", "", nil) - if err != nil { - return errors.New("Error occurred when getting download URL for object " + fdrObject.GUID + " from endpoint " + endPointPostfix + " . Details: " + err.Error()) - } - defer r.Body.Close() - if r.StatusCode != 200 { - buf := new(bytes.Buffer) - buf.ReadFrom(r.Body) // nolint:errcheck - body := buf.String() - return errors.New("Error when getting download URL at " + endPointPostfix + " for file " + fdrObject.GUID + " : Shepherd returned non-200 status code " + strconv.Itoa(r.StatusCode) + " . Request body: " + body) - } - // Unmarshal into json - urlResponse := struct { - URL string `json:"url"` - }{} - err = json.NewDecoder(r.Body).Decode(&urlResponse) - if err != nil { - return errors.New("Error occurred when getting download URL for object " + fdrObject.GUID + " from endpoint " + endPointPostfix + " . Details: " + err.Error()) - } - fileDownloadURL = urlResponse.URL - if fileDownloadURL == "" { - return errors.New("Unknown error occurred when getting download URL for object " + fdrObject.GUID + " from endpoint " + endPointPostfix + " : No URL found in response body. Check the Shepherd logs") - } - } else { - endPointPostfix := common.FenceDataDownloadEndpoint + "/" + fdrObject.GUID + protocolText - msg, err := g3.DoRequestWithSignedHeader(endPointPostfix, "", nil) - - if err != nil || msg.URL == "" { - errorMsg := "Error occurred when getting download URL for object " + fdrObject.GUID - if err != nil { - errorMsg += "\n Details of error: " + err.Error() - } - return errors.New(errorMsg) - } - fileDownloadURL = msg.URL - } - - // TODO: for now we don't print fdrObject.URL in error messages since it is sensitive - // Later after we had log level we could consider for putting URL into debug logs... - fdrObject.URL = fileDownloadURL - if fdrObject.Range != 0 && !strings.Contains(fdrObject.URL, "X-Amz-Signature") && !strings.Contains(fdrObject.URL, "X-Goog-Signature") { // Not S3 or GS URLs and we want resume, send HEAD req first to check if server supports range - resp, err := http.Head(fdrObject.URL) - if err != nil { - errorMsg := "Error occurred when sending HEAD req to URL associated with GUID " + fdrObject.GUID - errorMsg += "\n Details of error: " + sanitizeErrorMsg(err.Error(), fdrObject.URL) - return errors.New(errorMsg) - } - if resp.Header.Get("Accept-Ranges") != "bytes" { // server does not support range, download without range header - fdrObject.Range = 0 - } - } - - headers := map[string]string{} - if fdrObject.Range != 0 { - headers["Range"] = "bytes=" + strconv.FormatInt(fdrObject.Range, 10) + "-" - } - resp, err := g3.MakeARequest(http.MethodGet, fdrObject.URL, "", "", headers, nil, true) - if err != nil { - errorMsg := "Error occurred when making request to URL associated with GUID " + fdrObject.GUID - errorMsg += "\n Details of error: " + sanitizeErrorMsg(err.Error(), fdrObject.URL) - return errors.New(errorMsg) - } - if resp.StatusCode != 200 && resp.StatusCode != 206 { - errorMsg := "Got a non-200 or non-206 response when making request to URL associated with GUID " + fdrObject.GUID - errorMsg += "\n HTTP status code for response: " + strconv.Itoa(resp.StatusCode) - return errors.New(errorMsg) - } - fdrObject.Response = resp - return nil -} - -func sanitizeErrorMsg(errorMsg string, sensitiveURL string) string { - return strings.ReplaceAll(errorMsg, sensitiveURL, "") -} - -// GeneratePresignedURL helps sending requests to Shepherd/Fence and parsing the response in order to get presigned URL for the new upload flow -func GeneratePresignedURL(g3 client.Gen3Interface, filename string, fileMetadata common.FileMetadata, bucketName string) (string, string, error) { - // Attempt to get the presigned URL of this file from Shepherd if it's deployed, otherwise fall back to Fence. - hasShepherd, err := g3.CheckForShepherdAPI() - if err != nil { - g3.Logger().Println("Error occurred when checking for Shepherd API: " + err.Error()) - g3.Logger().Println("Falling back to Fence...") - } else if hasShepherd { - purObject := ShepherdInitRequestObject{ - Filename: filename, - Authz: struct { - Version string `json:"version"` - ResourcePaths []string `json:"resource_paths"` - }{ - "0", - fileMetadata.Authz, - }, - Aliases: fileMetadata.Aliases, - Metadata: fileMetadata.Metadata, - } - objectBytes, err := json.Marshal(purObject) - if err != nil { - return "", "", errors.New("Error occurred when creating upload request for file " + filename + ". Details: " + err.Error()) - } - endPointPostfix := common.ShepherdEndpoint + "/objects" - _, r, err := g3.GetResponse(endPointPostfix, "POST", "", objectBytes) - if err != nil { - return "", "", errors.New("Error occurred when requesting upload URL from " + endPointPostfix + " for file " + filename + ". Details: " + err.Error()) - } - defer r.Body.Close() - if r.StatusCode != 201 { - buf := new(bytes.Buffer) - buf.ReadFrom(r.Body) // nolint:errcheck - body := buf.String() - return "", "", errors.New("Error when requesting upload URL at " + endPointPostfix + " for file " + filename + ": Shepherd returned non-200 status code " + strconv.Itoa(r.StatusCode) + ". Request body: " + body) - } - res := struct { - GUID string `json:"guid"` - URL string `json:"upload_url"` - }{} - err = json.NewDecoder(r.Body).Decode(&res) - if err != nil { - return "", "", errors.New("Error occurred when creating upload URL for file " + filename + ": . Details: " + err.Error()) - } - if res.URL == "" || res.GUID == "" { - return "", "", errors.New("unknown error has occurred during presigned URL or GUID generation. Please check logs from Gen3 services") - } - return res.URL, res.GUID, nil - } - - // Otherwise, fall back to Fence - purObject := InitRequestObject{Filename: filename, Bucket: bucketName} - objectBytes, err := json.Marshal(purObject) - if err != nil { - return "", "", errors.New("Error occurred when marshalling object: " + err.Error()) - } - msg, err := g3.DoRequestWithSignedHeader(common.FenceDataUploadEndpoint, "application/json", objectBytes) - - if err != nil { - return "", "", errors.New("Something went wrong. Maybe you don't have permission to upload data or Fence is misconfigured. Detailed error message: " + err.Error()) - } - if msg.URL == "" || msg.GUID == "" { - return "", "", errors.New("unknown error has occurred during presigned URL or GUID generation. Please check logs from Gen3 services") - } - return msg.URL, msg.GUID, err -} - -// GenerateUploadRequest helps preparing the HTTP request for upload and the progress bar for single part upload -func GenerateUploadRequest(g3 client.Gen3Interface, furObject common.FileUploadRequestObject, file *os.File, progress *mpb.Progress) (common.FileUploadRequestObject, error) { - if furObject.PresignedURL == "" { - endPointPostfix := common.FenceDataUploadEndpoint + "/" + furObject.GUID + "?file_name=" + url.QueryEscape(furObject.Filename) - - // ensure bucket is set - if furObject.Bucket != "" { - endPointPostfix += "&bucket=" + furObject.Bucket - } - msg, err := g3.DoRequestWithSignedHeader(endPointPostfix, "application/json", nil) - if err != nil && !strings.Contains(err.Error(), "No GUID found") { - return furObject, errors.New("Upload error: " + err.Error()) - } - if msg.URL == "" { - return furObject, errors.New("Upload error: error in generating presigned URL for " + furObject.Filename) - } - furObject.PresignedURL = msg.URL - } - - fi, err := file.Stat() - if err != nil { - return furObject, errors.New("File stat error for file" + furObject.Filename + ", file may be missing or unreadable because of permissions.\n") - } - - if fi.Size() > FileSizeLimit { - return furObject, errors.New("The file size of file " + furObject.Filename + " exceeds the limit allowed and cannot be uploaded. The maximum allowed file size is " + FormatSize(FileSizeLimit) + ".\n") - } - - if progress == nil { - progress = mpb.New(mpb.WithOutput(os.Stdout)) - } - bar := progress.AddBar(fi.Size(), - mpb.PrependDecorators( - decor.Name(furObject.Filename+" "), - decor.CountersKibiByte("% .1f / % .1f"), - ), - mpb.AppendDecorators( - decor.Percentage(), - decor.AverageSpeed(decor.SizeB1024(0), " % .1f"), - ), - ) - pr, pw := io.Pipe() - - go func() { - var writer io.Writer - defer pw.Close() - defer file.Close() - - writer = bar.ProxyWriter(pw) - if _, err = io.Copy(writer, file); err != nil { - err = errors.New("io.Copy error: " + err.Error() + "\n") - } - if err = pw.Close(); err != nil { - err = errors.New("Pipe writer close error: " + err.Error() + "\n") - } - }() - if err != nil { - return furObject, err - } - - req, err := http.NewRequest(http.MethodPut, furObject.PresignedURL, pr) - req.ContentLength = fi.Size() - - furObject.Request = req - furObject.Progress = progress - furObject.Bar = bar - - return furObject, err -} - -// DeleteRecord helps sending requests to FENCE to delete a record from INDEXD as well as its storage locations -func DeleteRecord(g3 client.Gen3Interface, guid string) (string, error) { - return g3.DeleteRecord(guid) -} - -func separateSingleAndMultipartUploads(g3i client.Gen3Interface, objects []common.FileUploadRequestObject, forceMultipart bool) ([]common.FileUploadRequestObject, []common.FileUploadRequestObject) { - fileSizeLimit := FileSizeLimit // 5GB - if forceMultipart { - fileSizeLimit = minMultipartChunkSize // 5MB - } - singlepartObjects := make([]common.FileUploadRequestObject, 0) - multipartObjects := make([]common.FileUploadRequestObject, 0) - - for _, object := range objects { - filePath := object.FilePath - - // Check if file exists locally - if _, err := os.Stat(filePath); os.IsNotExist(err) { - g3i.Logger().Printf("The file you specified \"%s\" does not exist locally\n", filePath) - g3i.Logger().Failed(object.FilePath, object.Filename, object.FileMetadata, object.GUID, 0, false) - continue - } - - // Use a closure to handle file operations and cleanup - func(obj common.FileUploadRequestObject) { - file, err := os.Open(filePath) - if err != nil { - g3i.Logger().Println("File open error occurred when validating file path: " + err.Error()) - g3i.Logger().Failed(obj.FilePath, obj.Filename, obj.FileMetadata, obj.GUID, 0, false) - return - } - defer file.Close() - - fi, err := file.Stat() - if err != nil { - g3i.Logger().Println("File stat error occurred when validating file path: " + err.Error()) - g3i.Logger().Failed(obj.FilePath, obj.Filename, obj.FileMetadata, obj.GUID, 0, false) - return - } - if fi.IsDir() { - return - } - - _, ok := g3i.Logger().GetSucceededLogMap()[filePath] - if ok { - g3i.Logger().Println("File \"" + filePath + "\" has been found in local submission history and has been skipped to prevent duplicated submissions.") - return - } - - // Add to failed log initially, it will be removed on success - // This is an existing pattern, keeping it here. - g3i.Logger().Failed(obj.FilePath, obj.Filename, obj.FileMetadata, obj.GUID, 0, false) - - if fi.Size() > MultipartFileSizeLimit { - g3i.Logger().Printf("The file size of %s has exceeded the limit allowed and cannot be uploaded. The maximum allowed file size is %s\n", fi.Name(), FormatSize(MultipartFileSizeLimit)) - } else if fi.Size() > int64(fileSizeLimit) { - multipartObjects = append(multipartObjects, obj) - } else { - singlepartObjects = append(singlepartObjects, obj) - } - }(object) - } - return singlepartObjects, multipartObjects -} - -// ProcessFilename returns an FileInfo object which has the information about the path and name to be used for upload of a file -func ProcessFilename(logger logs.Logger, uploadPath string, filePath string, objectId string, includeSubDirName bool, includeMetadata bool) (common.FileUploadRequestObject, error) { - var err error - filePath, err = common.GetAbsolutePath(filePath) - if err != nil { - return common.FileUploadRequestObject{}, err - } - - filename := filepath.Base(filePath) // Default to base filename - - var metadata common.FileMetadata - if includeSubDirName { - absUploadPath, err := common.GetAbsolutePath(uploadPath) - if err != nil { - return common.FileUploadRequestObject{}, err - } - - // Ensure absUploadPath is a directory path for relative calculation - // Trim the optional wildcard if present - uploadDir := strings.TrimSuffix(absUploadPath, common.PathSeparator+"*") - fileInfo, err := os.Stat(uploadDir) - if err != nil { - return common.FileUploadRequestObject{}, err - } - if fileInfo.IsDir() { - // Calculate the path of the file relative to the upload directory - relPath, err := filepath.Rel(uploadDir, filePath) - if err != nil { - return common.FileUploadRequestObject{}, err - } - filename = relPath - } - } - - if includeMetadata { - // The metadata path is the file name plus '_metadata.json' - metadataFilePath := strings.TrimSuffix(filePath, filepath.Ext(filePath)) + "_metadata.json" - var metadataFileBytes []byte - if _, err := os.Stat(metadataFilePath); err == nil { - metadataFileBytes, err = os.ReadFile(metadataFilePath) - if err != nil { - return common.FileUploadRequestObject{}, errors.New("Error reading metadata file " + metadataFilePath + ": " + err.Error()) - } - err := json.Unmarshal(metadataFileBytes, &metadata) - if err != nil { - return common.FileUploadRequestObject{}, errors.New("Error parsing metadata file " + metadataFilePath + ": " + err.Error()) - } - } else { - // No metadata file was found for this file -- proceed, but warn the user. - logger.Printf("WARNING: File metadata is enabled, but could not find the metadata file %v for file %v. Execute `data-client upload --help` for more info on file metadata.\n", metadataFilePath, filePath) - } - } - return common.FileUploadRequestObject{FilePath: filePath, Filename: filename, FileMetadata: metadata, GUID: objectId}, nil -} - -func getFullFilePath(filePath string, filename string) (string, error) { - filePath, err := common.GetAbsolutePath(filePath) - if err != nil { - return "", err - } - fi, err := os.Stat(filePath) - if err != nil { - return "", err - } - switch mode := fi.Mode(); { - case mode.IsDir(): - if strings.HasSuffix(filePath, "/") { - return filePath + filename, nil - } - return filePath + "/" + filename, nil - case mode.IsRegular(): - return "", errors.New("in manifest upload mode filePath must be a dir") - default: - return "", errors.New("full file path creation unsuccessful") - } -} - -func uploadFile(g3i client.Gen3Interface, furObject common.FileUploadRequestObject, retryCount int) error { - g3i.Logger().Println("Uploading data ...") - if furObject.Progress != nil { - defer furObject.Progress.Wait() - } - - client := &http.Client{} - resp, err := client.Do(furObject.Request) - if err != nil { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, retryCount, false) - return errors.New("Error occurred during upload: " + err.Error()) - } - if resp.StatusCode != 200 { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, retryCount, false) - return errors.New("Upload request got a non-200 response with status code " + strconv.Itoa(resp.StatusCode)) - } - g3i.Logger().Printf("Successfully uploaded file \"%s\" to GUID %s.\n", furObject.FilePath, furObject.GUID) - g3i.Logger().DeleteFromFailedLog(furObject.FilePath) - g3i.Logger().Succeeded(furObject.FilePath, furObject.GUID) - return nil -} - -func getNumberOfWorkers(numParallel int, inputSliceLen int) int { - workers := numParallel - if workers < 1 || workers > inputSliceLen { - workers = inputSliceLen - } - return workers -} - -func initBatchUploadChannels(numParallel int, inputSliceLen int) (int, chan *http.Response, chan error, []common.FileUploadRequestObject) { - workers := getNumberOfWorkers(numParallel, inputSliceLen) - respCh := make(chan *http.Response, inputSliceLen) - errCh := make(chan error, inputSliceLen) - batchFURSlice := make([]common.FileUploadRequestObject, 0) - return workers, respCh, errCh, batchFURSlice -} - -func batchUpload(g3i client.Gen3Interface, furObjects []common.FileUploadRequestObject, workers int, respCh chan *http.Response, errCh chan error, bucketName string) { - progress := mpb.New(mpb.WithOutput(os.Stdout)) - respURL := "" - var err error - var guid string - - for i := range furObjects { - if furObjects[i].Bucket == "" { - furObjects[i].Bucket = bucketName - } - if furObjects[i].GUID == "" { - respURL, guid, err = GeneratePresignedURL(g3i, furObjects[i].Filename, furObjects[i].FileMetadata, bucketName) - if err != nil { - g3i.Logger().Failed(furObjects[i].FilePath, furObjects[i].Filename, furObjects[i].FileMetadata, guid, 0, false) - errCh <- err - continue - } - furObjects[i].PresignedURL = respURL - furObjects[i].GUID = guid - // update failed log with new guid - g3i.Logger().Failed(furObjects[i].FilePath, furObjects[i].Filename, furObjects[i].FileMetadata, guid, 0, false) - } - file, err := os.Open(furObjects[i].FilePath) - if err != nil { - g3i.Logger().Failed(furObjects[i].FilePath, furObjects[i].Filename, furObjects[i].FileMetadata, furObjects[i].GUID, 0, false) - errCh <- errors.New("File open error: " + err.Error()) - continue - } - defer file.Close() - - furObjects[i], err = GenerateUploadRequest(g3i, furObjects[i], file, progress) - if err != nil { - file.Close() - g3i.Logger().Failed(furObjects[i].FilePath, furObjects[i].Filename, furObjects[i].FileMetadata, furObjects[i].GUID, 0, false) - errCh <- errors.New("Error occurred during request generation: " + err.Error()) - continue - } - } - - furObjectCh := make(chan common.FileUploadRequestObject, len(furObjects)) - - client := &http.Client{} - wg := sync.WaitGroup{} - for range workers { - wg.Add(1) - go func() { - for furObject := range furObjectCh { - if furObject.Request != nil { - resp, err := client.Do(furObject.Request) - if err != nil { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - errCh <- err - } else { - if resp.StatusCode != 200 { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - } else { - respCh <- resp - g3i.Logger().DeleteFromFailedLog(furObject.FilePath) - g3i.Logger().Succeeded(furObject.FilePath, furObject.GUID) - g3i.Logger().Scoreboard().IncrementSB(0) - } - } - } else if furObject.FilePath != "" { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) - } - } - wg.Done() - }() - } - - for i := range furObjects { - furObjectCh <- furObjects[i] - } - close(furObjectCh) - - wg.Wait() - progress.Wait() -} - -// GetWaitTime calculates the wait time for the next retry based on retry count -func GetWaitTime(retryCount int) time.Duration { - exponentialWaitTime := math.Pow(2, float64(retryCount)) - return time.Duration(math.Min(exponentialWaitTime, float64(maxWaitTime))) * time.Second -} - -// FormatSize helps to parse a int64 size into string -func FormatSize(size int64) string { - var unitSize int64 - switch { - case size >= TB: - unitSize = TB - case size >= GB: - unitSize = GB - case size >= MB: - unitSize = MB - case size >= KB: - unitSize = KB - default: - unitSize = B - } - - return fmt.Sprintf("%.1f"+unitMap[unitSize], float64(size)/float64(unitSize)) -} diff --git a/client/gen3Client/client.go b/client/gen3Client/client.go deleted file mode 100644 index a4e46c7..0000000 --- a/client/gen3Client/client.go +++ /dev/null @@ -1,120 +0,0 @@ -package client - -import ( - "bytes" - "context" - "errors" - "fmt" - "net/http" - "net/url" - - "github.com/calypr/data-client/client/jwt" - "github.com/calypr/data-client/client/logs" -) - -//go:generate mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/gen3Client Gen3Interface - -// Gen3Interface contains methods used to make authorized http requests to Gen3 services. -// The credential is embedded in the implementation, so it doesn't need to be passed to each method. -type Gen3Interface interface { - CheckPrivileges() (string, map[string]any, error) - CheckForShepherdAPI() (bool, error) - GetResponse(endpointPostPrefix string, method string, contentType string, bodyBytes []byte) (string, *http.Response, error) - DoRequestWithSignedHeader(endpointPostPrefix string, contentType string, bodyBytes []byte) (jwt.JsonMessage, error) - MakeARequest(method string, apiEndpoint string, accessToken string, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) - GetHost() (*url.URL, error) - GetCredential() *jwt.Credential - DeleteRecord(guid string) (string, error) - - Logger() *logs.TeeLogger -} - -// Gen3Client wraps jwt.FunctionInterface and embeds the credential -type Gen3Client struct { - Ctx context.Context - FunctionInterface jwt.FunctionInterface - credential *jwt.Credential - - logger *logs.TeeLogger -} - -func (g *Gen3Client) Logger() *logs.TeeLogger { - return g.logger -} - -// CheckPrivileges wraps the underlying method with embedded credential -func (g *Gen3Client) CheckPrivileges() (string, map[string]any, error) { - return g.FunctionInterface.CheckPrivileges(g.credential) -} - -// CheckForShepherdAPI wraps the underlying method with embedded credential -func (g *Gen3Client) CheckForShepherdAPI() (bool, error) { - return g.FunctionInterface.CheckForShepherdAPI(g.credential) -} - -// GetResponse wraps the underlying method with embedded credential -func (g *Gen3Client) GetResponse(endpointPostPrefix string, method string, contentType string, bodyBytes []byte) (string, *http.Response, error) { - return g.FunctionInterface.GetResponse(g.credential, endpointPostPrefix, method, contentType, bodyBytes) -} - -// DoRequestWithSignedHeader wraps the underlying method with embedded credential -func (g *Gen3Client) DoRequestWithSignedHeader(endpointPostPrefix string, contentType string, bodyBytes []byte) (jwt.JsonMessage, error) { - return g.FunctionInterface.DoRequestWithSignedHeader(g.credential, endpointPostPrefix, contentType, bodyBytes) -} - -// GetHost wraps the underlying method with embedded credential -func (g *Gen3Client) GetHost() (*url.URL, error) { - return g.FunctionInterface.GetHost(g.credential) -} - -// GetCredential returns the embedded credential -func (g *Gen3Client) GetCredential() *jwt.Credential { - return g.credential -} - -// MakeARequest wraps the underlying Request.MakeARequest method -func (g *Gen3Client) MakeARequest(method string, apiEndpoint string, accessToken string, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) { - // Access the underlying Request through the Functions struct - // We need to create a temporary Request instance since we can't access it directly - if functions, ok := g.FunctionInterface.(*jwt.Functions); ok { - return functions.Request.MakeARequest(method, apiEndpoint, accessToken, contentType, headers, body, noTimeout) - } - return nil, errors.New("unable to access MakeARequest method") -} - -// DeleteRecord deletes a record from INDEXD as well as its storage locations -func (g *Gen3Client) DeleteRecord(guid string) (string, error) { - // Use the embedded credential - // Since DeleteRecord is not part of FunctionInterface, we need to access it via type assertion - // or create a new Functions instance. We'll use type assertion first. - if functions, ok := g.FunctionInterface.(*jwt.Functions); ok { - return functions.DeleteRecord(g.credential, guid) - } - - // This should never happen, but handle it gracefully - return "", errors.New("unable to access DeleteRecord method") -} - -// NewGen3Interface returns a Gen3Client that embeds the credential and implements Gen3Interface. -// This eliminates the need to pass credentials around everywhere. -func NewGen3Interface(ctx context.Context, profile string, logger *logs.TeeLogger, opts ...func(*Gen3Client)) (Gen3Interface, error) { - // Note: A tee logger must be passed here otherwise you risk causing panics. - - config := &jwt.Configure{} - request := &jwt.Request{Ctx: ctx, Logs: logger} - client := jwt.NewFunctions(ctx, config, request) - - cred, err := config.ParseConfig(profile) - if err != nil { - return nil, err - } - if valid, err := config.IsValidCredential(cred); !valid { - return nil, fmt.Errorf("invalid credential: %v", err) - } - - return &Gen3Client{ - FunctionInterface: client, - credential: &cred, - logger: logger, - }, nil -} diff --git a/client/jwt/configure.go b/client/jwt/configure.go deleted file mode 100644 index 2b9b7f6..0000000 --- a/client/jwt/configure.go +++ /dev/null @@ -1,321 +0,0 @@ -package jwt - -//go:generate mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/jwt ConfigureInterface - -import ( - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "path" - "regexp" - "strings" - "time" - - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/golang-jwt/jwt/v5" - "gopkg.in/ini.v1" -) - -var ErrProfileNotFound = errors.New("profile not found in config file") - -type Credential struct { - Profile string - KeyId string - APIKey string - AccessToken string - APIEndpoint string - UseShepherd string - MinShepherdVersion string -} - -type Configure struct { - Logs logs.Logger -} - -type ConfigureInterface interface { - ReadFile(string, string) string - ValidateUrl(string) (*url.URL, error) - GetConfigPath() (string, error) - UpdateConfigFile(Credential) error - ParseKeyValue(str string, expr string) (string, error) - ParseConfig(profile string) (Credential, error) - IsValidCredential(Credential) (bool, error) -} - -func (conf *Configure) ReadFile(filePath string, fileType string) string { - //Look in config file - fullFilePath, err := common.GetAbsolutePath(filePath) - if err != nil { - conf.Logs.Println("error occurred when parsing config file path: " + err.Error()) - return "" - } - if _, err := os.Stat(fullFilePath); err != nil { - conf.Logs.Println("File specified at " + fullFilePath + " not found") - return "" - } - - content, err := os.ReadFile(fullFilePath) - if err != nil { - conf.Logs.Println("error occurred when reading file: " + err.Error()) - return "" - } - - contentStr := string(content[:]) - - if fileType == "json" { - contentStr = strings.ReplaceAll(contentStr, "\n", "") - } - return contentStr -} - -func (conf *Configure) ValidateUrl(apiEndpoint string) (*url.URL, error) { - parsedURL, err := url.Parse(apiEndpoint) - if err != nil { - return parsedURL, errors.New("Error occurred when parsing apiendpoint URL: " + err.Error()) - } - if parsedURL.Host == "" { - return parsedURL, errors.New("Invalid endpoint. A valid endpoint looks like: https://www.tests.com") - } - return parsedURL, nil -} - -func (conf *Configure) ReadCredentials(filePath string, fenceToken string) (*Credential, error) { - var profileConfig Credential - if filePath != "" { - jsonContent := conf.ReadFile(filePath, "json") - jsonContent = strings.ReplaceAll(jsonContent, "key_id", "KeyId") - jsonContent = strings.ReplaceAll(jsonContent, "api_key", "APIKey") - err := json.Unmarshal([]byte(jsonContent), &profileConfig) - if err != nil { - errs := fmt.Errorf("Cannot read json file: %s", err.Error()) - conf.Logs.Println(errs.Error()) - return nil, errs - } - } else if fenceToken != "" { - profileConfig.AccessToken = fenceToken - } - return &profileConfig, nil -} - -func (conf *Configure) GetConfigPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", err - } - configPath := path.Join(homeDir + common.PathSeparator + ".gen3" + common.PathSeparator + "gen3_client_config.ini") - return configPath, nil -} - -func (conf *Configure) InitConfigFile() error { - /* - Make sure the config exists on start up - */ - configPath, err := conf.GetConfigPath() - if err != nil { - return err - } - - if _, err := os.Stat(path.Dir(configPath)); os.IsNotExist(err) { - osErr := os.Mkdir(path.Join(path.Dir(configPath)), os.FileMode(0777)) - if osErr != nil { - return err - } - _, osErr = os.Create(configPath) - if osErr != nil { - return err - } - } - if _, err := os.Stat(configPath); os.IsNotExist(err) { - _, osErr := os.Create(configPath) - if osErr != nil { - return err - } - } - _, err = ini.Load(configPath) - - return err -} - -func (conf *Configure) UpdateConfigFile(profileConfig Credential) error { - /* - Overwrite the config file with new credential - - Args: - profileConfig: Credential object represents config of a profile - configPath: file path to config file - */ - configPath, err := conf.GetConfigPath() - if err != nil { - errs := fmt.Errorf("error occurred when getting config path: %s", err.Error()) - conf.Logs.Println(errs.Error()) - return errs - } - cfg, err := ini.Load(configPath) - if err != nil { - errs := fmt.Errorf("error occurred when loading config file: %s", err.Error()) - conf.Logs.Println(errs.Error()) - return errs - } - - section := cfg.Section(profileConfig.Profile) - if profileConfig.KeyId != "" { - section.Key("key_id").SetValue(profileConfig.KeyId) - } - if profileConfig.APIKey != "" { - section.Key("api_key").SetValue(profileConfig.APIKey) - } - if profileConfig.AccessToken != "" { - section.Key("access_token").SetValue(profileConfig.AccessToken) - } - if profileConfig.APIEndpoint != "" { - section.Key("api_endpoint").SetValue(profileConfig.APIEndpoint) - } - - section.Key("use_shepherd").SetValue(profileConfig.UseShepherd) - section.Key("min_shepherd_version").SetValue(profileConfig.MinShepherdVersion) - err = cfg.SaveTo(configPath) - if err != nil { - errs := fmt.Errorf("error occurred when saving config file: %s", err.Error()) - return errs - } - return nil -} - -func (conf *Configure) ParseKeyValue(str string, expr string) (string, error) { - r, err := regexp.Compile(expr) - if err != nil { - return "", fmt.Errorf("error occurred when parsing key/value: %v", err.Error()) - } - match := r.FindStringSubmatch(str) - if len(match) == 0 { - return "", fmt.Errorf("No match found") - } - return match[1], nil -} - -func (conf *Configure) ParseConfig(profile string) (Credential, error) { - /* - Looking profile in config file. The config file is a text file located at ~/.gen3 directory. It can - contain more than 1 profile. If there is no profile found, the user is asked to run a command to - create the profile - - The format of config file is described as following - - [profile1] - key_id=key_id_example_1 - api_key=api_key_example_1 - access_token=access_token_example_1 - api_endpoint=http://localhost:8000 - use_shepherd=true - min_shepherd_version=2.0.0 - - [profile2] - key_id=key_id_example_2 - api_key=api_key_example_2 - access_token=access_token_example_2 - api_endpoint=http://localhost:8000 - use_shepherd=false - min_shepherd_version= - - Args: - profile: the specific profile in config file - Returns: - An instance of Credential - */ - - homeDir, err := os.UserHomeDir() - if err != nil { - errs := fmt.Errorf("Error occurred when getting home directory: %s", err.Error()) - return Credential{}, errs - } - configPath := path.Join(homeDir + common.PathSeparator + ".gen3" + common.PathSeparator + "gen3_client_config.ini") - profileConfig := Credential{ - Profile: profile, - KeyId: "", - APIKey: "", - AccessToken: "", - APIEndpoint: "", - } - if _, err := os.Stat(configPath); os.IsNotExist(err) { - return Credential{}, fmt.Errorf("%w Run configure command (with a profile if desired) to set up account credentials \n"+ - "Example: ./data-client configure --profile= --cred= --apiendpoint=https://data.mycommons.org", ErrProfileNotFound) - } - - // If profile not in config file, prompt user to set up config first - cfg, err := ini.Load(configPath) - if err != nil { - errs := fmt.Errorf("Error occurred when reading config file: %s", err.Error()) - return Credential{}, errs - } - sec, err := cfg.GetSection(profile) - if err != nil { - return Credential{}, fmt.Errorf("%w: Need to run \"data-client configure --profile="+profile+" --cred= --apiendpoint=\" first", ErrProfileNotFound) - } - // Read in API key, key ID and endpoint for given profile - profileConfig.KeyId = sec.Key("key_id").String() - profileConfig.APIKey = sec.Key("api_key").String() - profileConfig.AccessToken = sec.Key("access_token").String() - - if profileConfig.KeyId == "" && profileConfig.APIKey == "" && profileConfig.AccessToken == "" { - errs := fmt.Errorf("key_id, api_key and access_token not found in profile.") - return Credential{}, errs - } - profileConfig.APIEndpoint = sec.Key("api_endpoint").String() - if profileConfig.APIEndpoint == "" { - errs := fmt.Errorf("api_endpoint not found in profile.") - return Credential{}, errs - } - // UseShepherd and MinShepherdVersion are optional - profileConfig.UseShepherd = sec.Key("use_shepherd").String() - profileConfig.MinShepherdVersion = sec.Key("min_shepherd_version").String() - - return profileConfig, nil -} - -func (conf *Configure) IsValidCredential(profileConfig Credential) (bool, error) { - /* Checks to see if credential in credential file is still valid */ - const expirationThresholdDays = 10 - // Parse the token without verifying the signature to access the claims. - token, _, err := new(jwt.Parser).ParseUnverified(profileConfig.APIKey, jwt.MapClaims{}) - if err != nil { - return false, fmt.Errorf("ERROR: Invalid token format: %v", err) - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return false, fmt.Errorf("Unable to parse claims from provided token %#v", token) - } - - exp, ok := claims["exp"].(float64) - if !ok { - return false, fmt.Errorf("ERROR: 'exp' claim not found or is not a number for claims %s", claims) - } - - iat, ok := claims["iat"].(float64) - if !ok { - return false, fmt.Errorf("ERROR: 'iat' claim not found or is not a number for claims %s", claims) - } - - now := time.Now().UTC() - expTime := time.Unix(int64(exp), 0).UTC() - iatTime := time.Unix(int64(iat), 0).UTC() - - if expTime.Before(now) { - return false, fmt.Errorf("key %s expired %s < %s", profileConfig.APIKey, expTime.Format(time.RFC3339), now.Format(time.RFC3339)) - } - if iatTime.After(now) { - return false, fmt.Errorf("key %s not yet valid %s > %s", profileConfig.APIKey, iatTime.Format(time.RFC3339), now.Format(time.RFC3339)) - } - - delta := expTime.Sub(now) - if delta > 0 && delta.Hours() < float64(expirationThresholdDays*24) { - daysUntilExpiration := int(delta.Hours() / 24) - if daysUntilExpiration > 0 { - return true, fmt.Errorf("WARNING %s: Key will expire in %d days, on %s", profileConfig.APIKey, daysUntilExpiration, expTime.Format(time.RFC3339)) - } - } - return true, nil -} diff --git a/client/jwt/functions.go b/client/jwt/functions.go deleted file mode 100644 index 004d61b..0000000 --- a/client/jwt/functions.go +++ /dev/null @@ -1,370 +0,0 @@ -package jwt - -//go:generate mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/jwt FunctionInterface -//go:generate mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/jwt RequestInterface - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "strconv" - "strings" - - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/hashicorp/go-version" -) - -func NewFunctions(ctx context.Context, config ConfigureInterface, request RequestInterface) FunctionInterface { - return &Functions{ - Config: config, - Request: request, - } -} - -type Functions struct { - Request RequestInterface - Config ConfigureInterface -} - -type FunctionInterface interface { - CheckPrivileges(profileConfig *Credential) (string, map[string]any, error) - CheckForShepherdAPI(profileConfig *Credential) (bool, error) - GetResponse(profileConfig *Credential, endpointPostPrefix string, method string, contentType string, bodyBytes []byte) (string, *http.Response, error) - DoRequestWithSignedHeader(profileConfig *Credential, endpointPostPrefix string, contentType string, bodyBytes []byte) (JsonMessage, error) - ParseFenceURLResponse(resp *http.Response) (JsonMessage, error) - GetHost(profileConfig *Credential) (*url.URL, error) -} - -type Request struct { - Logs logs.Logger - Ctx context.Context -} - -type RequestInterface interface { - MakeARequest(method string, apiEndpoint string, accessToken string, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) - RequestNewAccessToken(accessTokenEndpoint string, profileConfig *Credential) error - Logger() logs.Logger -} - -func (r *Request) Logger() logs.Logger { - return r.Logs -} - -func (r *Request) MakeARequest(method string, apiEndpoint string, accessToken string, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) { - /* - Make http request with header and body - */ - if headers == nil { - headers = make(map[string]string) - } - if accessToken != "" { - headers["Authorization"] = "Bearer " + accessToken - } - if contentType != "" { - headers["Content-Type"] = contentType - } - var client *http.Client - if noTimeout { - client = &http.Client{} - } else { - client = &http.Client{Timeout: common.DefaultTimeout} - } - var req *http.Request - var err error - if body == nil { - req, err = http.NewRequestWithContext(r.Ctx, method, apiEndpoint, nil) - } else { - req, err = http.NewRequestWithContext(r.Ctx, method, apiEndpoint, body) - } - if err != nil { - return nil, errors.New("Error occurred during generating HTTP request: " + err.Error()) - } - for k, v := range headers { - req.Header.Add(k, v) - } - resp, err := client.Do(req) - if err != nil { - return nil, errors.New("Error occurred during making HTTP request: " + err.Error()) - } - return resp, nil -} - -func (r *Request) RequestNewAccessToken(accessTokenEndpoint string, profileConfig *Credential) error { - /* - Request new access token to replace the expired one. - - Args: - accessTokenEndpoint: the api endpoint for request new access token - Returns: - profileConfig: new credential - err: error - - */ - body := bytes.NewBufferString("{\"api_key\": \"" + profileConfig.APIKey + "\"}") - resp, err := r.MakeARequest("POST", accessTokenEndpoint, "", "application/json", nil, body, false) - var m AccessTokenStruct - // parse resp error codes first for profile configuration verification - if resp != nil && resp.StatusCode != 200 { - return errors.New("Error occurred in RequestNewAccessToken with error code " + strconv.Itoa(resp.StatusCode) + ", check FENCE log for more details.") - } - if err != nil { - return errors.New("Error occurred in RequestNewAccessToken: " + err.Error()) - } - defer resp.Body.Close() - - str := ResponseToString(resp) - err = DecodeJsonFromString(str, &m) - if err != nil { - return errors.New("Error occurred in RequestNewAccessToken: " + err.Error()) - } - - if m.AccessToken == "" { - return errors.New("Could not get new access key from response string: " + str) - } - profileConfig.AccessToken = m.AccessToken - return nil -} - -func (f *Functions) ParseFenceURLResponse(resp *http.Response) (JsonMessage, error) { - msg := JsonMessage{} - - if resp == nil { - return msg, errors.New("Nil response received") - } - - // Capture the body for error reporting before we do anything else - // Using your existing ResponseToString helper - bodyStr := ResponseToString(resp) - - if !(resp.StatusCode == 200 || resp.StatusCode == 201) { - // Prepare a base error that includes the body content - errorMessage := fmt.Sprintf("Status: %d | Response: %s", resp.StatusCode, bodyStr) - - switch resp.StatusCode { - case 401: - return msg, fmt.Errorf("401 Unauthorized: %s", errorMessage) - case 403: - return msg, fmt.Errorf("403 Forbidden: %s (URL: %s)", bodyStr, resp.Request.URL.String()) - case 404: - return msg, fmt.Errorf("404 Not Found: %s (URL: %s)", bodyStr, resp.Request.URL.String()) - case 500: - return msg, fmt.Errorf("500 Internal Server Error: %s", bodyStr) - case 503: - return msg, fmt.Errorf("503 Service Unavailable: %s", bodyStr) - default: - return msg, fmt.Errorf("Unexpected Error (%d): %s", resp.StatusCode, bodyStr) - } - } - - // Logic for successful status codes - if strings.Contains(bodyStr, "Can't find a location for the data") { - return msg, errors.New("The provided GUID is not found") - } - - err := DecodeJsonFromString(bodyStr, &msg) - if err != nil { - return msg, fmt.Errorf("failed to decode JSON: %w (Raw body: %s)", err, bodyStr) - } - - return msg, nil -} -func (f *Functions) CheckForShepherdAPI(profileConfig *Credential) (bool, error) { - // Check if Shepherd is enabled - if profileConfig.UseShepherd == "false" { - return false, nil - } - if profileConfig.UseShepherd != "true" && common.DefaultUseShepherd == false { - return false, nil - } - // If Shepherd is enabled, make sure that the commons has a compatible version of Shepherd deployed. - // Compare the version returned from the Shepherd version endpoint with the minimum acceptable Shepherd version. - var minShepherdVersion string - if profileConfig.MinShepherdVersion == "" { - minShepherdVersion = common.DefaultMinShepherdVersion - } else { - minShepherdVersion = profileConfig.MinShepherdVersion - } - - _, res, err := f.GetResponse(profileConfig, common.ShepherdVersionEndpoint, "GET", "", nil) - if err != nil { - return false, errors.New("Error occurred during generating HTTP request: " + err.Error()) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return false, nil - } - bodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return false, errors.New("Error occurred when reading HTTP request: " + err.Error()) - } - body, err := strconv.Unquote(string(bodyBytes)) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) - } - // Compare the version in the response to the target version - ver, err := version.NewVersion(body) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) - } - minVer, err := version.NewVersion(minShepherdVersion) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing minimum acceptable Shepherd version: %v: %v", minShepherdVersion, err) - } - if ver.GreaterThanOrEqual(minVer) { - return true, nil - } - return false, fmt.Errorf("Shepherd is enabled, but %v does not have correct Shepherd version. (Need Shepherd version >=%v, got %v)", profileConfig.APIEndpoint, minVer, ver) -} - -func (f *Functions) GetResponse(profileConfig *Credential, endpointPostPrefix string, method string, contentType string, bodyBytes []byte) (string, *http.Response, error) { - - var resp *http.Response - var err error - - if profileConfig.APIKey == "" && profileConfig.AccessToken == "" && profileConfig.APIEndpoint == "" { - return "", resp, fmt.Errorf("No credentials found in the configuration file! Please use \"./data-client configure\" to configure your credentials first %s", profileConfig) - } - - host, _ := url.Parse(profileConfig.APIEndpoint) - prefixEndPoint := host.Scheme + "://" + host.Host - apiEndpoint := host.Scheme + "://" + host.Host + endpointPostPrefix - isExpiredToken := false - if profileConfig.AccessToken != "" { - resp, err = f.Request.MakeARequest(method, apiEndpoint, profileConfig.AccessToken, contentType, nil, bytes.NewBuffer(bodyBytes), false) - if err != nil { - return "", resp, fmt.Errorf("Error while requesting user access token at %v: %v", apiEndpoint, err) - } - - // 401 code is general error code from FENCE. the error message is also not clear for the case - // that the token expired. Temporary solution: get new access token and make another attempt. - if resp != nil && (resp.StatusCode == 401 || resp.StatusCode == 503) { - isExpiredToken = true - } else { - return prefixEndPoint, resp, err - } - } - if profileConfig.AccessToken == "" || isExpiredToken { - err := f.Request.RequestNewAccessToken(prefixEndPoint+common.FenceAccessTokenEndpoint, profileConfig) - if err != nil { - return prefixEndPoint, resp, err - } - err = f.Config.UpdateConfigFile(*profileConfig) - if err != nil { - return prefixEndPoint, resp, err - } - - resp, err = f.Request.MakeARequest(method, apiEndpoint, profileConfig.AccessToken, contentType, nil, bytes.NewBuffer(bodyBytes), false) - if err != nil { - return prefixEndPoint, resp, err - } - } - - return prefixEndPoint, resp, nil -} - -func (f *Functions) GetHost(profileConfig *Credential) (*url.URL, error) { - if profileConfig.APIEndpoint == "" { - return nil, errors.New("No APIEndpoint found in the configuration file! Please use \"./data-client configure\" to configure your credentials first") - } - host, _ := url.Parse(profileConfig.APIEndpoint) - return host, nil -} - -func (f *Functions) DoRequestWithSignedHeader(profileConfig *Credential, endpointPostPrefix string, contentType string, bodyBytes []byte) (JsonMessage, error) { - /* - Do request with signed header. User may have more than one profile and use a profile to make a request - */ - var err error - var msg JsonMessage - - method := "GET" - if bodyBytes != nil { - method = "POST" - } - - _, resp, err := f.GetResponse(profileConfig, endpointPostPrefix, method, contentType, bodyBytes) - if err != nil { - return msg, err - } - defer resp.Body.Close() - - msg, err = f.ParseFenceURLResponse(resp) - return msg, err -} - -func (f *Functions) CheckPrivileges(profileConfig *Credential) (string, map[string]any, error) { - /* - Return user privileges from specified profile - */ - var err error - var data map[string]any - - host, resp, err := f.GetResponse(profileConfig, common.FenceUserEndpoint, "GET", "", nil) - if err != nil { - return "", nil, errors.New("Error occurred when getting response from remote: " + err.Error()) - } - defer resp.Body.Close() - - str := ResponseToString(resp) - err = json.Unmarshal([]byte(str), &data) - if err != nil { - return "", nil, errors.New("Error occurred when unmarshalling response: " + err.Error()) - } - - resourceAccess, ok := data["authz"].(map[string]any) - - // If the `authz` section (Arborist permissions) is empty or missing, try get `project_access` section (Fence permissions) - if len(resourceAccess) == 0 || !ok { - resourceAccess, ok = data["project_access"].(map[string]any) - if !ok { - return "", nil, errors.New("Not possible to read access privileges of user") - } - } - - return host, resourceAccess, err -} - -func (f *Functions) DeleteRecord(profileConfig *Credential, guid string) (string, error) { - var err error - var msg string - - hasShepherd, err := f.CheckForShepherdAPI(profileConfig) - if err != nil { - f.Request.Logger().Printf("WARNING: Error while checking for Shepherd API: %v. Falling back to Fence to delete record.\n", err) - } else if hasShepherd { - endPointPostfix := common.ShepherdEndpoint + "/objects/" + guid - _, resp, err := f.GetResponse(profileConfig, endPointPostfix, "DELETE", "", nil) - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode == 204 { - msg = "Record with GUID " + guid + " has been deleted" - } else if resp.StatusCode == 500 { - err = errors.New("Internal server error occurred when deleting " + guid + "; could not delete stored files, or not able to delete INDEXD record") - } - return msg, err - } - - endPointPostfix := common.FenceDataEndpoint + "/" + guid - - _, resp, err := f.GetResponse(profileConfig, endPointPostfix, "DELETE", "", nil) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode == 204 { - msg = "Record with GUID " + guid + " has been deleted" - } else if resp.StatusCode == 500 { - err = errors.New("Internal server error occurred when deleting " + guid + "; could not delete stored files, or not able to delete INDEXD record") - } - - return msg, err -} diff --git a/client/jwt/update.go b/client/jwt/update.go deleted file mode 100644 index b2d9cc9..0000000 --- a/client/jwt/update.go +++ /dev/null @@ -1,78 +0,0 @@ -package jwt - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/hashicorp/go-version" -) - -func UpdateConfig(logger logs.Logger, cred *Credential) error { - var conf Configure - var req Request = Request{Ctx: context.Background()} - - if cred.Profile == "" { - return fmt.Errorf("profile name is required") - } - if cred.APIEndpoint == "" { - return fmt.Errorf("API endpoint is required") - } - - // Normalize endpoint - cred.APIEndpoint = strings.TrimSpace(cred.APIEndpoint) - cred.APIEndpoint = strings.TrimSuffix(cred.APIEndpoint, "/") - - // Validate URL format - parsedURL, err := conf.ValidateUrl(cred.APIEndpoint) - if err != nil { - return fmt.Errorf("invalid apiendpoint URL: %w", err) - } - fenceBase := parsedURL.Scheme + "://" + parsedURL.Host - if existingCfg, err := conf.ParseConfig(cred.Profile); err == nil { - // Only copy optional fields if the user didn't override them via flags - if cred.UseShepherd == "" { - cred.UseShepherd = existingCfg.UseShepherd - } - if cred.MinShepherdVersion == "" { - cred.MinShepherdVersion = existingCfg.MinShepherdVersion - } - } else if !errors.Is(err, ErrProfileNotFound) { - return err - } - - if cred.APIKey != "" { - // Always refresh the access token — ignore any old one that might be in the struct - err = req.RequestNewAccessToken(fenceBase+common.FenceAccessTokenEndpoint, cred) - if err != nil { - if strings.Contains(err.Error(), "401") { - return fmt.Errorf("authentication failed (401) for %s — your API key is invalid, revoked, or expired", fenceBase) - } - if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "no such host") { - return fmt.Errorf("cannot reach Fence at %s — is this a valid Gen3 commons?", fenceBase) - } - return fmt.Errorf("failed to refresh access token: %w", err) - } - } else { - logger.Printf("WARNING: Your profile will only be valid for 24 hours since you have only provided a refresh token for authentication") - } - - // Clean up shepherd flags - cred.UseShepherd = strings.TrimSpace(cred.UseShepherd) - cred.MinShepherdVersion = strings.TrimSpace(cred.MinShepherdVersion) - - if cred.MinShepherdVersion != "" { - if _, err = version.NewVersion(cred.MinShepherdVersion); err != nil { - return fmt.Errorf("invalid min-shepherd-version: %w", err) - } - } - - if err := conf.UpdateConfigFile(*cred); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - - return nil -} diff --git a/client/logs/scoreboard.go b/client/logs/scoreboard.go index a73117c..c2f4d25 100644 --- a/client/logs/scoreboard.go +++ b/client/logs/scoreboard.go @@ -1,16 +1,11 @@ package logs import ( - "context" "fmt" "sync" "text/tabwriter" ) -type key int - -const scoreboardKey key = 0 - // Scoreboard holds retry statistics type Scoreboard struct { mu sync.Mutex @@ -73,16 +68,3 @@ func (s *Scoreboard) PrintSB() { fmt.Fprintf(w, "TOTAL\t%d\n", total) w.Flush() } - -// Context helpers — so you don't have to pass scoreboard around - -func NewSBContext(parent context.Context, sb *Scoreboard) context.Context { - return context.WithValue(parent, scoreboardKey, sb) -} - -func FromSBContext(ctx context.Context) (*Scoreboard, error) { - if sb, ok := ctx.Value(scoreboardKey).(*Scoreboard); ok { - return sb, nil - } - return nil, fmt.Errorf("Scoreboard is not of type Scoreboard") -} diff --git a/client/logs/tee_logger.go b/client/logs/tee_logger.go index bd78d8a..08a6d0f 100644 --- a/client/logs/tee_logger.go +++ b/client/logs/tee_logger.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" // Added for standard logging methods like Fatal + "maps" "os" "sync" @@ -93,9 +94,8 @@ func (t *TeeLogger) GetSucceededLogMap() map[string]string { defer t.succeededMu.Unlock() // Return a copy to prevent external modification copiedMap := make(map[string]string, len(t.succeededMap)) - for k, v := range t.succeededMap { - copiedMap[k] = v - } + maps.Copy(copiedMap, t.succeededMap) + return copiedMap } @@ -105,9 +105,7 @@ func (t *TeeLogger) GetFailedLogMap() map[string]common.RetryObject { defer t.failedMu.Unlock() // Return a copy to prevent external modification copiedMap := make(map[string]common.RetryObject, len(t.FailedMap)) - for k, v := range t.FailedMap { - copiedMap[k] = v - } + maps.Copy(copiedMap, t.FailedMap) return copiedMap } diff --git a/client/mocks/mock_configure.go b/client/mocks/mock_configure.go index 697c3da..4ff0813 100644 --- a/client/mocks/mock_configure.go +++ b/client/mocks/mock_configure.go @@ -1,145 +1,114 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/jwt (interfaces: ConfigureInterface) +// Source: github.com/calypr/data-client/client/conf (interfaces: ManagerInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/jwt ConfigureInterface +// mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/conf ManagerInterface // // Package mocks is a generated GoMock package. package mocks import ( - url "net/url" reflect "reflect" - jwt "github.com/calypr/data-client/client/jwt" + conf "github.com/calypr/data-client/client/conf" gomock "go.uber.org/mock/gomock" ) -// MockConfigureInterface is a mock of ConfigureInterface interface. -type MockConfigureInterface struct { +// MockManagerInterface is a mock of ManagerInterface interface. +type MockManagerInterface struct { ctrl *gomock.Controller - recorder *MockConfigureInterfaceMockRecorder + recorder *MockManagerInterfaceMockRecorder isgomock struct{} } -// MockConfigureInterfaceMockRecorder is the mock recorder for MockConfigureInterface. -type MockConfigureInterfaceMockRecorder struct { - mock *MockConfigureInterface +// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface. +type MockManagerInterfaceMockRecorder struct { + mock *MockManagerInterface } -// NewMockConfigureInterface creates a new mock instance. -func NewMockConfigureInterface(ctrl *gomock.Controller) *MockConfigureInterface { - mock := &MockConfigureInterface{ctrl: ctrl} - mock.recorder = &MockConfigureInterfaceMockRecorder{mock} +// NewMockManagerInterface creates a new mock instance. +func NewMockManagerInterface(ctrl *gomock.Controller) *MockManagerInterface { + mock := &MockManagerInterface{ctrl: ctrl} + mock.recorder = &MockManagerInterfaceMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConfigureInterface) EXPECT() *MockConfigureInterfaceMockRecorder { +func (m *MockManagerInterface) EXPECT() *MockManagerInterfaceMockRecorder { return m.recorder } -// GetConfigPath mocks base method. -func (m *MockConfigureInterface) GetConfigPath() (string, error) { +// EnsureExists mocks base method. +func (m *MockManagerInterface) EnsureExists() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetConfigPath") - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "EnsureExists") + ret0, _ := ret[0].(error) + return ret0 } -// GetConfigPath indicates an expected call of GetConfigPath. -func (mr *MockConfigureInterfaceMockRecorder) GetConfigPath() *gomock.Call { +// EnsureExists indicates an expected call of EnsureExists. +func (mr *MockManagerInterfaceMockRecorder) EnsureExists() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfigPath", reflect.TypeOf((*MockConfigureInterface)(nil).GetConfigPath)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureExists", reflect.TypeOf((*MockManagerInterface)(nil).EnsureExists)) } -// IsValidCredential mocks base method. -func (m *MockConfigureInterface) IsValidCredential(arg0 jwt.Credential) (bool, error) { +// Import mocks base method. +func (m *MockManagerInterface) Import(filePath, fenceToken string) (*conf.Credential, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsValidCredential", arg0) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Import", filePath, fenceToken) + ret0, _ := ret[0].(*conf.Credential) ret1, _ := ret[1].(error) return ret0, ret1 } -// IsValidCredential indicates an expected call of IsValidCredential. -func (mr *MockConfigureInterfaceMockRecorder) IsValidCredential(arg0 any) *gomock.Call { +// Import indicates an expected call of Import. +func (mr *MockManagerInterfaceMockRecorder) Import(filePath, fenceToken any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsValidCredential", reflect.TypeOf((*MockConfigureInterface)(nil).IsValidCredential), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Import", reflect.TypeOf((*MockManagerInterface)(nil).Import), filePath, fenceToken) } -// ParseConfig mocks base method. -func (m *MockConfigureInterface) ParseConfig(profile string) (jwt.Credential, error) { +// IsValid mocks base method. +func (m *MockManagerInterface) IsValid(arg0 *conf.Credential) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ParseConfig", profile) - ret0, _ := ret[0].(jwt.Credential) + ret := m.ctrl.Call(m, "IsValid", arg0) + ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// ParseConfig indicates an expected call of ParseConfig. -func (mr *MockConfigureInterfaceMockRecorder) ParseConfig(profile any) *gomock.Call { +// IsValid indicates an expected call of IsValid. +func (mr *MockManagerInterfaceMockRecorder) IsValid(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseConfig", reflect.TypeOf((*MockConfigureInterface)(nil).ParseConfig), profile) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsValid", reflect.TypeOf((*MockManagerInterface)(nil).IsValid), arg0) } -// ParseKeyValue mocks base method. -func (m *MockConfigureInterface) ParseKeyValue(str, expr string) (string, error) { +// Load mocks base method. +func (m *MockManagerInterface) Load(profile string) (*conf.Credential, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ParseKeyValue", str, expr) - ret0, _ := ret[0].(string) + ret := m.ctrl.Call(m, "Load", profile) + ret0, _ := ret[0].(*conf.Credential) ret1, _ := ret[1].(error) return ret0, ret1 } -// ParseKeyValue indicates an expected call of ParseKeyValue. -func (mr *MockConfigureInterfaceMockRecorder) ParseKeyValue(str, expr any) *gomock.Call { +// Load indicates an expected call of Load. +func (mr *MockManagerInterfaceMockRecorder) Load(profile any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseKeyValue", reflect.TypeOf((*MockConfigureInterface)(nil).ParseKeyValue), str, expr) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockManagerInterface)(nil).Load), profile) } -// ReadFile mocks base method. -func (m *MockConfigureInterface) ReadFile(arg0, arg1 string) string { +// Save mocks base method. +func (m *MockManagerInterface) Save(cred *conf.Credential) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadFile", arg0, arg1) - ret0, _ := ret[0].(string) - return ret0 -} - -// ReadFile indicates an expected call of ReadFile. -func (mr *MockConfigureInterfaceMockRecorder) ReadFile(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFile", reflect.TypeOf((*MockConfigureInterface)(nil).ReadFile), arg0, arg1) -} - -// UpdateConfigFile mocks base method. -func (m *MockConfigureInterface) UpdateConfigFile(arg0 jwt.Credential) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateConfigFile", arg0) + ret := m.ctrl.Call(m, "Save", cred) ret0, _ := ret[0].(error) return ret0 } -// UpdateConfigFile indicates an expected call of UpdateConfigFile. -func (mr *MockConfigureInterfaceMockRecorder) UpdateConfigFile(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConfigFile", reflect.TypeOf((*MockConfigureInterface)(nil).UpdateConfigFile), arg0) -} - -// ValidateUrl mocks base method. -func (m *MockConfigureInterface) ValidateUrl(arg0 string) (*url.URL, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateUrl", arg0) - ret0, _ := ret[0].(*url.URL) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ValidateUrl indicates an expected call of ValidateUrl. -func (mr *MockConfigureInterfaceMockRecorder) ValidateUrl(arg0 any) *gomock.Call { +// Save indicates an expected call of Save. +func (mr *MockManagerInterfaceMockRecorder) Save(cred any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUrl", reflect.TypeOf((*MockConfigureInterface)(nil).ValidateUrl), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockManagerInterface)(nil).Save), cred) } diff --git a/client/mocks/mock_functions.go b/client/mocks/mock_functions.go index 6c48765..de3b1bc 100644 --- a/client/mocks/mock_functions.go +++ b/client/mocks/mock_functions.go @@ -1,20 +1,22 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/jwt (interfaces: FunctionInterface) +// Source: github.com/calypr/data-client/client/api (interfaces: FunctionInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/jwt FunctionInterface +// mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/api FunctionInterface // // Package mocks is a generated GoMock package. package mocks import ( + context "context" http "net/http" - url "net/url" reflect "reflect" - jwt "github.com/calypr/data-client/client/jwt" + api "github.com/calypr/data-client/client/api" + conf "github.com/calypr/data-client/client/conf" + request "github.com/calypr/data-client/client/request" gomock "go.uber.org/mock/gomock" ) @@ -43,87 +45,113 @@ func (m *MockFunctionInterface) EXPECT() *MockFunctionInterfaceMockRecorder { } // CheckForShepherdAPI mocks base method. -func (m *MockFunctionInterface) CheckForShepherdAPI(profileConfig *jwt.Credential) (bool, error) { +func (m *MockFunctionInterface) CheckForShepherdAPI(ctx context.Context) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckForShepherdAPI", profileConfig) + ret := m.ctrl.Call(m, "CheckForShepherdAPI", ctx) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // CheckForShepherdAPI indicates an expected call of CheckForShepherdAPI. -func (mr *MockFunctionInterfaceMockRecorder) CheckForShepherdAPI(profileConfig any) *gomock.Call { +func (mr *MockFunctionInterfaceMockRecorder) CheckForShepherdAPI(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockFunctionInterface)(nil).CheckForShepherdAPI), profileConfig) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockFunctionInterface)(nil).CheckForShepherdAPI), ctx) } // CheckPrivileges mocks base method. -func (m *MockFunctionInterface) CheckPrivileges(profileConfig *jwt.Credential) (string, map[string]any, error) { +func (m *MockFunctionInterface) CheckPrivileges(ctx context.Context) (map[string]any, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckPrivileges", profileConfig) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(map[string]any) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret := m.ctrl.Call(m, "CheckPrivileges", ctx) + ret0, _ := ret[0].(map[string]any) + ret1, _ := ret[1].(error) + return ret0, ret1 } // CheckPrivileges indicates an expected call of CheckPrivileges. -func (mr *MockFunctionInterfaceMockRecorder) CheckPrivileges(profileConfig any) *gomock.Call { +func (mr *MockFunctionInterfaceMockRecorder) CheckPrivileges(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockFunctionInterface)(nil).CheckPrivileges), profileConfig) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockFunctionInterface)(nil).CheckPrivileges), ctx) } -// DoRequestWithSignedHeader mocks base method. -func (m *MockFunctionInterface) DoRequestWithSignedHeader(profileConfig *jwt.Credential, endpointPostPrefix, contentType string, bodyBytes []byte) (jwt.JsonMessage, error) { +// DeleteRecord mocks base method. +func (m *MockFunctionInterface) DeleteRecord(ctx context.Context, guid string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoRequestWithSignedHeader", profileConfig, endpointPostPrefix, contentType, bodyBytes) - ret0, _ := ret[0].(jwt.JsonMessage) + ret := m.ctrl.Call(m, "DeleteRecord", ctx, guid) + ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } -// DoRequestWithSignedHeader indicates an expected call of DoRequestWithSignedHeader. -func (mr *MockFunctionInterfaceMockRecorder) DoRequestWithSignedHeader(profileConfig, endpointPostPrefix, contentType, bodyBytes any) *gomock.Call { +// DeleteRecord indicates an expected call of DeleteRecord. +func (mr *MockFunctionInterfaceMockRecorder) DeleteRecord(ctx, guid any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoRequestWithSignedHeader", reflect.TypeOf((*MockFunctionInterface)(nil).DoRequestWithSignedHeader), profileConfig, endpointPostPrefix, contentType, bodyBytes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecord", reflect.TypeOf((*MockFunctionInterface)(nil).DeleteRecord), ctx, guid) } -// GetHost mocks base method. -func (m *MockFunctionInterface) GetHost(profileConfig *jwt.Credential) (*url.URL, error) { +// Do mocks base method. +func (m *MockFunctionInterface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHost", profileConfig) - ret0, _ := ret[0].(*url.URL) + ret := m.ctrl.Call(m, "Do", ctx, req) + ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetHost indicates an expected call of GetHost. -func (mr *MockFunctionInterfaceMockRecorder) GetHost(profileConfig any) *gomock.Call { +// Do indicates an expected call of Do. +func (mr *MockFunctionInterfaceMockRecorder) Do(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockFunctionInterface)(nil).Do), ctx, req) +} + +// ExportCredential mocks base method. +func (m *MockFunctionInterface) ExportCredential(ctx context.Context, cred *conf.Credential) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExportCredential", ctx, cred) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExportCredential indicates an expected call of ExportCredential. +func (mr *MockFunctionInterfaceMockRecorder) ExportCredential(ctx, cred any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHost", reflect.TypeOf((*MockFunctionInterface)(nil).GetHost), profileConfig) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportCredential", reflect.TypeOf((*MockFunctionInterface)(nil).ExportCredential), ctx, cred) } -// GetResponse mocks base method. -func (m *MockFunctionInterface) GetResponse(profileConfig *jwt.Credential, endpointPostPrefix, method, contentType string, bodyBytes []byte) (string, *http.Response, error) { +// GetPresignedUrl mocks base method. +func (m *MockFunctionInterface) GetPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResponse", profileConfig, endpointPostPrefix, method, contentType, bodyBytes) + ret := m.ctrl.Call(m, "GetPresignedUrl", ctx, guid, protocolText) ret0, _ := ret[0].(string) - ret1, _ := ret[1].(*http.Response) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPresignedUrl indicates an expected call of GetPresignedUrl. +func (mr *MockFunctionInterfaceMockRecorder) GetPresignedUrl(ctx, guid, protocolText any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPresignedUrl", reflect.TypeOf((*MockFunctionInterface)(nil).GetPresignedUrl), ctx, guid, protocolText) +} + +// New mocks base method. +func (m *MockFunctionInterface) New(method, url string) *request.RequestBuilder { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "New", method, url) + ret0, _ := ret[0].(*request.RequestBuilder) + return ret0 } -// GetResponse indicates an expected call of GetResponse. -func (mr *MockFunctionInterfaceMockRecorder) GetResponse(profileConfig, endpointPostPrefix, method, contentType, bodyBytes any) *gomock.Call { +// New indicates an expected call of New. +func (mr *MockFunctionInterfaceMockRecorder) New(method, url any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponse", reflect.TypeOf((*MockFunctionInterface)(nil).GetResponse), profileConfig, endpointPostPrefix, method, contentType, bodyBytes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockFunctionInterface)(nil).New), method, url) } // ParseFenceURLResponse mocks base method. -func (m *MockFunctionInterface) ParseFenceURLResponse(resp *http.Response) (jwt.JsonMessage, error) { +func (m *MockFunctionInterface) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ParseFenceURLResponse", resp) - ret0, _ := ret[0].(jwt.JsonMessage) + ret0, _ := ret[0].(api.FenceResponse) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/client/mocks/mock_gen3interface.go b/client/mocks/mock_gen3interface.go index 44dd849..99f3f25 100644 --- a/client/mocks/mock_gen3interface.go +++ b/client/mocks/mock_gen3interface.go @@ -1,22 +1,23 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/gen3Client (interfaces: Gen3Interface) +// Source: github.com/calypr/data-client/client/client (interfaces: Gen3Interface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/gen3Client Gen3Interface +// mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/client Gen3Interface // // Package mocks is a generated GoMock package. package mocks import ( - bytes "bytes" + context "context" http "net/http" - url "net/url" reflect "reflect" - jwt "github.com/calypr/data-client/client/jwt" + api "github.com/calypr/data-client/client/api" + conf "github.com/calypr/data-client/client/conf" logs "github.com/calypr/data-client/client/logs" + request "github.com/calypr/data-client/client/request" gomock "go.uber.org/mock/gomock" ) @@ -45,109 +46,106 @@ func (m *MockGen3Interface) EXPECT() *MockGen3InterfaceMockRecorder { } // CheckForShepherdAPI mocks base method. -func (m *MockGen3Interface) CheckForShepherdAPI() (bool, error) { +func (m *MockGen3Interface) CheckForShepherdAPI(ctx context.Context) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckForShepherdAPI") + ret := m.ctrl.Call(m, "CheckForShepherdAPI", ctx) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // CheckForShepherdAPI indicates an expected call of CheckForShepherdAPI. -func (mr *MockGen3InterfaceMockRecorder) CheckForShepherdAPI() *gomock.Call { +func (mr *MockGen3InterfaceMockRecorder) CheckForShepherdAPI(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockGen3Interface)(nil).CheckForShepherdAPI)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockGen3Interface)(nil).CheckForShepherdAPI), ctx) } // CheckPrivileges mocks base method. -func (m *MockGen3Interface) CheckPrivileges() (string, map[string]any, error) { +func (m *MockGen3Interface) CheckPrivileges(ctx context.Context) (map[string]any, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckPrivileges") - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(map[string]any) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret := m.ctrl.Call(m, "CheckPrivileges", ctx) + ret0, _ := ret[0].(map[string]any) + ret1, _ := ret[1].(error) + return ret0, ret1 } // CheckPrivileges indicates an expected call of CheckPrivileges. -func (mr *MockGen3InterfaceMockRecorder) CheckPrivileges() *gomock.Call { +func (mr *MockGen3InterfaceMockRecorder) CheckPrivileges(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockGen3Interface)(nil).CheckPrivileges)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockGen3Interface)(nil).CheckPrivileges), ctx) } // DeleteRecord mocks base method. -func (m *MockGen3Interface) DeleteRecord(guid string) (string, error) { +func (m *MockGen3Interface) DeleteRecord(ctx context.Context, guid string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRecord", guid) + ret := m.ctrl.Call(m, "DeleteRecord", ctx, guid) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // DeleteRecord indicates an expected call of DeleteRecord. -func (mr *MockGen3InterfaceMockRecorder) DeleteRecord(guid any) *gomock.Call { +func (mr *MockGen3InterfaceMockRecorder) DeleteRecord(ctx, guid any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecord", reflect.TypeOf((*MockGen3Interface)(nil).DeleteRecord), guid) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecord", reflect.TypeOf((*MockGen3Interface)(nil).DeleteRecord), ctx, guid) } -// DoRequestWithSignedHeader mocks base method. -func (m *MockGen3Interface) DoRequestWithSignedHeader(endpointPostPrefix, contentType string, bodyBytes []byte) (jwt.JsonMessage, error) { +// Do mocks base method. +func (m *MockGen3Interface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoRequestWithSignedHeader", endpointPostPrefix, contentType, bodyBytes) - ret0, _ := ret[0].(jwt.JsonMessage) + ret := m.ctrl.Call(m, "Do", ctx, req) + ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } -// DoRequestWithSignedHeader indicates an expected call of DoRequestWithSignedHeader. -func (mr *MockGen3InterfaceMockRecorder) DoRequestWithSignedHeader(endpointPostPrefix, contentType, bodyBytes any) *gomock.Call { +// Do indicates an expected call of Do. +func (mr *MockGen3InterfaceMockRecorder) Do(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoRequestWithSignedHeader", reflect.TypeOf((*MockGen3Interface)(nil).DoRequestWithSignedHeader), endpointPostPrefix, contentType, bodyBytes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockGen3Interface)(nil).Do), ctx, req) } -// GetCredential mocks base method. -func (m *MockGen3Interface) GetCredential() *jwt.Credential { +// ExportCredential mocks base method. +func (m *MockGen3Interface) ExportCredential(ctx context.Context, cred *conf.Credential) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCredential") - ret0, _ := ret[0].(*jwt.Credential) + ret := m.ctrl.Call(m, "ExportCredential", ctx, cred) + ret0, _ := ret[0].(error) return ret0 } -// GetCredential indicates an expected call of GetCredential. -func (mr *MockGen3InterfaceMockRecorder) GetCredential() *gomock.Call { +// ExportCredential indicates an expected call of ExportCredential. +func (mr *MockGen3InterfaceMockRecorder) ExportCredential(ctx, cred any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockGen3Interface)(nil).GetCredential)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportCredential", reflect.TypeOf((*MockGen3Interface)(nil).ExportCredential), ctx, cred) } -// GetHost mocks base method. -func (m *MockGen3Interface) GetHost() (*url.URL, error) { +// GetCredential mocks base method. +func (m *MockGen3Interface) GetCredential() *conf.Credential { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHost") - ret0, _ := ret[0].(*url.URL) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "GetCredential") + ret0, _ := ret[0].(*conf.Credential) + return ret0 } -// GetHost indicates an expected call of GetHost. -func (mr *MockGen3InterfaceMockRecorder) GetHost() *gomock.Call { +// GetCredential indicates an expected call of GetCredential. +func (mr *MockGen3InterfaceMockRecorder) GetCredential() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHost", reflect.TypeOf((*MockGen3Interface)(nil).GetHost)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockGen3Interface)(nil).GetCredential)) } -// GetResponse mocks base method. -func (m *MockGen3Interface) GetResponse(endpointPostPrefix, method, contentType string, bodyBytes []byte) (string, *http.Response, error) { +// GetPresignedUrl mocks base method. +func (m *MockGen3Interface) GetPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResponse", endpointPostPrefix, method, contentType, bodyBytes) + ret := m.ctrl.Call(m, "GetPresignedUrl", ctx, guid, protocolText) ret0, _ := ret[0].(string) - ret1, _ := ret[1].(*http.Response) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } -// GetResponse indicates an expected call of GetResponse. -func (mr *MockGen3InterfaceMockRecorder) GetResponse(endpointPostPrefix, method, contentType, bodyBytes any) *gomock.Call { +// GetPresignedUrl indicates an expected call of GetPresignedUrl. +func (mr *MockGen3InterfaceMockRecorder) GetPresignedUrl(ctx, guid, protocolText any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponse", reflect.TypeOf((*MockGen3Interface)(nil).GetResponse), endpointPostPrefix, method, contentType, bodyBytes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPresignedUrl", reflect.TypeOf((*MockGen3Interface)(nil).GetPresignedUrl), ctx, guid, protocolText) } // Logger mocks base method. @@ -164,17 +162,31 @@ func (mr *MockGen3InterfaceMockRecorder) Logger() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockGen3Interface)(nil).Logger)) } -// MakeARequest mocks base method. -func (m *MockGen3Interface) MakeARequest(method, apiEndpoint, accessToken, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) { +// New mocks base method. +func (m *MockGen3Interface) New(method, url string) *request.RequestBuilder { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeARequest", method, apiEndpoint, accessToken, contentType, headers, body, noTimeout) - ret0, _ := ret[0].(*http.Response) + ret := m.ctrl.Call(m, "New", method, url) + ret0, _ := ret[0].(*request.RequestBuilder) + return ret0 +} + +// New indicates an expected call of New. +func (mr *MockGen3InterfaceMockRecorder) New(method, url any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockGen3Interface)(nil).New), method, url) +} + +// ParseFenceURLResponse mocks base method. +func (m *MockGen3Interface) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseFenceURLResponse", resp) + ret0, _ := ret[0].(api.FenceResponse) ret1, _ := ret[1].(error) return ret0, ret1 } -// MakeARequest indicates an expected call of MakeARequest. -func (mr *MockGen3InterfaceMockRecorder) MakeARequest(method, apiEndpoint, accessToken, contentType, headers, body, noTimeout any) *gomock.Call { +// ParseFenceURLResponse indicates an expected call of ParseFenceURLResponse. +func (mr *MockGen3InterfaceMockRecorder) ParseFenceURLResponse(resp any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeARequest", reflect.TypeOf((*MockGen3Interface)(nil).MakeARequest), method, apiEndpoint, accessToken, contentType, headers, body, noTimeout) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseFenceURLResponse", reflect.TypeOf((*MockGen3Interface)(nil).ParseFenceURLResponse), resp) } diff --git a/client/mocks/mock_request.go b/client/mocks/mock_request.go index 74f87de..1021d18 100644 --- a/client/mocks/mock_request.go +++ b/client/mocks/mock_request.go @@ -1,21 +1,20 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/jwt (interfaces: RequestInterface) +// Source: github.com/calypr/data-client/client/request (interfaces: RequestInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/jwt RequestInterface +// mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/request RequestInterface // // Package mocks is a generated GoMock package. package mocks import ( - bytes "bytes" + context "context" http "net/http" reflect "reflect" - jwt "github.com/calypr/data-client/client/jwt" - logs "github.com/calypr/data-client/client/logs" + request "github.com/calypr/data-client/client/request" gomock "go.uber.org/mock/gomock" ) @@ -43,45 +42,31 @@ func (m *MockRequestInterface) EXPECT() *MockRequestInterfaceMockRecorder { return m.recorder } -// Logger mocks base method. -func (m *MockRequestInterface) Logger() logs.Logger { +// Do mocks base method. +func (m *MockRequestInterface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Logger") - ret0, _ := ret[0].(logs.Logger) - return ret0 -} - -// Logger indicates an expected call of Logger. -func (mr *MockRequestInterfaceMockRecorder) Logger() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockRequestInterface)(nil).Logger)) -} - -// MakeARequest mocks base method. -func (m *MockRequestInterface) MakeARequest(method, apiEndpoint, accessToken, contentType string, headers map[string]string, body *bytes.Buffer, noTimeout bool) (*http.Response, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeARequest", method, apiEndpoint, accessToken, contentType, headers, body, noTimeout) + ret := m.ctrl.Call(m, "Do", ctx, req) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } -// MakeARequest indicates an expected call of MakeARequest. -func (mr *MockRequestInterfaceMockRecorder) MakeARequest(method, apiEndpoint, accessToken, contentType, headers, body, noTimeout any) *gomock.Call { +// Do indicates an expected call of Do. +func (mr *MockRequestInterfaceMockRecorder) Do(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeARequest", reflect.TypeOf((*MockRequestInterface)(nil).MakeARequest), method, apiEndpoint, accessToken, contentType, headers, body, noTimeout) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockRequestInterface)(nil).Do), ctx, req) } -// RequestNewAccessToken mocks base method. -func (m *MockRequestInterface) RequestNewAccessToken(accessTokenEndpoint string, profileConfig *jwt.Credential) error { +// New mocks base method. +func (m *MockRequestInterface) New(method, url string) *request.RequestBuilder { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RequestNewAccessToken", accessTokenEndpoint, profileConfig) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "New", method, url) + ret0, _ := ret[0].(*request.RequestBuilder) return ret0 } -// RequestNewAccessToken indicates an expected call of RequestNewAccessToken. -func (mr *MockRequestInterfaceMockRecorder) RequestNewAccessToken(accessTokenEndpoint, profileConfig any) *gomock.Call { +// New indicates an expected call of New. +func (mr *MockRequestInterfaceMockRecorder) New(method, url any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNewAccessToken", reflect.TypeOf((*MockRequestInterface)(nil).RequestNewAccessToken), accessTokenEndpoint, profileConfig) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockRequestInterface)(nil).New), method, url) } diff --git a/client/request/auth.go b/client/request/auth.go new file mode 100644 index 0000000..7d08f65 --- /dev/null +++ b/client/request/auth.go @@ -0,0 +1,103 @@ +package request + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "sync" + + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/conf" +) + +func (t *AuthTransport) NewAccessToken(ctx context.Context) error { + if t.Cred.APIKey == "" { + return errors.New("APIKey is required to refresh access token") + } + + refreshClient := &http.Client{Transport: t.Base} + + payload := map[string]string{"api_key": t.Cred.APIKey} + reader, err := common.ToJSONReader(payload) + if err != nil { + return err + } + + refreshUrl := t.Cred.APIEndpoint + common.FenceAccessTokenEndpoint + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshUrl, reader) + if err != nil { + return err + } + req.Header.Set(common.HeaderContentType, common.MIMEApplicationJSON) + + resp, err := refreshClient.Do(req) + if err != nil { + return fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("failed to refresh token, status: " + strconv.Itoa(resp.StatusCode)) + } + + var result common.AccessTokenStruct + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + + t.mu.Lock() + t.Cred.AccessToken = result.AccessToken + if t.Manager != nil { + t.Manager.Save(t.Cred) + } + t.mu.Unlock() + return nil +} + +type AuthTransport struct { + Manager conf.ManagerInterface + Base http.RoundTripper + Cred *conf.Credential + mu sync.RWMutex + refreshMu sync.Mutex +} + +func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + + resp, err := t.Base.RoundTrip(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadGateway { + resp.Body.Close() + + newToken, refreshErr := t.tryRefresh(req.Context()) + if refreshErr != nil { + return nil, refreshErr + } + + retryReq := req.Clone(req.Context()) + retryReq.Header.Set("Authorization", "Bearer "+newToken) + return t.Base.RoundTrip(retryReq) + } + + return resp, nil +} + +func (t *AuthTransport) tryRefresh(ctx context.Context) (string, error) { + // Only one goroutine can enter this block + t.refreshMu.Lock() + defer t.refreshMu.Unlock() + + if err := t.NewAccessToken(ctx); err != nil { + return "", err + } + + t.mu.RLock() + defer t.mu.RUnlock() + return t.Cred.AccessToken, nil +} diff --git a/client/request/builder.go b/client/request/builder.go new file mode 100644 index 0000000..1280fb7 --- /dev/null +++ b/client/request/builder.go @@ -0,0 +1,54 @@ +package request + +import ( + "io" + + "github.com/calypr/data-client/client/common" +) + +// New addition to your request package +type RequestBuilder struct { + //Req *Request // the underlying retry client holder + Method string + Url string + Body io.Reader // store as []byte for easy reuse + Headers map[string]string + Token string + PartSize int64 +} + +func (r *Request) New(method, url string) *RequestBuilder { + return &RequestBuilder{ + //Req: r, + Method: method, + Url: url, + Headers: make(map[string]string), + } +} + +func (ar *RequestBuilder) WithToken(token string) *RequestBuilder { + ar.Token = token + return ar +} + +func (ar *RequestBuilder) WithJSONBody(v any) (*RequestBuilder, error) { + reader, err := common.ToJSONReader(v) + if err != nil { + return nil, err + } + + ar.Body = reader + ar.Headers[common.HeaderContentType] = common.MIMEApplicationJSON + return ar, nil + +} + +func (ar *RequestBuilder) WithBody(body io.Reader) *RequestBuilder { + ar.Body = body + return ar +} + +func (ar *RequestBuilder) WithHeader(key, value string) *RequestBuilder { + ar.Headers[key] = value + return ar +} diff --git a/client/request/request.go b/client/request/request.go new file mode 100644 index 0000000..a1076e7 --- /dev/null +++ b/client/request/request.go @@ -0,0 +1,96 @@ +package request + +//go:generate mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/request RequestInterface + +import ( + "context" + "errors" + "net" + "net/http" + "time" + + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/logs" + "github.com/hashicorp/go-retryablehttp" +) + +type Request struct { + Logs logs.Logger + Ctx context.Context + RetryClient *retryablehttp.Client +} + +type RequestInterface interface { + New(method, url string) *RequestBuilder + Do(ctx context.Context, req *RequestBuilder) (*http.Response, error) +} + +func NewRequestInterface( + logger logs.Logger, + cred *conf.Credential, + conf conf.ManagerInterface, +) RequestInterface { + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = 3 + retryClient.Logger = logger + + baseTransport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + TLSHandshakeTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } + + authTransport := &AuthTransport{ + Base: baseTransport, + Cred: cred, + Manager: conf, + } + + retryClient.HTTPClient = &http.Client{ + Timeout: 0, + Transport: authTransport, // The outer shell is now AuthTransport + } + + return &Request{ + RetryClient: retryClient, + Logs: logger, + } +} + +func (r *Request) Do(ctx context.Context, rb *RequestBuilder) (*http.Response, error) { + // Prepare body reader + + httpReq, err := http.NewRequestWithContext(ctx, rb.Method, rb.Url, rb.Body) + if err != nil { + return nil, errors.New("failed to create HTTP request: " + err.Error()) + } + + for key, value := range rb.Headers { + httpReq.Header.Add(key, value) + } + + if rb.Token != "" { + httpReq.Header.Set("Authorization", "Bearer "+rb.Token) + } + + if rb.PartSize != 0 { + httpReq.ContentLength = rb.PartSize + } + // Convert to retryablehttp.Request + retryReq, err := retryablehttp.FromRequest(httpReq) + if err != nil { + return nil, err + } + + resp, err := r.RetryClient.Do(retryReq) + if err != nil { + return resp, errors.New("request failed after retries: " + err.Error()) + } + + return resp, nil +} diff --git a/client/upload/batch.go b/client/upload/batch.go new file mode 100644 index 0000000..5e08af4 --- /dev/null +++ b/client/upload/batch.go @@ -0,0 +1,161 @@ +package upload + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "sync" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/request" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" +) + +func InitBatchUploadChannels(numParallel int, inputSliceLen int) (int, chan *http.Response, chan error, []common.FileUploadRequestObject) { + workers := numParallel + if workers < 1 || workers > inputSliceLen { + workers = inputSliceLen + } + if workers < 1 { + workers = 1 + } + + respCh := make(chan *http.Response, inputSliceLen) + errCh := make(chan error, inputSliceLen) + batchSlice := make([]common.FileUploadRequestObject, 0, workers) + + return workers, respCh, errCh, batchSlice +} + +func BatchUpload( + ctx context.Context, + g3i client.Gen3Interface, + furObjects []common.FileUploadRequestObject, + workers int, + respCh chan *http.Response, + errCh chan error, + bucketName string, +) { + if len(furObjects) == 0 { + return + } + + // Ensure bucket is set + for i := range furObjects { + if furObjects[i].Bucket == "" { + furObjects[i].Bucket = bucketName + } + } + + progress := mpb.New(mpb.WithOutput(os.Stdout)) + + workCh := make(chan common.FileUploadRequestObject, len(furObjects)) + + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for fur := range workCh { + // --- Ensure presigned URL --- + if fur.PresignedURL == "" { + resp, err := GeneratePresignedUploadURL(ctx, g3i, fur.Filename, fur.FileMetadata, fur.Bucket) + if err != nil { + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, "", 0, false) + errCh <- err + continue + } + fur.PresignedURL = resp.URL + fur.GUID = resp.GUID + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, resp.GUID, 0, false) // update log + } + + // --- Open file --- + file, err := os.Open(fur.FilePath) + if err != nil { + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + errCh <- fmt.Errorf("file open error: %w", err) + continue + } + + fi, err := file.Stat() + if err != nil { + file.Close() + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + errCh <- fmt.Errorf("file stat error: %w", err) + continue + } + + if fi.Size() > common.FileSizeLimit { + file.Close() + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + errCh <- fmt.Errorf("file size exceeds limit: %s", fur.Filename) + continue + } + + // --- Progress bar --- + bar := progress.AddBar(fi.Size(), + mpb.PrependDecorators( + decor.Name(fur.Filename+" "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), " % .1f"), + ), + ) + + proxyReader := bar.ProxyReader(file) + + // --- Upload using DoAuthenticatedRequest (no manual http.Request!) --- + resp, err := g3i.Do( + ctx, + &request.RequestBuilder{ + Method: http.MethodPut, + Url: fur.PresignedURL, + Body: proxyReader, + }, + ) + + // Cleanup + file.Close() + bar.Abort(false) + + if err != nil { + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + errCh <- err + continue + } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + errMsg := fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + errCh <- errMsg + continue + } + + resp.Body.Close() + + // Success + respCh <- resp + g3i.Logger().DeleteFromFailedLog(fur.FilePath) + g3i.Logger().Succeeded(fur.FilePath, fur.GUID) + g3i.Logger().Scoreboard().IncrementSB(0) + } + }() + } + + for _, obj := range furObjects { + workCh <- obj + } + close(workCh) + + wg.Wait() + progress.Wait() +} diff --git a/client/upload/multipart.go b/client/upload/multipart.go new file mode 100644 index 0000000..be97d69 --- /dev/null +++ b/client/upload/multipart.go @@ -0,0 +1,303 @@ +package upload + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "sort" + "strings" + "sync" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + req "github.com/calypr/data-client/client/request" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" +) + +func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, file *os.File, showProgress bool) error { + g3.Logger().Printf("File Upload Request: %#v\n", req) + + stat, err := file.Stat() + if err != nil { + return fmt.Errorf("cannot stat file: %w", err) + } + + fileSize := stat.Size() + if fileSize == 0 { + return fmt.Errorf("file is empty: %s", req.Filename) + } + + var p *mpb.Progress + var bar *mpb.Bar + if showProgress { + p = mpb.New(mpb.WithOutput(os.Stdout)) + bar = p.AddBar(fileSize, + mpb.PrependDecorators( + decor.Name(req.Filename+" "), + decor.CountersKibiByte("%.1f / %.1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), " % .1f"), + ), + ) + } + + // 1. Initialize multipart upload + uploadID, finalGUID, err := initMultipartUpload(ctx, g3, req, req.Bucket) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + // 2. Construct the S3 Key correctly + // Ensure finalGUID is not empty to avoid a leading slash + key := fmt.Sprintf("%s/%s", finalGUID, req.Filename) + g3.Logger().Printf("Initialized Upload: ID=%s, Key=%s\n", uploadID, key) + + optimalChunkSize := func(fSize int64) int64 { + if fSize <= 512*common.MB { + return 32 * common.MB + } + chunkSize := fSize / common.MaxMultipartParts + if chunkSize < common.MinChunkSize { + chunkSize = common.MinChunkSize + } + return ((chunkSize + common.MB - 1) / common.MB) * common.MB + } + + chunkSize := optimalChunkSize(fileSize) + numChunks := int((fileSize + chunkSize - 1) / chunkSize) + + chunks := make(chan int, numChunks) + for i := 1; i <= numChunks; i++ { + chunks <- i + } + close(chunks) + + var ( + wg sync.WaitGroup + mu sync.Mutex + parts []MultipartPartObject + uploadErrors []error + ) + + // 3. Worker logic + worker := func() { + defer wg.Done() + + for partNum := range chunks { + + offset := int64(partNum-1) * chunkSize + size := chunkSize + if offset+size > fileSize { + size = fileSize - offset + } + + // SectionReader implements io.Reader, io.ReaderAt, and io.Seeker + // It allows each worker to read its own segment without a shared buffer. + section := io.NewSectionReader(file, offset, size) + + url, err := generateMultipartPresignedURL(ctx, g3, key, uploadID, partNum, req.Bucket) + if err != nil { + mu.Lock() + uploadErrors = append(uploadErrors, fmt.Errorf("URL generation failed part %d: %w", partNum, err)) + mu.Unlock() + return + } + + // Perform the upload using the section directly + etag, err := uploadPart(ctx, url, section, size) + if err != nil { + mu.Lock() + uploadErrors = append(uploadErrors, fmt.Errorf("upload failed part %d: %w", partNum, err)) + mu.Unlock() + return + } + + mu.Lock() + parts = append(parts, MultipartPartObject{ + PartNumber: partNum, + ETag: etag, + }) + if bar != nil { + bar.IncrInt64(size) + } + mu.Unlock() + } + } + + // Launch workers + for range common.MaxConcurrentUploads { + wg.Add(1) + go worker() + } + wg.Wait() + + if p != nil { + p.Wait() + } + + if len(uploadErrors) > 0 { + return fmt.Errorf("multipart upload failed with %d errors: %v", len(uploadErrors), uploadErrors) + } + + // 5. Finalize the upload + sort.Slice(parts, func(i, j int) bool { + return parts[i].PartNumber < parts[j].PartNumber + }) + + if err := CompleteMultipartUpload(ctx, g3, key, uploadID, parts, req.Bucket); err != nil { + return fmt.Errorf("failed to complete multipart upload: %w", err) + } + + g3.Logger().Printf("Successfully uploaded %s to %s", req.Filename, key) + g3.Logger().Succeeded(req.FilePath, req.GUID) + return nil +} + +// InitMultipartUpload helps sending requests to FENCE to init a multipart upload +func initMultipartUpload(ctx context.Context, g3 client.Gen3Interface, furObject common.FileUploadRequestObject, bucketName string) (string, string, error) { + // Use Filename and GUID directly from the unified request object + + reader, err := common.ToJSONReader( + InitRequestObject{ + Filename: furObject.Filename, + Bucket: bucketName, + GUID: furObject.GUID, + }, + ) + + cred := g3.GetCredential() + resp, err := g3.Do( + ctx, + &req.RequestBuilder{ + Method: http.MethodPost, + Url: cred.APIEndpoint + common.FenceDataMultipartInitEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: reader, + Token: cred.AccessToken, + }, + ) + + if err != nil { + if strings.Contains(err.Error(), "404") { + return "", "", errors.New(err.Error() + "\nPlease check to ensure FENCE version is at 2.8.0 or beyond") + } + return "", "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) + } + + msg, err := g3.ParseFenceURLResponse(resp) + if err != nil { + return "", "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) + + } + + if msg.UploadID == "" || msg.GUID == "" { + return "", "", errors.New("unknown error has occurred during multipart upload initialization. Please check logs from Gen3 services") + } + return msg.UploadID, msg.GUID, err +} + +// GenerateMultipartPresignedURL helps sending requests to FENCE to get a presigned URL for a part during a multipart upload +func generateMultipartPresignedURL(ctx context.Context, g3 client.Gen3Interface, key string, uploadID string, partNumber int, bucketName string) (string, error) { + + reader, err := common.ToJSONReader( + MultipartUploadRequestObject{ + Key: key, + UploadID: uploadID, + PartNumber: partNumber, + Bucket: bucketName, + }, + ) + if err != nil { + return "", err + } + + cred := g3.GetCredential() + resp, err := g3.Do( + ctx, + &req.RequestBuilder{ + Url: cred.APIEndpoint + common.FenceDataMultipartUploadEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Method: http.MethodPost, + Body: reader, + Token: cred.AccessToken, + }, + ) + if err != nil { + return "", errors.New("Error has occurred during multipart upload presigned url generation, detailed error message: " + err.Error()) + } + + msg, err := g3.ParseFenceURLResponse(resp) + if err != nil { + return "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) + } + + if msg.PresignedURL == "" { + return "", errors.New("unknown error has occurred during multipart upload presigned url generation. Please check logs from Gen3 services") + } + return msg.PresignedURL, err +} + +// CompleteMultipartUpload helps sending requests to FENCE to complete a multipart upload +func CompleteMultipartUpload(ctx context.Context, g3 client.Gen3Interface, key string, uploadID string, parts []MultipartPartObject, bucketName string) error { + multipartCompleteObject := MultipartCompleteRequestObject{Key: key, UploadID: uploadID, Parts: parts, Bucket: bucketName} + + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(multipartCompleteObject) + if err != nil { + return errors.New("Error occurred during encoding multipart upload data: " + err.Error()) + } + + // TOOD: error check this, return resp information + cred := g3.GetCredential() + _, err = g3.Do( + ctx, + &req.RequestBuilder{ + Url: cred.APIEndpoint + common.FenceDataMultipartCompleteEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: &buf, + Method: http.MethodPost, + Token: cred.AccessToken, + }, + ) + if err != nil { + return errors.New("Error has occurred during completing multipart upload, detailed error message: " + err.Error()) + } + return nil +} + +// uploadPart now returns the ETag and error directly. +// It accepts a Context to allow for cancellation (e.g., if another part fails). +func uploadPart(ctx context.Context, url string, data io.Reader, partSize int64) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, data) + if err != nil { + return "", err + } + + req.ContentLength = partSize + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return "", fmt.Errorf("upload failed (%d): %s", resp.StatusCode, body) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return "", errors.New("no ETag returned") + } + + return strings.Trim(etag, `"`), nil +} diff --git a/client/upload/request.go b/client/upload/request.go new file mode 100644 index 0000000..80b7744 --- /dev/null +++ b/client/upload/request.go @@ -0,0 +1,125 @@ +package upload + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strings" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + req "github.com/calypr/data-client/client/request" + "github.com/vbauerster/mpb/v8" +) + +// GeneratePresignedURL handles both Shepherd and Fence fallback +func GeneratePresignedUploadURL(ctx context.Context, g3 client.Gen3Interface, filename string, metadata common.FileMetadata, bucket string) (*PresignedURLResponse, error) { + hasShepherd, err := g3.CheckForShepherdAPI(ctx) + if err != nil || !hasShepherd { + payload := map[string]string{ + "file_name": filename, + } + if bucket != "" { + payload["bucket"] = bucket + } + + buf, err := common.ToJSONReader(payload) + if err != nil { + return nil, err + } + + cred := g3.GetCredential() + resp, err := g3.Do( + ctx, + &req.RequestBuilder{ + Method: http.MethodPost, + Url: cred.APIEndpoint + common.FenceDataUploadEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: buf, + Token: cred.AccessToken, + }) + if err != nil { + return nil, err + } + msg, err := g3.ParseFenceURLResponse(resp) + return &PresignedURLResponse{msg.URL, msg.GUID}, err + } + + shepherdPayload := ShepherdInitRequestObject{ + Filename: filename, + Authz: ShepherdAuthz{ + Version: "0", ResourcePaths: metadata.Authz, + }, + Aliases: metadata.Aliases, + Metadata: metadata.Metadata, + } + + reader, err := common.ToJSONReader(shepherdPayload) + if err != nil { + return nil, err + } + + cred := g3.GetCredential() + r, err := g3.Do( + ctx, + &req.RequestBuilder{ + Url: cred.APIEndpoint + common.ShepherdEndpoint + "/objects", + Method: http.MethodPost, + Body: reader, + Token: cred.AccessToken, + }) + if err != nil || r.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("shepherd upload init failed") + } + + var res PresignedURLResponse + if err := json.NewDecoder(r.Body).Decode(&res); err != nil { + return nil, err + } + return &res, nil +} + +// GenerateUploadRequest helps preparing the HTTP request for upload and the progress bar for single part upload +func generateUploadRequest(ctx context.Context, g3 client.Gen3Interface, furObject common.FileUploadRequestObject, file *os.File, progress *mpb.Progress) (common.FileUploadRequestObject, error) { + if furObject.PresignedURL == "" { + endPointPostfix := common.FenceDataUploadEndpoint + "/" + furObject.GUID + "?file_name=" + url.QueryEscape(furObject.Filename) + + if furObject.Bucket != "" { + endPointPostfix += "&bucket=" + furObject.Bucket + } + cred := g3.GetCredential() + resp, err := g3.Do( + ctx, + &req.RequestBuilder{ + Url: cred.APIEndpoint + endPointPostfix, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Token: cred.AccessToken, + Method: http.MethodGet, + }, + ) + + msg, err := g3.ParseFenceURLResponse(resp) + if err != nil && !strings.Contains(err.Error(), "No GUID found") { + return furObject, errors.New("Upload error: " + err.Error()) + } + if msg.URL == "" { + return furObject, errors.New("Upload error: error in generating presigned URL for " + furObject.Filename) + } + furObject.PresignedURL = msg.URL + } + + fi, err := file.Stat() + if err != nil { + return furObject, errors.New("File stat error for file" + furObject.Filename + ", file may be missing or unreadable because of permissions.\n") + } + + if fi.Size() > common.FileSizeLimit { + return furObject, errors.New("The file size of file " + furObject.Filename + " exceeds the limit allowed and cannot be uploaded. The maximum allowed file size is " + FormatSize(common.FileSizeLimit) + ".\n") + } + + return furObject, err +} diff --git a/client/upload/retry.go b/client/upload/retry.go new file mode 100644 index 0000000..ca3779f --- /dev/null +++ b/client/upload/retry.go @@ -0,0 +1,171 @@ +package upload + +import ( + "context" + "os" + "path/filepath" + "time" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" +) + +// GetWaitTime calculates exponential backoff with cap +func GetWaitTime(retryCount int) time.Duration { + exp := 1 << retryCount // 2^retryCount + seconds := int64(exp) + if seconds > common.MaxWaitTime { + seconds = common.MaxWaitTime + } + return time.Duration(seconds) * time.Second +} + +// RetryFailedUploads re-uploads previously failed files with exponential backoff +func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap map[string]common.RetryObject) { + logger := g3.Logger() + if len(failedMap) == 0 { + logger.Println("No failed files to retry.") + return + } + + sb := logger.Scoreboard() + + logger.Printf("Starting retry-upload for %d failed Uploads", len(failedMap)) + retryChan := make(chan common.RetryObject, len(failedMap)) + + // Queue only non-already-succeeded files + for _, ro := range failedMap { + retryChan <- ro + } + + if len(retryChan) == 0 { + logger.Println("All previously failed files have since succeeded.") + return + } + + for ro := range retryChan { + ro.RetryCount++ + logger.Printf("#%d retry — %s\n", ro.RetryCount, ro.FilePath) + wait := GetWaitTime(ro.RetryCount) + logger.Printf("Waiting %.0f seconds before retry...\n", wait.Seconds()) + time.Sleep(wait) + + // Clean up old record if exists + if ro.GUID != "" { + if msg, err := g3.DeleteRecord( + ctx, + ro.GUID, + ); err == nil { + logger.Println(msg) + } + } + + file, err := os.Open(ro.FilePath) + if err != nil { + continue + } + + // Ensure filename is set + if ro.Filename == "" { + absPath, _ := common.GetAbsolutePath(ro.FilePath) + ro.Filename = filepath.Base(absPath) + } + + if ro.Multipart { + // Retry multipart + req := common.FileUploadRequestObject{ + FilePath: ro.FilePath, + Filename: ro.Filename, + GUID: ro.GUID, + FileMetadata: ro.FileMetadata, + Bucket: ro.Bucket, + } + err = MultipartUpload(ctx, g3, req, file, true) + if err == nil { + logger.Succeeded(ro.FilePath, req.GUID) + if sb != nil { + sb.IncrementSB(ro.RetryCount - 1) + } + continue + } + } else { + // Retry single-part + respObj, err := GeneratePresignedUploadURL(ctx, g3, ro.Filename, ro.FileMetadata, ro.Bucket) + if err != nil { + handleRetryFailure(ctx, g3, ro, retryChan, err) + continue + } + + file, err := os.Open(ro.FilePath) + if err != nil { + handleRetryFailure(ctx, g3, ro, retryChan, err) + continue + } + stat, _ := file.Stat() + file.Close() + + if stat.Size() > common.FileSizeLimit { + ro.Multipart = true + retryChan <- ro + continue + } + + fur := common.FileUploadRequestObject{ + FilePath: ro.FilePath, + Filename: ro.Filename, + FileMetadata: ro.FileMetadata, + GUID: respObj.GUID, + PresignedURL: respObj.URL, + } + + fur, err = generateUploadRequest(ctx, g3, fur, nil, nil) + if err != nil { + handleRetryFailure(ctx, g3, ro, retryChan, err) + continue + } + + err = UploadSingleFile(ctx, g3, fur, true) + if err == nil { + logger.Succeeded(ro.FilePath, fur.GUID) + if sb != nil { + sb.IncrementSB(ro.RetryCount - 1) + } + continue + } + } + + // On failure, requeue if retries remain + handleRetryFailure(ctx, g3, ro, retryChan, err) + } +} + +// handleRetryFailure logs failure and requeues if retries remain +func handleRetryFailure(ctx context.Context, g3 client.Gen3Interface, ro common.RetryObject, retryChan chan common.RetryObject, err error) { + logger := g3.Logger() + logger.Failed(ro.FilePath, ro.Filename, ro.FileMetadata, ro.GUID, ro.RetryCount, ro.Multipart) + if err != nil { + logger.Println("Retry error:", err) + } + + if ro.RetryCount < common.MaxRetryCount { + retryChan <- ro + return + } + + // Max retries reached — final cleanup + if ro.GUID != "" { + if msg, err := g3.DeleteRecord(ctx, ro.GUID); err == nil { + logger.Println("Cleaned up failed record:", msg) + } else { + logger.Println("Cleanup failed:", err) + } + } + + if sb := logger.Scoreboard(); sb != nil { + sb.IncrementSB(common.MaxRetryCount + 1) + } + + if len(retryChan) == 0 { + close(retryChan) + } +} diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go new file mode 100644 index 0000000..527fd4c --- /dev/null +++ b/client/upload/singleFile.go @@ -0,0 +1,97 @@ +package upload + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + + client "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" +) + +func UploadSingle(ctx context.Context, profile string, guid string, filePath string, bucketName string, enableLogs bool) error { + + logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog()) + if enableLogs { + logger, closer = logs.New( + profile, + logs.WithSucceededLog(), + logs.WithFailedLog(), + logs.WithScoreboard(), + logs.WithConsole(), + ) + } + defer closer() + + // Instantiate interface to Gen3 + g3i, err := client.NewGen3Interface( + profile, + logger, + ) + if err != nil { + return fmt.Errorf("failed to parse config on profile %s: %w", profile, err) + } + + filePaths, err := common.ParseFilePaths(filePath, false) + if len(filePaths) > 1 { + return errors.New("more than 1 file location has been found. Do not use \"*\" in file path or provide a folder as file path") + } + if err != nil { + return errors.New("file path parsing error: " + err.Error()) + } + if len(filePaths) == 1 { + filePath = filePaths[0] + } + filename := filepath.Base(filePath) + if _, err := os.Stat(filePath); os.IsNotExist(err) { + g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + return fmt.Errorf("[ERROR] The file you specified \"%s\" does not exist locally\n", filePath) + } + + file, err := os.Open(filePath) + if err != nil { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) + g3i.Logger().Println("File open error: " + err.Error()) + return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", filePath, err.Error()) + } + defer file.Close() + + furObject := common.FileUploadRequestObject{FilePath: filePath, Filename: filename, GUID: guid, Bucket: bucketName} + + furObject, err = generateUploadRequest(ctx, g3i, furObject, file, nil) + if err != nil { + file.Close() + g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) + return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) + } + jsonData, err := json.Marshal(furObject) + if err != nil { + return fmt.Errorf("failed to marshal furObject: %w", err) + } + + _, err = uploadPart(ctx, furObject.PresignedURL, bytes.NewReader(jsonData), int64(len(jsonData))) + if err != nil { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + return fmt.Errorf("[ERROR] Error uploading file %s: %s\n", filePath, err.Error()) + } else { + g3i.Logger().Scoreboard().IncrementSB(0) + } + g3i.Logger().Scoreboard().PrintSB() + return nil +} diff --git a/client/upload/types.go b/client/upload/types.go new file mode 100644 index 0000000..2ef2f29 --- /dev/null +++ b/client/upload/types.go @@ -0,0 +1,73 @@ +package upload + +import "github.com/calypr/data-client/client/common" + +type PresignedURLResponse struct { + GUID string `json:"guid"` + URL string `json:"upload_url"` +} + +type MultipartPartObject struct { + PartNumber int `json:"PartNumber"` + ETag string `json:"ETag"` +} + +type UploadConfig struct { + BucketName string + NumParallel int + ForceMultipart bool + IncludeSubDirName bool + HasMetadata bool + ShowProgress bool +} + +// InitRequestObject represents the payload that sends to FENCE for getting a singlepart upload presignedURL or init a multipart upload for new object file +type InitRequestObject struct { + Filename string `json:"file_name"` + Bucket string `json:"bucket,omitempty"` + GUID string `json:"guid,omitempty"` +} + +// ShepherdInitRequestObject represents the payload that sends to Shepherd for getting a singlepart upload presignedURL or init a multipart upload for new object file +type ShepherdInitRequestObject struct { + Filename string `json:"file_name"` + Authz ShepherdAuthz `json:"authz"` + Aliases []string `json:"aliases"` + // Metadata is an encoded JSON string of any arbitrary metadata the user wishes to upload. + Metadata map[string]any `json:"metadata"` +} +type ShepherdAuthz struct { + Version string `json:"version"` + ResourcePaths []string `json:"resource_paths"` +} + +// MultipartUploadRequestObject represents the payload that sends to FENCE for getting a presignedURL for a part +type MultipartUploadRequestObject struct { + Key string `json:"key"` + UploadID string `json:"uploadId"` + PartNumber int `json:"partNumber"` + Bucket string `json:"bucket,omitempty"` +} + +// MultipartCompleteRequestObject represents the payload that sends to FENCE for completeing a multipart upload +type MultipartCompleteRequestObject struct { + Key string `json:"key"` + UploadID string `json:"uploadId"` + Parts []MultipartPartObject `json:"parts"` + Bucket string `json:"bucket,omitempty"` +} + +// FileInfo is a helper struct for including subdirname as filename +type FileInfo struct { + FilePath string + Filename string + FileMetadata common.FileMetadata + ObjectId string +} + +// RenamedOrSkippedFileInfo is a helper struct for recording renamed or skipped files +type RenamedOrSkippedFileInfo struct { + GUID string + OldFilename string + NewFilename string +} diff --git a/client/upload/upload.go b/client/upload/upload.go new file mode 100644 index 0000000..b786164 --- /dev/null +++ b/client/upload/upload.go @@ -0,0 +1,125 @@ +package upload + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/request" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" +) + +// Upload is a unified catch-all function that automatically chooses between +// single-part and multipart upload based on file size. +func Upload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + g3.Logger().Printf("Processing Upload Request for: %s\n", req.FilePath) + + file, err := os.Open(req.FilePath) + if err != nil { + return fmt.Errorf("cannot open file %s: %w", req.FilePath, err) + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return fmt.Errorf("cannot stat file: %w", err) + } + + fileSize := stat.Size() + if fileSize == 0 { + return fmt.Errorf("file is empty: %s", req.Filename) + } + + // Use Single-Part if file is smaller than 5GB (or your defined limit) + if fileSize < 5*common.GB { + g3.Logger().Printf("File size %d bytes (< 5GB), performing single-part upload\n", fileSize) + UploadSingle(ctx, g3.GetCredential().Profile, req.GUID, req.FilePath, req.Bucket, true) + } + g3.Logger().Printf("File size %d bytes (>= 5GB), performing multipart upload\n", fileSize) + return MultipartUpload(ctx, g3, req, file, showProgress) +} + +func performSinglePartUpload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + // 1. Get the Presigned URL + respObj, err := GeneratePresignedUploadURL(ctx, g3, req.Filename, req.FileMetadata, req.Bucket) + if err != nil { + return fmt.Errorf("failed to generate single-part URL: %w", err) + } + + req.GUID = respObj.GUID + req.PresignedURL = respObj.URL + + // 2. Open file and setup progress + file, _ := os.Open(req.FilePath) + defer file.Close() + + var body io.Reader = file + var p *mpb.Progress + if showProgress { + p = mpb.New(mpb.WithOutput(os.Stdout)) + fi, _ := file.Stat() + bar := p.AddBar(fi.Size(), + mpb.PrependDecorators(decor.Name(req.Filename+" ")), + mpb.AppendDecorators(decor.Percentage()), + ) + body = bar.ProxyReader(file) + } + + resp, err := g3.Do(ctx, &request.RequestBuilder{ + Method: http.MethodPut, + Url: req.PresignedURL, + Body: body, + }) + + if p != nil { + p.Wait() + } + + if err != nil || resp.StatusCode != http.StatusOK { + return fmt.Errorf("single-part upload failed") + } + return nil +} + +// UploadSingleFile handles single-part upload with progress +func UploadSingleFile(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + file, err := os.Open(req.FilePath) + if err != nil { + return err + } + defer file.Close() + + fi, _ := file.Stat() + if fi.Size() > common.FileSizeLimit { + return fmt.Errorf("file exceeds 5GB limit") + } + + respObj, err := GeneratePresignedUploadURL(ctx, g3, req.Filename, req.FileMetadata, req.Bucket) + if err != nil { + return err + } + + // Generate request with progress bar + var p *mpb.Progress + if showProgress { + p = mpb.New(mpb.WithOutput(os.Stdout)) + } + + fur, err := generateUploadRequest(ctx, g3, common.FileUploadRequestObject{ + FilePath: req.FilePath, + Filename: req.Filename, + PresignedURL: respObj.URL, + GUID: respObj.GUID, + Bucket: req.Bucket, + }, file, p) + if err != nil { + return err + } + + return MultipartUpload(ctx, g3, fur, file, showProgress) +} diff --git a/client/upload/utils.go b/client/upload/utils.go new file mode 100644 index 0000000..2dbfa85 --- /dev/null +++ b/client/upload/utils.go @@ -0,0 +1,133 @@ +package upload + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" +) + +func SeparateSingleAndMultipartUploads(g3i client.Gen3Interface, objects []common.FileUploadRequestObject) ([]common.FileUploadRequestObject, []common.FileUploadRequestObject) { + fileSizeLimit := common.FileSizeLimit + + var singlepartObjects []common.FileUploadRequestObject + var multipartObjects []common.FileUploadRequestObject + + for _, object := range objects { + fi, err := os.Stat(object.FilePath) + if err != nil { + if os.IsNotExist(err) { + g3i.Logger().Printf("The file you specified \"%s\" does not exist locally\n", object.FilePath) + } else { + g3i.Logger().Println("File stat error: " + err.Error()) + } + g3i.Logger().Failed(object.FilePath, object.Filename, object.FileMetadata, object.GUID, 0, false) + continue + } + if fi.IsDir() { + continue + } + if _, ok := g3i.Logger().GetSucceededLogMap()[object.FilePath]; ok { + g3i.Logger().Println("File \"" + object.FilePath + "\" found in history. Skipping.") + continue + } + if fi.Size() > common.MultipartFileSizeLimit { + g3i.Logger().Printf("File %s exceeds max limit\n", fi.Name()) + continue + } + if fi.Size() > int64(fileSizeLimit) { + multipartObjects = append(multipartObjects, object) + } else { + singlepartObjects = append(singlepartObjects, object) + } + } + return singlepartObjects, multipartObjects +} + +// ProcessFilename returns an FileInfo object which has the information about the path and name to be used for upload of a file +func ProcessFilename(logger logs.Logger, uploadPath string, filePath string, objectId string, includeSubDirName bool, includeMetadata bool) (common.FileUploadRequestObject, error) { + var err error + filePath, err = common.GetAbsolutePath(filePath) + if err != nil { + return common.FileUploadRequestObject{}, err + } + + filename := filepath.Base(filePath) // Default to base filename + + var metadata common.FileMetadata + if includeSubDirName { + absUploadPath, err := common.GetAbsolutePath(uploadPath) + if err != nil { + return common.FileUploadRequestObject{}, err + } + + // Ensure absUploadPath is a directory path for relative calculation + // Trim the optional wildcard if present + uploadDir := strings.TrimSuffix(absUploadPath, common.PathSeparator+"*") + fileInfo, err := os.Stat(uploadDir) + if err != nil { + return common.FileUploadRequestObject{}, err + } + if fileInfo.IsDir() { + // Calculate the path of the file relative to the upload directory + relPath, err := filepath.Rel(uploadDir, filePath) + if err != nil { + return common.FileUploadRequestObject{}, err + } + filename = relPath + } + } + + if includeMetadata { + // The metadata path is the file name plus '_metadata.json' + metadataFilePath := strings.TrimSuffix(filePath, filepath.Ext(filePath)) + "_metadata.json" + var metadataFileBytes []byte + if _, err := os.Stat(metadataFilePath); err == nil { + metadataFileBytes, err = os.ReadFile(metadataFilePath) + if err != nil { + return common.FileUploadRequestObject{}, errors.New("Error reading metadata file " + metadataFilePath + ": " + err.Error()) + } + err := json.Unmarshal(metadataFileBytes, &metadata) + if err != nil { + return common.FileUploadRequestObject{}, errors.New("Error parsing metadata file " + metadataFilePath + ": " + err.Error()) + } + } else { + // No metadata file was found for this file -- proceed, but warn the user. + logger.Printf("WARNING: File metadata is enabled, but could not find the metadata file %v for file %v. Execute `data-client upload --help` for more info on file metadata.\n", metadataFilePath, filePath) + } + } + return common.FileUploadRequestObject{FilePath: filePath, Filename: filename, FileMetadata: metadata, GUID: objectId}, nil +} + +// FormatSize helps to parse a int64 size into string +func FormatSize(size int64) string { + var unitSize int64 + switch { + case size >= common.TB: + unitSize = common.TB + case size >= common.GB: + unitSize = common.GB + case size >= common.MB: + unitSize = common.MB + case size >= common.KB: + unitSize = common.KB + default: + unitSize = common.B + } + + var unitMap = map[int64]string{ + common.B: "B", + common.KB: "KB", + common.MB: "MB", + common.GB: "GB", + common.TB: "TB", + } + + return fmt.Sprintf("%.1f"+unitMap[unitSize], float64(size)/float64(unitSize)) +} diff --git a/client/g3cmd/auth.go b/cmd/auth.go similarity index 88% rename from client/g3cmd/auth.go rename to cmd/auth.go index 2dbd361..7de1b36 100644 --- a/client/g3cmd/auth.go +++ b/cmd/auth.go @@ -1,4 +1,4 @@ -package g3cmd +package cmd import ( "context" @@ -6,7 +6,7 @@ import ( "log" "sort" - client "github.com/calypr/data-client/client/gen3Client" + "github.com/calypr/data-client/client/client" "github.com/calypr/data-client/client/logs" "github.com/spf13/cobra" ) @@ -24,19 +24,19 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole()) defer logCloser() - g3i, err := client.NewGen3Interface(context.Background(), profile, logger) + g3i, err := client.NewGen3Interface(profile, logger) if err != nil { log.Fatalf("Fatal NewGen3Interface error: %s\n", err) } - host, resourceAccess, err := g3i.CheckPrivileges() + resourceAccess, err := g3i.CheckPrivileges(context.Background()) if err != nil { g3i.Logger().Fatalf("Fatal authentication error: %s\n", err) } else { if len(resourceAccess) == 0 { - g3i.Logger().Printf("\nYou don't currently have access to any resources at %s\n", host) + g3i.Logger().Printf("\nYou don't currently have access to any resources at %s\n", g3i.GetCredential().APIEndpoint) } else { - g3i.Logger().Printf("\nYou have access to the following resource(s) at %s:\n", host) + g3i.Logger().Printf("\nYou have access to the following resource(s) at %s:\n", g3i.GetCredential().APIEndpoint) // Sort by resource name resources := make([]string, 0, len(resourceAccess)) diff --git a/client/g3cmd/configure.go b/cmd/configure.go similarity index 83% rename from client/g3cmd/configure.go rename to cmd/configure.go index d9d97e2..604693d 100644 --- a/client/g3cmd/configure.go +++ b/cmd/configure.go @@ -1,11 +1,14 @@ -package g3cmd +package cmd import ( + "context" "fmt" + "github.com/calypr/data-client/client/api" "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/jwt" + "github.com/calypr/data-client/client/conf" "github.com/calypr/data-client/client/logs" + req "github.com/calypr/data-client/client/request" "github.com/spf13/cobra" ) @@ -24,7 +27,7 @@ func init() { Example: `./data-client configure --profile= --cred= --apiendpoint=https://data.mycommons.org`, Run: func(cmd *cobra.Command, args []string) { // don't initialize transmission logs for non-uploading related commands - cred := &jwt.Credential{ + cred := &conf.Credential{ Profile: profile, APIEndpoint: apiEndpoint, AccessToken: fenceToken, @@ -34,21 +37,27 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole()) defer logCloser() - conf := jwt.Configure{Logs: logger} - + configure := conf.NewConfigure(logger) if credFile != "" { - readCred, err := conf.ReadCredentials(credFile, "") + readCred, err := configure.Import(credFile, "") if err != nil { logger.Fatal(err) // or return proper error } - cred.KeyId = readCred.KeyId + cred.KeyID = readCred.KeyID cred.APIKey = readCred.APIKey if readCred.APIEndpoint != "" { cred.APIEndpoint = readCred.APIEndpoint } cred.AccessToken = "" } - err := jwt.UpdateConfig(logger, cred) + + newFunc := api.NewFunctions( + configure, + req.NewRequestInterface(logger, cred, configure), + cred, + logger, + ) + err := newFunc.ExportCredential(context.Background(), cred) if err != nil { logger.Println(err.Error()) } diff --git a/cmd/delete.go b/cmd/delete.go new file mode 100644 index 0000000..e11c92f --- /dev/null +++ b/cmd/delete.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "context" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/logs" + "github.com/spf13/cobra" +) + +//Not support yet, place holder only + +func init() { + var guid string + var deleteCmd = &cobra.Command{ // nolint:deadcode,unused,varcheck + Use: "delete", + Short: "Send DELETE HTTP Request for given URI", + Long: `Deletes a given URI from the database. +If no profile is specified, "default" profile is used for authentication.`, + Example: `./data-client delete --uri=v0/submission/bpa/test/entities/example_id + ./data-client delete --profile=user1 --uri=v0/submission/bpa/test/entities/1af1d0ab-efec-4049-98f0-ae0f4bb1bc64`, + Run: func(cmd *cobra.Command, args []string) { + + logger, logCloser := logs.New(profile, logs.WithConsole()) + defer logCloser() + + g3i, err := client.NewGen3Interface(profile, logger) + if err != nil { + logger.Fatalf("Fatal NewGen3Interface error: %s\n", err) + } + + msg, err := g3i.DeleteRecord(context.Background(), guid) + if err != nil { + logger.Fatal(err) + } + logger.Println(msg) + }, + } + + deleteCmd.Flags().StringVar(&profile, "guid", "", "Specify the profile to check your access privileges") + RootCmd.AddCommand(deleteCmd) +} diff --git a/cmd/download-multipart.go b/cmd/download-multipart.go new file mode 100644 index 0000000..3718720 --- /dev/null +++ b/cmd/download-multipart.go @@ -0,0 +1,261 @@ +package cmd + +/* +// DownloadSignedURL downloads a file from a signed URL with: +// - Resumable single-stream download (if partial file exists) +// - Concurrent multipart download for large files (>1GB) +// - Retries via go-retryablehttp +// - Progress bar support via mpb +func DownloadSignedURL(signedURL, dstPath string) error { + // Setup retryable client + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = 10 + retryClient.RetryWaitMin = 1 * time.Second + retryClient.RetryWaitMax = 30 * time.Second + retryClient.Logger = nil // silent + client := retryClient.StandardClient() + client.Timeout = 0 // no timeout for large downloads + + // HEAD to get size and Accept-Ranges support + headResp, err := client.Head(signedURL) + if err != nil { + return fmt.Errorf("HEAD request failed: %w", err) + } + defer headResp.Body.Close() + + if headResp.StatusCode != http.StatusOK { + return fmt.Errorf("HEAD failed: %s", headResp.Status) + } + + contentLength := headResp.ContentLength + if contentLength <= 0 { + return fmt.Errorf("invalid Content-Length") + } + + acceptRanges := headResp.Header.Get("Accept-Ranges") == "bytes" + if !acceptRanges { + return fmt.Errorf("server does not support range requests") + } + + // Ensure directory exists + if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil { + return fmt.Errorf("mkdir failed: %w", err) + } + + // Check if partial file exists + stat, _ := os.Stat(dstPath) + existingSize := int64(0) + if stat != nil { + existingSize = stat.Size() + } + + // If we have a partial file, resume with single stream (safer and simpler) + if existingSize > 0 && existingSize < contentLength { + return downloadResumableSingle(signedURL, dstPath, contentLength, existingSize, client) + } + + // For complete downloads: use multipart if file is large enough + if contentLength >= 5*1024*1024*1024 { + return downloadConcurrentMultipart(signedURL, dstPath, contentLength, client) + } + + // Otherwise: simple single-stream download + return downloadResumableSingle(signedURL, dstPath, contentLength, 0, client) +} + +// downloadResumableSingle handles single-stream (possibly resumed) download +func downloadResumableSingle(signedURL, dstPath string, totalSize, startByte int64, client *http.Client) error { + req, err := http.NewRequest("GET", signedURL, nil) + if err != nil { + return err + } + if startByte > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startByte)) + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("GET failed: %w", err) + } + defer resp.Body.Close() + + if startByte > 0 && resp.StatusCode != http.StatusPartialContent { + return fmt.Errorf("expected 206 Partial Content, got %d", resp.StatusCode) + } + if startByte == 0 && resp.StatusCode != http.StatusOK { + return fmt.Errorf("expected 200 OK, got %d", resp.StatusCode) + } + + file, err := os.OpenFile(dstPath, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer file.Close() + + if startByte > 0 { + if _, err := file.Seek(startByte, io.SeekStart); err != nil { + return err + } + } else { + if err := file.Truncate(0); err != nil { + return err + } + } + + var writer io.Writer = file + if progress != nil { + bar := progress.AddBar(totalSize, + mpb.PrependDecorators( + decor.Name(filepath.Base(dstPath)+" "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), "% .1f"), + ), + ) + if startByte > 0 { + bar.SetCurrent(startByte) + } + writer = bar.ProxyWriter(file) + } + + _, err = io.Copy(writer, resp.Body) + return err +} + +// downloadConcurrentMultipart downloads in parallel chunks +func downloadConcurrentMultipart(signedURL, dstPath string, totalSize int64, client *http.Client) error { + numChunks := int((totalSize + chunkSize - 1) / chunkSize) + if numChunks < defaultConcurrency { + numChunks = defaultConcurrency + } + chunkSizeActual := (totalSize + int64(numChunks) - 1) / int64(numChunks) + + // Pre-allocate file + file, err := os.Create(dstPath) + if err != nil { + return err + } + if err := file.Truncate(totalSize); err != nil { + file.Close() + return err + } + file.Close() + + var wg sync.WaitGroup + var mu sync.Mutex + var downloadErr error + + // Shared progress bar for total + var totalBar *mpb.Bar + if progress != nil { + totalBar = progress.AddBar(totalSize, + mpb.PrependDecorators( + decor.Name(filepath.Base(dstPath)+" (multipart) "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), "% .1f"), + ), + ) + } + + concurrency := defaultConcurrency + sem := make(chan struct{}, concurrency) + + for i := 0; i < int(numChunks); i++ { + start := int64(i) * chunkSizeActual + end := start + chunkSizeActual - 1 + if end >= totalSize { + end = totalSize - 1 + } + if start > end { + break + } + + wg.Add(1) + sem <- struct{}{} + + go func(start, end int64, chunkIdx int) { + defer wg.Done() + defer func() { <-sem }() + + req, _ := http.NewRequest("GET", signedURL, nil) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + + resp, err := client.Do(req) + if err != nil { + mu.Lock() + if downloadErr == nil { + downloadErr = fmt.Errorf("chunk %d failed: %w", chunkIdx, err) + } + mu.Unlock() + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusPartialContent { + mu.Lock() + if downloadErr == nil { + downloadErr = fmt.Errorf("chunk %d expected 206, got %d", chunkIdx, resp.StatusCode) + } + mu.Unlock() + return + } + + file, err := os.OpenFile(dstPath, os.O_WRONLY, 0644) + if err != nil { + mu.Lock() + downloadErr = err + mu.Unlock() + return + } + file.Seek(start, io.SeekStart) + writer := io.Writer(file) + + var chunkWriter io.Writer = writer + if progress != nil { + chunkBar := progress.AddBar(end-start+1, + mpb.BarRemoveOnComplete(), + mpb.PrependDecorators(decor.Name(fmt.Sprintf("chunk %d ", chunkIdx))), + ) + chunkWriter = chunkBar.ProxyWriter(writer) + defer file.Close() + } + + if _, err := io.Copy(chunkWriter, resp.Body); err != nil { + mu.Lock() + if downloadErr == nil { + downloadErr = fmt.Errorf("chunk %d copy failed: %w", chunkIdx, err) + } + mu.Unlock() + } else { + if totalBar != nil { + totalBar.IncrBy(int(end - start + 1)) + } + } + if progress == nil { + file.Close() + } + }(start, end, i) + } + + wg.Wait() + + if downloadErr != nil { + if totalBar != nil { + totalBar.Abort(true) + } + return downloadErr + } + + if totalBar != nil { + totalBar.SetCurrent(totalSize) + } + + return nil +} + +*/ diff --git a/cmd/download-multiple.go b/cmd/download-multiple.go new file mode 100644 index 0000000..ed59486 --- /dev/null +++ b/cmd/download-multiple.go @@ -0,0 +1,111 @@ +package cmd + +import ( + "context" + "encoding/json" + "io" + "log" + "os" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/download" + "github.com/calypr/data-client/client/logs" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + + "github.com/spf13/cobra" +) + +func init() { + var manifestPath string + var downloadPath string + var filenameFormat string + var rename bool + var noPrompt bool + var protocol string + var numParallel int + var skipCompleted bool + + var downloadMultipleCmd = &cobra.Command{ + Use: "download-multiple", + Short: "Download multiple of files from a specified manifest", + Long: `Get presigned URLs for multiple of files specified in a manifest file and then download all of them.`, + Example: `./data-client download-multiple --profile --manifest --download-path `, + Run: func(cmd *cobra.Command, args []string) { + // don't initialize transmission logs for non-uploading related commands + + logger, logCloser := logs.New(profile, logs.WithConsole(), logs.WithFailedLog(), logs.WithScoreboard(), logs.WithSucceededLog()) + defer logCloser() + + g3i, err := client.NewGen3Interface(profile, logger) + if err != nil { + log.Fatalf("Failed to parse config on profile %s, %v", profile, err) + } + + manifestPath, _ = common.GetAbsolutePath(manifestPath) + manifestFile, err := os.Open(manifestPath) + if err != nil { + g3i.Logger().Fatalf("Failed to open manifest file %s, %v\n", manifestPath, err) + } + defer manifestFile.Close() + manifestFileStat, err := manifestFile.Stat() + if err != nil { + g3i.Logger().Fatalf("Failed to get manifest file stats %s, %v\n", manifestPath, err) + } + g3i.Logger().Println("Reading manifest...") + manifestFileSize := manifestFileStat.Size() + manifestProgress := mpb.New(mpb.WithOutput(os.Stdout)) + manifestFileBar := manifestProgress.AddBar(manifestFileSize, + mpb.PrependDecorators( + decor.Name("Manifest "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators(decor.Percentage()), + ) + + manifestFileReader := manifestFileBar.ProxyReader(manifestFile) + + manifestBytes, err := io.ReadAll(manifestFileReader) + if err != nil { + g3i.Logger().Fatalf("Failed reading manifest %s, %v\n", manifestPath, err) + } + manifestProgress.Wait() + + var objects []common.ManifestObject + err = json.Unmarshal(manifestBytes, &objects) + if err != nil { + g3i.Logger().Fatalf("Error has occurred during unmarshalling manifest object: %v\n", err) + } + + err = download.DownloadMultiple( + context.Background(), + g3i, + objects, + downloadPath, + filenameFormat, + rename, + noPrompt, + protocol, + numParallel, + skipCompleted, + ) + if err != nil { + g3i.Logger().Fatal(err.Error()) + } + }, + } + + downloadMultipleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") + downloadMultipleCmd.MarkFlagRequired("profile") //nolint:errcheck + downloadMultipleCmd.Flags().StringVar(&manifestPath, "manifest", "", "The manifest file to read from. A valid manifest can be acquired by using the \"Download Manifest\" button in Data Explorer from a data common's portal") + downloadMultipleCmd.MarkFlagRequired("manifest") //nolint:errcheck + downloadMultipleCmd.Flags().StringVar(&downloadPath, "download-path", ".", "The directory in which to store the downloaded files") + downloadMultipleCmd.Flags().StringVar(&filenameFormat, "filename-format", "original", "The format of filename to be used, including \"original\", \"guid\" and \"combined\"") + downloadMultipleCmd.Flags().BoolVar(&rename, "rename", false, "Only useful when \"--filename-format=original\", will rename file by appending a counter value to its filename if set to true, otherwise the same filename will be used") + downloadMultipleCmd.Flags().BoolVar(&noPrompt, "no-prompt", false, "If set to true, will not display user prompt message for confirmation") + downloadMultipleCmd.Flags().StringVar(&protocol, "protocol", "", "Specify the preferred protocol with --protocol=s3") + downloadMultipleCmd.Flags().IntVar(&numParallel, "numparallel", 1, "Number of downloads to run in parallel") + downloadMultipleCmd.Flags().BoolVar(&skipCompleted, "skip-completed", false, "If set to true, will check for filename and size before download and skip any files in \"download-path\" that matches both") + RootCmd.AddCommand(downloadMultipleCmd) +} diff --git a/client/g3cmd/download-single.go b/cmd/download-single.go similarity index 83% rename from client/g3cmd/download-single.go rename to cmd/download-single.go index 6038f23..6438acd 100644 --- a/client/g3cmd/download-single.go +++ b/cmd/download-single.go @@ -1,10 +1,12 @@ -package g3cmd +package cmd import ( "context" "log" - client "github.com/calypr/data-client/client/gen3Client" + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/download" "github.com/calypr/data-client/client/logs" "github.com/spf13/cobra" ) @@ -30,16 +32,28 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole(), logs.WithFailedLog(), logs.WithSucceededLog(), logs.WithScoreboard()) defer logCloser() - g3I, err := client.NewGen3Interface(context.Background(), profile, logger) + g3I, err := client.NewGen3Interface(profile, logger) if err != nil { log.Fatalf("Failed to parse config on profile %s, %v", profile, err) } - obj := ManifestObject{ - ObjectID: guid, + objects := []common.ManifestObject{ + common.ManifestObject{ + ObjectID: guid, + }, } - objects := []ManifestObject{obj} - err = downloadFile(g3I, objects, downloadPath, filenameFormat, rename, noPrompt, protocol, 1, skipCompleted) + err = download.DownloadMultiple( + context.Background(), + g3I, + objects, + downloadPath, + filenameFormat, + rename, + noPrompt, + protocol, + 1, + skipCompleted, + ) if err != nil { g3I.Logger().Println(err.Error()) } diff --git a/client/g3cmd/generate-tsv.go b/cmd/generate-tsv.go similarity index 96% rename from client/g3cmd/generate-tsv.go rename to cmd/generate-tsv.go index 9abff77..47d92c4 100644 --- a/client/g3cmd/generate-tsv.go +++ b/cmd/generate-tsv.go @@ -1,4 +1,4 @@ -package g3cmd +package cmd import ( "github.com/spf13/cobra" diff --git a/cmd/gitversion.go b/cmd/gitversion.go new file mode 100644 index 0000000..cc123f5 --- /dev/null +++ b/cmd/gitversion.go @@ -0,0 +1,6 @@ +package cmd + +var ( + gitcommit = "N/A" + gitversion = "2025.12" +) diff --git a/cmd/retry-upload.go b/cmd/retry-upload.go new file mode 100644 index 0000000..bd68a42 --- /dev/null +++ b/cmd/retry-upload.go @@ -0,0 +1,59 @@ +package cmd + +import ( + "context" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/upload" + + "github.com/spf13/cobra" +) + +func init() { + var failedLogPath, profile string + + var retryUploadCmd = &cobra.Command{ + Use: "retry-upload", + Short: "Retry failed uploads from a failed_log.json", + Long: `Re-uploads files listed in a failed log using exponential backoff and progress bars.`, + Example: `./data-client retry-upload --profile=myprofile --failed-log-path=/path/to/failed_log.json`, + Run: func(cmd *cobra.Command, args []string) { + Logger, closer := logs.New(profile, + logs.WithConsole(), + logs.WithMessageFile(), + logs.WithFailedLog(), + logs.WithSucceededLog(), + ) + defer closer() + + g3, err := client.NewGen3Interface(profile, Logger) + if err != nil { + Logger.Fatalf("Failed to initialize client: %v", err) + } + + logger := g3.Logger() + + // Create scoreboard with our logger injected + sb := logs.NewSB(common.MaxRetryCount, logger) + + // Load failed log + failedMap, err := common.LoadFailedLog(failedLogPath) + if err != nil { + logger.Fatalf("Cannot read failed log: %v", err) + } + + upload.RetryFailedUploads(context.Background(), g3, failedMap) + sb.PrintSB() + }, + } + + retryUploadCmd.Flags().StringVar(&profile, "profile", "", "Profile to use") + retryUploadCmd.MarkFlagRequired("profile") + + retryUploadCmd.Flags().StringVar(&failedLogPath, "failed-log-path", "", "Path to failed_log.json") + retryUploadCmd.MarkFlagRequired("failed-log-path") + + RootCmd.AddCommand(retryUploadCmd) +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..a2ec2f8 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,31 @@ +package cmd + +import ( + "os" + + "github.com/spf13/cobra" +) + +var profile string + +// RootCmd represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "data-client", + Short: "Use the data-client to interact with a Gen3 Data Commons", + Long: "Gen3 Client for downloading, uploading and submitting data to data commons.\ndata-client version: " + gitversion + ", commit: " + gitcommit, + Version: gitversion, +} + +// Execute adds all child commands to the root command sets flags appropriately +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := RootCmd.Execute(); err != nil { + os.Stderr.WriteString("Error: " + err.Error() + "\n") + os.Exit(1) + } +} + +func init() { + RootCmd.PersistentFlags().StringVar(&profile, "profile", "", "Specify profile to use") + _ = RootCmd.MarkFlagRequired("profile") +} diff --git a/cmd/upload-multipart.go b/cmd/upload-multipart.go new file mode 100644 index 0000000..86330a3 --- /dev/null +++ b/cmd/upload-multipart.go @@ -0,0 +1,82 @@ +package cmd + +import ( + "context" + "os" + "path/filepath" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/upload" + "github.com/spf13/cobra" +) + +func init() { + var ( + profile string + filePath string + guid string + bucketName string + ) + + var uploadMultipartCmd = &cobra.Command{ + Use: "upload-multipart", + Short: "Upload a single file using multipart upload", + Long: `Uploads a large file to object storage using multipart upload. +This method is resilient to network interruptions and supports resume capability.`, + Example: `./data-client upload-multipart --profile=myprofile --file-path=./large.bam +./data-client upload-multipart --profile=myprofile --file-path=./data.bam --guid=existing-guid`, + Run: func(cmd *cobra.Command, args []string) { + // Initialize logger + logger, logCloser := logs.New(profile, logs.WithConsole()) + defer logCloser() + + logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) + defer closer() + + g3, err := client.NewGen3Interface( + profile, + logger, + ) + + if err != nil { + logger.Fatalf("failed to initialize Gen3 interface: %w", err) + } + + absPath, err := common.GetAbsolutePath(filePath) + if err != nil { + logger.Fatalf("invalid file path: %w", err) + } + + fileInfo := common.FileUploadRequestObject{ + FilePath: absPath, + Filename: filepath.Base(absPath), + GUID: guid, + FileMetadata: common.FileMetadata{}, + } + + file, err := os.Open(absPath) + if err != nil { + logger.Fatalf("cannot open file %s: %w", absPath, err) + } + defer file.Close() + + err = upload.MultipartUpload(context.Background(), g3, fileInfo, file, true) + if err != nil { + logger.Fatal(err) + } + + }, + } + + uploadMultipartCmd.Flags().StringVar(&profile, "profile", "", "Specify the profile to use for upload") + uploadMultipartCmd.Flags().StringVar(&filePath, "file-path", "", "Path to the file to upload") + uploadMultipartCmd.Flags().StringVar(&guid, "guid", "", "Optional existing GUID (otherwise generated)") + uploadMultipartCmd.Flags().StringVar(&bucketName, "bucket", "", "Target bucket (defaults to configured DATA_UPLOAD_BUCKET)") + + _ = uploadMultipartCmd.MarkFlagRequired("profile") + _ = uploadMultipartCmd.MarkFlagRequired("file-path") + + RootCmd.AddCommand(uploadMultipartCmd) +} diff --git a/cmd/upload-multiple.go b/cmd/upload-multiple.go new file mode 100644 index 0000000..13f91d7 --- /dev/null +++ b/cmd/upload-multiple.go @@ -0,0 +1,176 @@ +package cmd + +// Deprecated: Use "upload" instead for new uploads (without pre-existing GUIDs). +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "os" + "path/filepath" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/upload" + "github.com/spf13/cobra" +) + +func init() { + var bucketName string + var manifestPath string + var uploadPath string + var batch bool + var numParallel int + var includeSubDirName bool + + uploadMultipleCmd := &cobra.Command{ + Use: "upload-multiple", + Short: "Upload multiple files from a specified manifest (uses pre-existing GUIDs)", + Long: `Get presigned URLs for multiple files specified in a manifest file and then upload all of them. +This command is for uploading to existing GUIDs (e.g., from a downloaded manifest). +For new uploads (new GUIDs generated), use "data-client upload" instead. + +Options to run multipart uploads for large files and parallel batch uploading are available.`, + Example: `./data-client upload-multiple --profile= --manifest= --upload-path= --bucket= --batch`, + Run: func(cmd *cobra.Command, args []string) { + // Warning message + fmt.Printf("Notice: this command uploads to pre-existing GUIDs from a manifest.\nIf you want to upload new files (new GUIDs generated automatically), use \"./data-client upload\" instead.\n\n") + + ctx := context.Background() + + logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) + defer closer() + + g3i, err := client.NewGen3Interface(profile, logger) + if err != nil { + logger.Fatalf("Failed to parse config on profile %s: %v", profile, err) + } + + // Basic config validation + profileConfig := g3i.GetCredential() + if profileConfig.APIEndpoint == "" { + logger.Fatal("No APIEndpoint found in configuration. Run \"./data-client configure\" first.") + } + host, err := url.Parse(profileConfig.APIEndpoint) + if err != nil { + logger.Fatal("Error parsing APIEndpoint:", err) + } + dataExplorerURL := host.Scheme + "://" + host.Host + "/explorer" + + // Load manifest + var objects []common.ManifestObject + manifestBytes, err := os.ReadFile(manifestPath) + if err != nil { + logger.Fatalf("Failed reading manifest %s: %v\nA valid manifest can be acquired from %s", manifestPath, err, dataExplorerURL) + } + if err := json.Unmarshal(manifestBytes, &objects); err != nil { + logger.Fatalf("Invalid manifest JSON: %v", err) + } + + absUploadPath, err := common.GetAbsolutePath(uploadPath) + if err != nil { + logger.Fatalf("Error resolving upload path: %v", err) + } + + // Build FileUploadRequestObjects using existing GUIDs + var requests []common.FileUploadRequestObject + logger.Println("\nProcessing manifest entries...") + + for _, obj := range objects { + localFilePath := filepath.Join(absUploadPath, obj.Title) + if err != nil { + logger.Println("Skipping:", err) + continue + } + + fur, err := upload.ProcessFilename(logger, absUploadPath, localFilePath, obj.ObjectID, includeSubDirName, false) + if err != nil { + logger.Printf("Skipping %s: %v\n", localFilePath, err) + logger.Failed(localFilePath, filepath.Base(localFilePath), common.FileMetadata{}, obj.ObjectID, 0, false) + continue + } + + // GUID comes from manifest → override + fur.GUID = obj.ObjectID + fur.Bucket = bucketName + + logger.Println("\t" + localFilePath + " → GUID " + obj.ObjectID) + requests = append(requests, fur) + } + + if len(requests) == 0 { + logger.Println("No valid files found to upload from manifest.") + return + } + + // Classify single vs multipart + single, multi := upload.SeparateSingleAndMultipartUploads(g3i, requests) + + // Upload single-part files + if batch { + workers, respCh, errCh, batchFURObjects := upload.InitBatchUploadChannels(numParallel, len(single)) + for i, furObject := range single { + // FileInfo processing and path normalization are already done, so we use the object directly + if len(batchFURObjects) < workers { + batchFURObjects = append(batchFURObjects, furObject) + } else { + upload.BatchUpload(ctx, g3i, batchFURObjects, workers, respCh, errCh, bucketName) + batchFURObjects = []common.FileUploadRequestObject{furObject} + } + if i == len(single)-1 && len(batchFURObjects) > 0 { + upload.BatchUpload(ctx, g3i, batchFURObjects, workers, respCh, errCh, bucketName) + } + } + } else { + for _, req := range single { + upload.UploadSingle(ctx, profileConfig.Profile, req.GUID, req.FilePath, req.Bucket, true) + } + } + + // Upload multipart files + for _, req := range multi { + + file, err := os.Open(req.FilePath) + if err != nil { + g3i.Logger().Printf("Error opening file %s : %v", req.FilePath, err) + continue + } + + err = upload.MultipartUpload(ctx, g3i, req, file, true) + if err != nil { + logger.Println("Multipart upload failed:", err) + } + } + + // Retry logic (only if nothing succeeded initially) + if len(logger.GetSucceededLogMap()) == 0 { + failed := logger.GetFailedLogMap() + if len(failed) > 0 { + upload.RetryFailedUploads(ctx, g3i, failed) + } + } + + logger.Scoreboard().PrintSB() + }, + } + + // Flags + uploadMultipleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") + uploadMultipleCmd.MarkFlagRequired("profile") + + uploadMultipleCmd.Flags().StringVar(&manifestPath, "manifest", "", "Path to the manifest JSON file") + uploadMultipleCmd.MarkFlagRequired("manifest") + + uploadMultipleCmd.Flags().StringVar(&uploadPath, "upload-path", "", "Directory containing the files to upload") + uploadMultipleCmd.MarkFlagRequired("upload-path") + + uploadMultipleCmd.Flags().BoolVar(&batch, "batch", true, "Upload single-part files in parallel") + uploadMultipleCmd.Flags().IntVar(&numParallel, "numparallel", 4, "Number of parallel uploads") + + uploadMultipleCmd.Flags().StringVar(&bucketName, "bucket", "", "Target bucket (defaults to configured DATA_UPLOAD_BUCKET)") + + uploadMultipleCmd.Flags().BoolVar(&includeSubDirName, "include-subdirname", true, "Include subdirectory names in object key") + + RootCmd.AddCommand(uploadMultipleCmd) +} diff --git a/cmd/upload-single.go b/cmd/upload-single.go new file mode 100644 index 0000000..d8a8b53 --- /dev/null +++ b/cmd/upload-single.go @@ -0,0 +1,37 @@ +package cmd + +// Deprecated: Use upload instead. +import ( + "context" + "log" + + "github.com/calypr/data-client/client/upload" + "github.com/spf13/cobra" +) + +func init() { + var guid string + var filePath string + var bucketName string + + var uploadSingleCmd = &cobra.Command{ + Use: "upload-single", + Short: "Upload a single file to a GUID", + Long: `Gets a presigned URL for which to upload a file associated with a GUID and then uploads the specified file.`, + Example: `./data-client upload-single --profile= --guid=f6923cf3-xxxx-xxxx-xxxx-14ab3f84f9d6 --file=`, + Run: func(cmd *cobra.Command, args []string) { + err := upload.UploadSingle(context.Background(), profile, guid, filePath, bucketName, true) + if err != nil { + log.Fatalln(err.Error()) + } + }, + } + uploadSingleCmd.Flags().StringVar(&profile, "profile", "", "Specify profile to use") + uploadSingleCmd.MarkFlagRequired("profile") //nolint:errcheck + uploadSingleCmd.Flags().StringVar(&guid, "guid", "", "Specify the guid for the data you would like to work with") + uploadSingleCmd.MarkFlagRequired("guid") //nolint:errcheck + uploadSingleCmd.Flags().StringVar(&filePath, "file", "", "Specify file to upload to with --file=~/path/to/file") + uploadSingleCmd.MarkFlagRequired("file") //nolint:errcheck + uploadSingleCmd.Flags().StringVar(&bucketName, "bucket", "", "The bucket to which files will be uploaded. If not provided, defaults to Gen3's configured DATA_UPLOAD_BUCKET.") + RootCmd.AddCommand(uploadSingleCmd) +} diff --git a/client/g3cmd/upload.go b/cmd/upload.go similarity index 70% rename from client/g3cmd/upload.go rename to cmd/upload.go index 2e50a87..ffae48f 100644 --- a/client/g3cmd/upload.go +++ b/cmd/upload.go @@ -1,4 +1,4 @@ -package g3cmd +package cmd import ( "context" @@ -6,9 +6,10 @@ import ( "os" "path/filepath" + "github.com/calypr/data-client/client/client" "github.com/calypr/data-client/client/common" - client "github.com/calypr/data-client/client/gen3Client" "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/upload" "github.com/spf13/cobra" ) @@ -17,7 +18,6 @@ func init() { var includeSubDirName bool var uploadPath string var batch bool - var forceMultipart bool var numParallel int var hasMetadata bool var uploadCmd = &cobra.Command{ @@ -33,17 +33,18 @@ func init() { "For the format of the metadata files, see the README.", Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() Logger, logCloser := logs.New(profile, logs.WithSucceededLog(), logs.WithScoreboard(), logs.WithFailedLog()) defer logCloser() // Instantiate interface to Gen3 - g3i, err := client.NewGen3Interface(context.Background(), profile, Logger) + g3i, err := client.NewGen3Interface(profile, Logger) if err != nil { log.Fatalf("Failed to parse config on profile %s, %v", profile, err) } logger := g3i.Logger() if hasMetadata { - hasShepherd, err := g3i.CheckForShepherdAPI() + hasShepherd, err := g3i.CheckForShepherdAPI(ctx) if err != nil { logger.Printf("WARNING: Error when checking for Shepherd API: %v", err) } else { @@ -64,7 +65,8 @@ func init() { for _, filePath := range filePaths { // Use ProcessFilename to create the unified object (GUID is empty here, as this command requests a new GUID) // ProcessFilename signature: (uploadPath, filePath, objectId, includeSubDirName, includeMetadata) - furObject, err := ProcessFilename(g3i.Logger(), uploadPath, filePath, "", includeSubDirName, hasMetadata) + furObject, err := upload.ProcessFilename(g3i.Logger(), uploadPath, filePath, "", includeSubDirName, hasMetadata) + furObject.Bucket = bucketName // Handle case where ProcessFilename fails (e.g., metadata parsing error) if err != nil { @@ -91,20 +93,21 @@ func init() { return } - singlePartObjects, multipartObjects := separateSingleAndMultipartUploads(g3i, uploadRequestObjects, forceMultipart) + singlePartObjects, multipartObjects := upload.SeparateSingleAndMultipartUploads(g3i, uploadRequestObjects) + if batch { - workers, respCh, errCh, batchFURObjects := initBatchUploadChannels(numParallel, len(singlePartObjects)) + workers, respCh, errCh, batchFURObjects := upload.InitBatchUploadChannels(numParallel, len(singlePartObjects)) for _, furObject := range singlePartObjects { if len(batchFURObjects) < workers { batchFURObjects = append(batchFURObjects, furObject) } else { - batchUpload(g3i, batchFURObjects, workers, respCh, errCh, bucketName) + upload.BatchUpload(ctx, g3i, batchFURObjects, workers, respCh, errCh, bucketName) batchFURObjects = []common.FileUploadRequestObject{furObject} } } if len(batchFURObjects) > 0 { - batchUpload(g3i, batchFURObjects, workers, respCh, errCh, bucketName) + upload.BatchUpload(ctx, g3i, batchFURObjects, workers, respCh, errCh, bucketName) } if len(errCh) > 0 { @@ -119,24 +122,47 @@ func init() { for _, furObject := range singlePartObjects { file, err := os.Open(furObject.FilePath) if err != nil { - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) logger.Println("File open error: " + err.Error()) continue } - startSingleFileUpload(g3i, furObject, file, bucketName) + defer file.Close() + fi, err := file.Stat() + if err != nil { + logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Println("File stat error for file" + fi.Name() + ", file may be missing or unreadable because of permissions.\n") + continue + } + upload.UploadSingleFile(ctx, g3i, furObject, true) } } if len(multipartObjects) > 0 { - err := processMultipartUpload(g3i, multipartObjects, bucketName, includeSubDirName, uploadPath) - if err != nil { - logger.Println(err.Error()) + cred := g3i.GetCredential() + if cred.UseShepherd == "true" || + cred.UseShepherd == "" && common.DefaultUseShepherd == true { + logger.Printf("error: Shepherd currently does not support multipart uploads. For the moment, please disable Shepherd with\n $ data-client configure --profile=%v --use-shepherd=false\nand try again", cred.Profile) + return + } + g3i.Logger().Println("Multipart uploading...") + for _, furObject := range multipartObjects { + file, err := os.Open(furObject.FilePath) + if err != nil { + logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Println("File open error: " + err.Error()) + continue + } + err = upload.MultipartUpload(ctx, g3i, furObject, file, true) + if err != nil { + g3i.Logger().Println(err.Error()) + } else { + g3i.Logger().Scoreboard().IncrementSB(0) + } } } if len(g3i.Logger().GetSucceededLogMap()) == 0 { - retryUpload(g3i, g3i.Logger().GetFailedLogMap()) + upload.RetryFailedUploads(ctx, g3i, g3i.Logger().GetFailedLogMap()) } - g3i.Logger().Scoreboard().PrintSB() }, } @@ -148,7 +174,6 @@ func init() { uploadCmd.Flags().BoolVar(&batch, "batch", false, "Upload in parallel") uploadCmd.Flags().IntVar(&numParallel, "numparallel", 3, "Number of uploads to run in parallel") uploadCmd.Flags().BoolVar(&includeSubDirName, "include-subdirname", true, "Include subdirectory names in file name") - uploadCmd.Flags().BoolVar(&forceMultipart, "force-multipart", false, "Force to use multipart upload if possible") uploadCmd.Flags().BoolVar(&hasMetadata, "metadata", false, "Search for and upload file metadata alongside the file") uploadCmd.Flags().StringVar(&bucketName, "bucket", "", "The bucket to which files will be uploaded. If not provided, defaults to Gen3's configured DATA_UPLOAD_BUCKET.") RootCmd.AddCommand(uploadCmd) diff --git a/docs/DEVELOPER_DOCS.md b/docs/DEVELOPER_DOCS.md new file mode 100644 index 0000000..54478a7 --- /dev/null +++ b/docs/DEVELOPER_DOCS.md @@ -0,0 +1,91 @@ +# Dev Docs + +This repo is a heavily updated / refactored version of https://github.com/uc-cdis/cdis-data-client + +The new architecture splits out many of the mega packages into smaller, more digestable pieces. This whole CLI is essentially a Go client library for Gen3's Fence microservice. + +These new packages are: + +├── api +│   ├── gen3.go +│   └── types.go +├── client +│   └── client.go +├── common +│   ├── common.go +│   ├── constants.go +│   ├── isHidden_notwindows.go +│   ├── isHidden_windows.go +│   ├── logHelper.go +│   └── types.go +├── conf +│   ├── config.go +│   └── validate.go +├── download +│   ├── batch.go +│   ├── downloader.go +│   ├── file_info.go +│   ├── types.go +│   ├── url_resolution.go +│   └── utils.go +├── logs +│   ├── factory.go +│   ├── logger.go +│   ├── scoreboard.go +│   └── tee_logger.go +├── mocks +│   ├── mock_configure.go +│   ├── mock_functions.go +│   ├── mock_gen3interface.go +│   └── mock_request.go +├── request +│   ├── auth.go +│   ├── builder.go +│   └── request.go +└── upload + ├── batch.go + ├── multipart.go + ├── request.go + ├── retry.go + ├── singleFile.go + ├── types.go + ├── upload.go + └── utils.go + + +# api + +This is the main Client API for talking to fence. Some of the functions that are currently defined in upload/ and download should probablyl be broken out into this library also. + +# client + +This is a thin wrapper client that wraps the API interface to make the API calls easier to use from other packages. + +# common + +This contains common constants / utility functions that are used in the repo + +# conf + +This is the config package for loading / storing the gen3 credential. Note ~/.gen3/.ini file is where credentials / configurations are stored, +but the raw credential is also stored in ~/.gen3/ under whatever you called it. + +# download + +This is the business logic for all download and download related operations in the depo + +# logs + +This is where the logger is defined + +# mocks + +This contains mocks for testing the data-client + +# request + +This is the lowest level interface for doing requests. It implements some basic retry, and wraps the http round trip with a token if one is provided + +# upload + +This contains the business logic for all upload and upload related operations. diff --git a/go.mod b/go.mod index 6515b39..a40c2e0 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,12 @@ go 1.24.2 require ( github.com/golang-jwt/jwt/v5 v5.3.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/go-version v1.8.0 github.com/spf13/cobra v1.10.2 github.com/vbauerster/mpb/v8 v8.11.2 go.uber.org/mock v0.6.0 - golang.org/x/mod v0.31.0 + golang.org/x/sync v0.19.0 gopkg.in/ini.v1 v1.67.0 ) @@ -19,6 +20,7 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/go.sum b/go.sum index bae303b..57dfebd 100644 --- a/go.sum +++ b/go.sum @@ -9,17 +9,29 @@ github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsV github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -37,8 +49,8 @@ github.com/vbauerster/mpb/v8 v8.11.2/go.mod h1:mEB/M353al1a7wMUNtiymmPsEkGlJgeJm go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 00bb0f7..dd6e829 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ package main import ( - "github.com/calypr/data-client/client/g3cmd" + "github.com/calypr/data-client/cmd" ) func main() { - g3cmd.Execute() + cmd.Execute() } diff --git a/tests/download-multiple_test.go b/tests/download-multiple_test.go index 0113935..c2da7c6 100644 --- a/tests/download-multiple_test.go +++ b/tests/download-multiple_test.go @@ -8,176 +8,205 @@ import ( "strings" "testing" - "github.com/calypr/data-client/client/common" - g3cmd "github.com/calypr/data-client/client/g3cmd" - "github.com/calypr/data-client/client/jwt" + "github.com/calypr/data-client/client/api" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/download" "github.com/calypr/data-client/client/logs" "github.com/calypr/data-client/client/mocks" + req "github.com/calypr/data-client/client/request" "go.uber.org/mock/gomock" ) -// Add all other methods required by your logs.Logger interface! - -// If Shepherd is deployed, attempt to get the filename from the Shepherd API. func Test_askGen3ForFileInfo_withShepherd(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" testFileName := "test-file" testFileSize := int64(120) + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Expect AskGen3ForFileInfo to call shepherd looking for testGUID: respond with a valid file. + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + + // Expect credential access + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + + // Shepherd is available + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). + Return(true, nil) + + // Mock successful Shepherd response testBody := `{ - "record": { - "file_name": "test-file", - "size": 120, - "did": "000000-0000000-0000000-000000" - }, - "metadata": { - "_file_type": "PFB", - "_resource_paths": ["/open"], - "_uploader_id": 42, - "_bucket": "s3://gen3-bucket" - } -}` - testResponse := http.Response{ + "record": { + "file_name": "test-file", + "size": 120, + "did": "000000-0000000-0000000-000000" + } + }` + resp := &http.Response{ StatusCode: 200, Body: io.NopCloser(strings.NewReader(testBody)), } - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). - Return(true, nil) - mockGen3Interface. - EXPECT(). - GetResponse(common.ShepherdEndpoint+"/objects/"+testGUID, "GET", "", nil). - Return("", &testResponse, nil) - // ---------- - - // Expect AskGen3ForFileInfo to return the correct filename and filesize from shepherd. - fileName, fileSize := g3cmd.AskGen3ForFileInfo(mockGen3Interface, testGUID, "", "", "original", true, &[]g3cmd.RenamedOrSkippedFileInfo{}) - if fileName != testFileName { - t.Errorf("Wanted filename %v, got %v", testFileName, fileName) + + // Expect authenticated request to Shepherd + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { + if !strings.HasSuffix(rb.Url, "/objects/"+testGUID) { + t.Errorf("Expected request to Shepherd objects endpoint, got %s", rb.Url) + } + return resp, nil + }) + + // Optional: logger + mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() + + skipped := []download.RenamedOrSkippedFileInfo{} + info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + if err != nil { + t.Error(err) } - if fileSize != testFileSize { - t.Errorf("Wanted filesize %v, got %v", testFileSize, fileSize) + + if info.Name != testFileName { + t.Errorf("Wanted filename %v, got %v", testFileName, info.Name) + } + if info.Size != testFileSize { + t.Errorf("Wanted filesize %v, got %v", testFileSize, info.Size) + } + if len(skipped) != 0 { + t.Errorf("Expected no skipped files, got %v", skipped) } } - -// If there's an error while getting the filename from Shepherd, add the guid -// to *renamedFiles, which tracks which files have errored. func Test_askGen3ForFileInfo_withShepherd_shepherdError(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Expect AskGen3ForFileInfo to call indexd looking for testGUID: - // Respond with an error. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). - Return(true, nil) - mockGen3Interface. - EXPECT(). - GetResponse(common.ShepherdEndpoint+"/objects/"+testGUID, "GET", "", nil). - Return("", nil, fmt.Errorf("Error getting metadata from Shepherd")) - // ---------- - - mockGen3Interface. - EXPECT(). + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + + dummyCred := &conf.Credential{} + mockGen3.EXPECT().GetCredential().Return(dummyCred).AnyTimes() + + // 1. Shepherd is available + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). + Return(true, nil). + Times(1) + + // 2. Shepherd request fails → triggers fallback to Indexd + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("Shepherd error")). + Times(1) // only the Shepherd call + + // 3. Fallback: Indexd request also fails (we want to test error handling) + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("Indexd error")). + Times(1) + + // Optional: if it tries to parse nil response from Indexd + mockGen3.EXPECT(). + ParseFenceURLResponse(gomock.Nil()). + Return(api.FenceResponse{}, fmt.Errorf("no response")). + AnyTimes() + + // Logger + mockGen3.EXPECT(). Logger(). - Return(logs.NewTeeLogger("", "test", os.Stdout)). // Or your appropriate dummy logger + Return(logs.NewTeeLogger("", "test", os.Stdout)). AnyTimes() - // Expect AskGen3ForFileInfo to add this file's GUID to the renamedOrSkippedFiles array. - skipped := []g3cmd.RenamedOrSkippedFileInfo{} - fileName, _ := g3cmd.AskGen3ForFileInfo(mockGen3Interface, testGUID, "", "", "original", true, &skipped) - expected := g3cmd.RenamedOrSkippedFileInfo{GUID: testGUID, OldFilename: "N/A", NewFilename: testGUID} - if skipped[0] != expected { - t.Errorf("Wanted skipped files list to contain %v, got %v", expected, skipped) + skipped := []download.RenamedOrSkippedFileInfo{} + info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + if err != nil { + t.Fatal(err) + } + + // Critical fix: check for nil first + if info == nil { + t.Fatal("AskGen3ForFileInfo returned nil when both Shepherd and Indexd failed. Expected fallback FileInfo with Name = GUID") } - // Expect the returned filename to be the file's GUID. - if fileName != testGUID { - t.Errorf("Wanted filename %v, got %v", testGUID, fileName) + + if info.Name != testGUID { + t.Errorf("Wanted fallback filename %v, got %v", testGUID, info.Name) + } + + if len(skipped) != 1 { + t.Errorf("Expected exactly 1 skipped file, got %d", len(skipped)) + } else if skipped[0].GUID != testGUID || skipped[0].NewFilename != testGUID { + t.Errorf("Skipped entry mismatch: %+v", skipped[0]) } } -// If Shepherd is not deployed, attempt to get the filename from indexd. func Test_askGen3ForFileInfo_noShepherd(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" testFileName := "test-file" testFileSize := int64(120) + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Expect AskGen3ForFileInfo to call indexd looking for testGUID: respond with a valid file. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). - Return(false, nil) - mockGen3Interface. - EXPECT(). - DoRequestWithSignedHeader(common.IndexdIndexEndpoint+"/"+testGUID, "", nil). - Return(jwt.JsonMessage{FileName: testFileName, Size: testFileSize}, nil) - // ---------- - - mockGen3Interface. - EXPECT(). - Logger(). - Return(logs.NewTeeLogger("", "test", os.Stdout)). // Or your appropriate dummy logger - AnyTimes() + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + + // No Shepherd + mockGen3.EXPECT().CheckForShepherdAPI(gomock.Any()).Return(false, nil) + + // Indexd returns parsed FenceResponse + mockGen3.EXPECT(). + ParseFenceURLResponse(gomock.Any()). + Return(api.FenceResponse{FileName: testFileName, Size: testFileSize}, nil) + + // DoAuthenticatedRequest called for indexd + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(&http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("{}"))}, nil) + + mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() - // Expect AskGen3ForFileInfo to return the correct filename and filesize from indexd. - fileName, fileSize := g3cmd.AskGen3ForFileInfo(mockGen3Interface, testGUID, "", "", "original", true, &[]g3cmd.RenamedOrSkippedFileInfo{}) - if fileName != testFileName { - t.Errorf("Wanted filename %v, got %v", testFileName, fileName) + skipped := []download.RenamedOrSkippedFileInfo{} + info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + if err != nil { + t.Fatal(err) } - if fileSize != testFileSize { - t.Errorf("Wanted filesize %v, got %v", testFileSize, fileSize) + + if info.Name != testFileName { + t.Errorf("Wanted filename %v, got %v", testFileName, info.Name) + } + if info.Size != testFileSize { + t.Errorf("Wanted filesize %v, got %v", testFileSize, info.Size) } } -// If there's an error while getting the filename from indexd, add the guid -// to *renamedFiles, which tracks which files have errored. func Test_askGen3ForFileInfo_noShepherd_indexdError(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Expect AskGen3ForFileInfo to call indexd looking for testGUID: - // Respond with an error. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). - Return(false, nil) - mockGen3Interface. - EXPECT(). - DoRequestWithSignedHeader(common.IndexdIndexEndpoint+"/"+testGUID, "", nil). - Return(jwt.JsonMessage{}, fmt.Errorf("Error downloading file from Indexd")) - // ---------- - mockGen3Interface. - EXPECT(). - Logger(). - Return(logs.NewTeeLogger("", "test", os.Stdout)). // Or your appropriate dummy logger - AnyTimes() + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockGen3.EXPECT().CheckForShepherdAPI(gomock.Any()).Return(false, nil) + + // Indexd request fails + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("Indexd error")) + + mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() + + skipped := []download.RenamedOrSkippedFileInfo{} + info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + if err != nil { + t.Fatal(err) + } - // Expect AskGen3ForFileInfo to add this file's GUID to the renamedOrSkippedFiles array. - skipped := []g3cmd.RenamedOrSkippedFileInfo{} - fileName, _ := g3cmd.AskGen3ForFileInfo(mockGen3Interface, testGUID, "", "", "original", true, &skipped) - expected := g3cmd.RenamedOrSkippedFileInfo{GUID: testGUID, OldFilename: "N/A", NewFilename: testGUID} - if skipped[0] != expected { - t.Errorf("Wanted skipped files list to contain %v, got %v", expected, skipped) + if info.Name != testGUID { + t.Errorf("Wanted fallback filename %v, got %v", testGUID, info.Name) } - // Expect the returned filename to be the file's GUID. - if fileName != testGUID { - t.Errorf("Wanted filename %v, got %v", testGUID, fileName) + if len(skipped) != 1 || skipped[0].GUID != testGUID { + t.Errorf("Expected skipped entry for GUID: %v", skipped) } } diff --git a/tests/functions_test.go b/tests/functions_test.go index d1e0982..8c20d33 100755 --- a/tests/functions_test.go +++ b/tests/functions_test.go @@ -2,253 +2,251 @@ package tests import ( "bytes" - "fmt" "io" "net/http" "reflect" "strings" "testing" - "github.com/calypr/data-client/client/jwt" + "github.com/calypr/data-client/client/api" + "github.com/calypr/data-client/client/conf" "github.com/calypr/data-client/client/mocks" + req "github.com/calypr/data-client/client/request" "go.uber.org/mock/gomock" ) -func TestDoRequestWithSignedHeaderNoProfile(t *testing.T) { - +func TestDoAuthenticatedRequest_NoProfile(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig} - - profileConfig := jwt.Credential{KeyId: "", APIKey: "", AccessToken: "", APIEndpoint: ""} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - _, err := testFunction.DoRequestWithSignedHeader(&profileConfig, "/user/data/download/test_uuid", "", nil) + emptyCred := &conf.Credential{} + // Expect error when credentials are incomplete + _, err := mockFuncs.DoAuthenticatedRequest(emptyCred, &req.RequestBuilder{ + Url: "/user/data/download/test_uuid", + }) if err == nil { - t.Fail() + t.Error("Expected error due to missing credentials, but got nil") } } -func TestDoRequestWithSignedHeaderGoodToken(t *testing.T) { +func TestDoAuthenticatedRequest_GoodToken(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) + + cred := &conf.Credential{ + APIKey: "fake_api_key", + AccessToken: "non_expired_token", + APIEndpoint: "https://example.com", + } - profileConfig := jwt.Credential{Profile: "test", KeyId: "", APIKey: "fake_api_key", AccessToken: "non_expired_token", APIEndpoint: "http://www.test.com", UseShepherd: "false", MinShepherdVersion: ""} mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString("{\"url\": \"http://www.test.com/user/data/download/test_uuid\"}")), StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(`{"url": "https://signed.url"}`)), } - mockRequest.EXPECT().MakeARequest("GET", "http://www.test.com/user/data/download/test_uuid", "non_expired_token", "", gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(1) + mockFuncs.EXPECT(). + DoAuthenticatedRequest(cred, gomock.Any()). + Return(mockedResp, nil). + Times(1) - _, err := testFunction.DoRequestWithSignedHeader(&profileConfig, "/user/data/download/test_uuid", "", nil) + resp, err := mockFuncs.DoAuthenticatedRequest(cred, &req.RequestBuilder{ + Url: "/user/data/download/test_uuid", + }) if err != nil { - t.Fail() + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } } -func TestDoRequestWithSignedHeaderCreateNewToken(t *testing.T) { - +func TestDoAuthenticatedRequest_MissingToken_CreatesNew(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) + mockConfig := mocks.NewMockManagerInterface(mockCtrl) - profileConfig := jwt.Credential{KeyId: "", APIKey: "fake_api_key", AccessToken: "", APIEndpoint: "http://www.test.com"} - mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString("{\"url\": \"www.test.com/user/data/download/\"}")), - StatusCode: 200, + // Assuming Functions struct has both Config and Functions fields + testFunction := &api.Functions{ + Config: mockConfig, } - mockConfig.EXPECT().UpdateConfigFile(profileConfig).Times(1) - mockRequest.EXPECT().RequestNewAccessToken("http://www.test.com/user/credentials/api/access_token", &profileConfig).Return(nil).Times(1) - mockRequest.EXPECT().MakeARequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(1) - - _, err := testFunction.DoRequestWithSignedHeader(&profileConfig, "/user/data/download/test_uuid", "", nil) - - if err != nil { - t.Fail() + cred := &conf.Credential{ + APIKey: "fake_api_key", + AccessToken: "", // empty → should trigger token creation + APIEndpoint: "https://example.com", } -} -func TestDoRequestWithSignedHeaderRefreshToken(t *testing.T) { - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} - - profileConfig := jwt.Credential{KeyId: "", APIKey: "fake_api_key", AccessToken: "expired_token", APIEndpoint: "http://www.test.com"} mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString("{\"url\": \"www.test.com/user/data/download/\"}")), - StatusCode: 401, + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(`{"url": "https://signed.url"}`)), } - mockConfig.EXPECT().UpdateConfigFile(profileConfig).Times(1) - mockRequest.EXPECT().RequestNewAccessToken("http://www.test.com/user/credentials/api/access_token", &profileConfig).Return(nil).Times(1) - mockRequest.EXPECT().MakeARequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(2) + // Expect Save to be called if new token is generated and saved + mockConfig.EXPECT().Save(cred).AnyTimes() - _, err := testFunction.DoRequestWithSignedHeader(&profileConfig, "/user/data/download/test_uuid", "", nil) + mockFuncs.EXPECT(). + DoAuthenticatedRequest(cred, gomock.Any()). + Return(mockedResp, nil). + Times(1) - if err != nil && !strings.Contains(err.Error(), "401") { - t.Fail() - } + _, err := testFunction.DoAuthenticatedRequest(cred, &req.RequestBuilder{ + Url: "/user/data/download/test_uuid", + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } } -func TestCheckPrivilegesNoProfile(t *testing.T) { - +func TestCheckPrivileges_NoProfile(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig} - - profileConfig := jwt.Credential{KeyId: "", APIKey: "", AccessToken: "", APIEndpoint: ""} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - _, _, err := testFunction.CheckPrivileges(&profileConfig) + emptyCred := &conf.Credential{} + _, err := mockFuncs.CheckPrivileges(emptyCred) if err == nil { - t.Errorf("Expected an error on missing credentials in configuration, but not received") + t.Error("Expected error when credentials are missing, got nil") } } -func TestCheckPrivilegesNoAccess(t *testing.T) { - +func TestCheckPrivileges_NoAccess(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - profileConfig := jwt.Credential{KeyId: "", APIKey: "fake_api_key", AccessToken: "non_expired_token", APIEndpoint: "http://www.test.com"} - mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString("{\"project_access\": {}}")), - StatusCode: 200, + cred := &conf.Credential{ + APIKey: "fake_api_key", + AccessToken: "valid_token", + APIEndpoint: "https://example.com", } - mockRequest.EXPECT().MakeARequest("GET", "http://www.test.com/user/user", "non_expired_token", "", gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(1) - - _, receivedAccess, err := testFunction.CheckPrivileges(&profileConfig) + userResp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"project_access": {}}`)), + } - expectedAccess := make(map[string]any) + mockFuncs.EXPECT(). + DoAuthenticatedRequest(cred, gomock.Any()). + Return(userResp, nil) + privileges, err := mockFuncs.CheckPrivileges(cred) if err != nil { - t.Errorf("Expected no errors, received an error \"%v\"", err) - } else if !reflect.DeepEqual(receivedAccess, expectedAccess) { - t.Errorf("Expected no user access, received %v", receivedAccess) + t.Fatalf("Unexpected error: %v", err) } -} -func TestCheckPrivilegesGrantedAccess(t *testing.T) { + expected := make(map[string]any) + if !reflect.DeepEqual(privileges, expected) { + t.Errorf("Expected empty privileges, got %v", privileges) + } +} +func TestCheckPrivileges_GrantedAccess_ProjectAccess(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - profileConfig := jwt.Credential{KeyId: "", APIKey: "fake_api_key", AccessToken: "non_expired_token", APIEndpoint: "http://www.test.com"} + cred := &conf.Credential{ + APIKey: "fake_api_key", + AccessToken: "valid_token", + APIEndpoint: "https://example.com", + } - grantedAccessJSON := `{ - "project_access": - { - "test_project": ["read", "create","read-storage","update","delete"] - } - }` + jsonBody := `{ + "project_access": { + "test_project": ["read", "create", "read-storage", "update", "delete"] + } + }` - mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString(grantedAccessJSON)), + userResp := &http.Response{ StatusCode: 200, + Body: io.NopCloser(strings.NewReader(jsonBody)), } - mockRequest.EXPECT().MakeARequest("GET", "http://www.test.com/user/user", "non_expired_token", "", gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(1) + mockFuncs.EXPECT(). + DoAuthenticatedRequest(cred, gomock.Any()). + Return(userResp, nil) - _, expectedAccess, err := testFunction.CheckPrivileges(&profileConfig) + privileges, err := mockFuncs.CheckPrivileges(cred) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - receivedAccess := make(map[string]any) - receivedAccess["test_project"] = []any{ - "read", - "create", - "read-storage", - "update", - "delete"} + expected := map[string]any{ + "test_project": []any{"read", "create", "read-storage", "update", "delete"}, + } - if err != nil { - t.Errorf("Expected no errors, received an error \"%v\"", err) - } else if !reflect.DeepEqual(expectedAccess, receivedAccess) { - t.Errorf(`Expected user access and received user access are not the same. - Expected: %v - Received: %v`, expectedAccess, receivedAccess) + if !reflect.DeepEqual(privileges, expected) { + t.Errorf("Privileges mismatch.\nExpected: %v\nGot: %v", expected, privileges) } } -// If both `authz` and `project_access` section exists, `authz` takes precedence -func TestCheckPrivilegesGrantedAccessAuthz(t *testing.T) { - +func TestCheckPrivileges_GrantedAccess_AuthzTakesPrecedence(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockConfig := mocks.NewMockConfigureInterface(mockCtrl) - mockRequest := mocks.NewMockRequestInterface(mockCtrl) - testFunction := &jwt.Functions{Config: mockConfig, Request: mockRequest} + mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - profileConfig := jwt.Credential{KeyId: "", APIKey: "fake_api_key", AccessToken: "non_expired_token", APIEndpoint: "http://www.test.com"} + cred := &conf.Credential{ + APIKey: "fake_api_key", + AccessToken: "valid_token", + APIEndpoint: "https://example.com", + } - grantedAccessJSON := `{ + jsonBody := `{ "authz": { - "test_project":[ - {"method":"create", "service":"*"}, - {"method":"delete", "service":"*"}, - {"method":"read", "service":"*"}, - {"method":"read-storage", "service":"*"}, - {"method":"update", "service":"*"}, - {"method":"upload", "service":"*"} + "test_project": [ + {"method": "create", "service": "*"}, + {"method": "delete", "service": "*"}, + {"method": "read", "service": "*"}, + {"method": "read-storage", "service": "*"}, + {"method": "update", "service": "*"}, + {"method": "upload", "service": "*"} ] }, "project_access": { - "test_project": ["read", "create","read-storage","update","delete"] + "test_project": ["read", "create", "read-storage", "update", "delete"] } }` - mockedResp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString(grantedAccessJSON)), + userResp := &http.Response{ StatusCode: 200, + Body: io.NopCloser(strings.NewReader(jsonBody)), } - mockRequest.EXPECT().MakeARequest("GET", "http://www.test.com/user/user", "non_expired_token", "", gomock.Any(), gomock.Any(), false).Return(mockedResp, nil).Times(1) + mockFuncs.EXPECT(). + DoAuthenticatedRequest(cred, gomock.Any()). + Return(userResp, nil) - _, expectedAccess, err := testFunction.CheckPrivileges(&profileConfig) + privileges, err := mockFuncs.CheckPrivileges(cred) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - receivedAccess := make(map[string]any) - receivedAccess["test_project"] = []map[string]any{ - {"method": "create", "service": "*"}, - {"method": "delete", "service": "*"}, - {"method": "read", "service": "*"}, - {"method": "read-storage", "service": "*"}, - {"method": "update", "service": "*"}, - {"method": "upload", "service": "*"}, + expected := map[string]any{ + "test_project": []any{ + map[string]any{"method": "create", "service": "*"}, + map[string]any{"method": "delete", "service": "*"}, + map[string]any{"method": "read", "service": "*"}, + map[string]any{"method": "read-storage", "service": "*"}, + map[string]any{"method": "update", "service": "*"}, + map[string]any{"method": "upload", "service": "*"}, + }, } - if err != nil { - t.Errorf("Expected no errors, received an error \"%v\"", err) - // don't use DeepEqual since expectedAccess is []interface {} and receivedAccess is []map[string]interface {}, just check for contents - } else if fmt.Sprint(expectedAccess) != fmt.Sprint(receivedAccess) { - t.Errorf(`Expected user access and received user access are not the same. - Expected: %v - Received: %v`, expectedAccess, receivedAccess) + if !reflect.DeepEqual(privileges, expected) { + t.Errorf("Authz privileges should take precedence.\nExpected: %v\nGot: %v", expected, privileges) } } diff --git a/tests/utils_test.go b/tests/utils_test.go index ae2c387..758fb24 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -1,241 +1,230 @@ package tests import ( - "encoding/json" "fmt" "io" "net/http" "strings" "testing" + "github.com/calypr/data-client/client/api" "github.com/calypr/data-client/client/common" - g3cmd "github.com/calypr/data-client/client/g3cmd" - "github.com/calypr/data-client/client/jwt" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/download" "github.com/calypr/data-client/client/mocks" + req "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/client/upload" "go.uber.org/mock/gomock" ) -// Expect GetDownloadResponse to: -// 1. get the file download URL from Shepherd if it's deployed -// 2. add the file download URL to the FileDownloadResponseObject -// 3. GET the file download URL, and add the response to the FileDownloadResponseObject func TestGetDownloadResponse_withShepherd(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" testFilename := "test-file" + mockDownloadURL := "https://example.com/example.pfb" + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Mock the request that checks if Shepherd is deployed. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + + // Mock credential + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + + // Shepherd is deployed + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). Return(true, nil) - // Mock the request to Shepherd for the download URL of this file. - mockDownloadURL := "https://example.com/example.pfb" - downloadURLBody := fmt.Sprintf(`{ - "url": "%v" - }`, mockDownloadURL) - mockDownloadURLResponse := http.Response{ + // Shepherd download URL response + downloadURLBody := fmt.Sprintf(`{"url": "%s"}`, mockDownloadURL) + shepherdResp := &http.Response{ StatusCode: 200, Body: io.NopCloser(strings.NewReader(downloadURLBody)), } - mockGen3Interface. - EXPECT(). - GetResponse(common.ShepherdEndpoint+"/objects/"+testGUID+"/download", "GET", "", nil). - Return("", &mockDownloadURLResponse, nil) - // Mock the request for the file at mockDownloadURL. - mockFileResponse := http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("It work")), - } - mockGen3Interface. - EXPECT(). - MakeARequest(http.MethodGet, mockDownloadURL, "", "", map[string]string{}, nil, true). - Return(&mockFileResponse, nil) - // ---------- + // Expect DoAuthenticatedRequest to Shepherd /objects/{guid}/download + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { + if !strings.HasSuffix(rb.Url, "/objects/"+testGUID+"/download") { + t.Errorf("Expected Shepherd download URL request, got %s", rb.Url) + } + return shepherdResp, nil + }) + + // ParseFenceURLResponse to extract URL + mockGen3.EXPECT(). + ParseFenceURLResponse(shepherdResp). + Return(api.FenceResponse{URL: mockDownloadURL}, nil) + + // We assume the implementation uses http.Client directly for presigned URLs (common pattern) + // So no mock needed here unless you inject an HTTP client — this part may be unmocked. + // If you have a mockable HTTP doer, adjust accordingly. mockFDRObj := common.FileDownloadResponseObject{ Filename: testFilename, GUID: testGUID, Range: 0, } - err := g3cmd.GetDownloadResponse(mockGen3Interface, &mockFDRObj, "") + + err := download.GetDownloadResponse(mockGen3, &mockFDRObj, "") if err != nil { - t.Error(err) + t.Fatalf("Unexpected error: %v", err) } + if mockFDRObj.URL != mockDownloadURL { - t.Errorf("Wanted the DownloadPath to be set to %v, got %v", mockDownloadURL, mockFDRObj.DownloadPath) - } - if mockFDRObj.Response != &mockFileResponse { - t.Errorf("Wanted download response to be %v, got %v", mockFileResponse, mockFDRObj.Response) + t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.URL) } + + // Note: Response may be fetched outside the interface (direct http.Get), so this check might not work unless injected. + // If you want to fully mock it, consider injecting a downloader. } -// Expect GetDownloadResponse to: -// 1. get the file download URL from Fence if Shepherd is not deployed -// 2. add the file download URL to the FileDownloadResponseObject -// 3. GET the file download URL, and add the response to the FileDownloadResponseObject func TestGetDownloadResponse_noShepherd(t *testing.T) { - // -- SETUP -- testGUID := "000000-0000000-0000000-000000" testFilename := "test-file" + mockDownloadURL := "https://example.com/example.pfb" + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Mock the request that checks if Shepherd is deployed. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). - Return(false, nil) + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() - // Mock the request to Fence for the download URL of this file. - mockDownloadURL := "https://example.com/example.pfb" - mockDownloadURLResponse := jwt.JsonMessage{ - URL: mockDownloadURL, - } - mockGen3Interface. - EXPECT(). - DoRequestWithSignedHeader(common.FenceDataDownloadEndpoint+"/"+testGUID, "", nil). - Return(mockDownloadURLResponse, nil) + // No Shepherd + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). + Return(false, nil) - // Mock the request for the file at mockDownloadURL. - mockFileResponse := http.Response{ + // Fence returns presigned URL + fenceResp := &http.Response{ StatusCode: 200, - Body: io.NopCloser(strings.NewReader("It work")), + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"url": "%s"}`, mockDownloadURL))), } - mockGen3Interface. - EXPECT(). - MakeARequest(http.MethodGet, mockDownloadURL, "", "", map[string]string{}, nil, true). - Return(&mockFileResponse, nil) - // ---------- + + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(fenceResp, nil) + + mockGen3.EXPECT(). + ParseFenceURLResponse(fenceResp). + Return(api.FenceResponse{URL: mockDownloadURL}, nil) mockFDRObj := common.FileDownloadResponseObject{ Filename: testFilename, GUID: testGUID, Range: 0, } - err := g3cmd.GetDownloadResponse(mockGen3Interface, &mockFDRObj, "") + + err := download.GetDownloadResponse(mockGen3, &mockFDRObj, "") if err != nil { - t.Error(err) + t.Fatalf("Unexpected error: %v", err) } + if mockFDRObj.URL != mockDownloadURL { - t.Errorf("Wanted the DownloadPath to be set to %v, got %v", mockDownloadURL, mockFDRObj.DownloadPath) - } - if mockFDRObj.Response != &mockFileResponse { - t.Errorf("Wanted download response to be %v, got %v", mockFileResponse, mockFDRObj.Response) + t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.URL) } } -// If Shepherd is not deployed, expect GeneratePresignedURL to hit fence's data upload -// endpoint and return the presigned URL and guid. func TestGeneratePresignedURL_noShepherd(t *testing.T) { - // -- SETUP -- testFilename := "test-file" testBucketname := "test-bucket" + mockPresignedURL := "https://example.com/example.pfb" + mockGUID := "000000-0000000-0000000-000000" + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Mock the request that checks if Shepherd is deployed. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + + // No Shepherd + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). Return(false, nil) - // Mock the request to Fence's data upload endpoint to create a presigned url for this file name. - expectedReqBody := []byte(fmt.Sprintf(`{"file_name":"%v","bucket":"%v"}`, testFilename, testBucketname)) - mockPresignedURL := "https://example.com/example.pfb" - mockGUID := "000000-0000000-0000000-000000" - mockUploadURLResponse := jwt.JsonMessage{ - URL: mockPresignedURL, - GUID: mockGUID, + // Fence upload endpoint response + fenceResp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf( + `{"url": "%s", "guid": "%s"}`, mockPresignedURL, mockGUID, + ))), } - mockGen3Interface. - EXPECT(). - DoRequestWithSignedHeader(common.FenceDataUploadEndpoint, "application/json", expectedReqBody). - Return(mockUploadURLResponse, nil) - // ---------- - url, guid, err := g3cmd.GeneratePresignedURL(mockGen3Interface, testFilename, common.FileMetadata{}, testBucketname) + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + Return(fenceResp, nil) + + mockGen3.EXPECT(). + ParseFenceURLResponse(fenceResp). + Return(api.FenceResponse{ + URL: mockPresignedURL, + GUID: mockGUID, + }, nil) + + resp, err := upload.GeneratePresignedURL(mockGen3, testFilename, common.FileMetadata{}, testBucketname) if err != nil { - t.Error(err) + t.Fatalf("Unexpected error: %v", err) } - if url != mockPresignedURL { - t.Errorf("Wanted the presignedURL to be set to %v, got %v", mockPresignedURL, url) + + if resp.URL != mockPresignedURL { + t.Errorf("Wanted URL %s, got %s", mockPresignedURL, resp.URL) } - if guid != mockGUID { - t.Errorf("Wanted generated GUID to be %v, got %v", mockGUID, guid) + if resp.GUID != mockGUID { + t.Errorf("Wanted GUID %s, got %s", mockGUID, resp.GUID) } } -// If Shepherd is deployed, expect GeneratePresignedURL to hit Shepherd's data upload -// endpoint with the file name and file metadata. GeneratePresignedURL should then -// return the guid and file name that it gets from the endpoint. func TestGeneratePresignedURL_withShepherd(t *testing.T) { - // -- SETUP -- testFilename := "test-file" testBucketname := "test-bucket" + mockPresignedURL := "https://example.com/example.pfb" + mockGUID := "000000-0000000-0000000-000000" + testMetadata := common.FileMetadata{ Aliases: []string{"test-alias-1", "test-alias-2"}, Authz: []string{"authz-resource-1", "authz-resource-2"}, Metadata: map[string]any{"arbitrary": "metadata"}, } + mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - // Mock the request that checks if Shepherd is deployed. - mockGen3Interface := mocks.NewMockGen3Interface(mockCtrl) - mockGen3Interface. - EXPECT(). - CheckForShepherdAPI(). + mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + + // Shepherd is deployed + mockGen3.EXPECT(). + CheckForShepherdAPI(gomock.Any()). Return(true, nil) - // Mock the request to Fence's data upload endpoint to create a presigned url for this file name. - expectedReq := g3cmd.ShepherdInitRequestObject{ - Filename: testFilename, - Authz: struct { - Version string `json:"version"` - ResourcePaths []string `json:"resource_paths"` - }{ - "0", - testMetadata.Authz, - }, - Aliases: testMetadata.Aliases, - Metadata: testMetadata.Metadata, - } - expectedReqBody, err := json.Marshal(expectedReq) - if err != nil { - t.Error(err) - } - mockPresignedURL := "https://example.com/example.pfb" - mockGUID := "000000-0000000-0000000-000000" - presignedURLBody := fmt.Sprintf(`{ - "guid": "%v", - "upload_url": "%v" - }`, mockGUID, mockPresignedURL) - mockUploadURLResponse := http.Response{ + // Shepherd returns GUID and upload_url + shepherdResp := &http.Response{ StatusCode: 201, - Body: io.NopCloser(strings.NewReader(presignedURLBody)), - } - mockGen3Interface. - EXPECT(). - GetResponse(common.ShepherdEndpoint+"/objects", "POST", "", expectedReqBody). - Return("", &mockUploadURLResponse, nil) - // ---------- - - url, guid, err := g3cmd.GeneratePresignedURL(mockGen3Interface, testFilename, testMetadata, testBucketname) + Body: io.NopCloser(strings.NewReader(fmt.Sprintf( + `{"guid": "%s", "upload_url": "%s"}`, mockGUID, mockPresignedURL, + ))), + } + + mockGen3.EXPECT(). + DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { + if rb.Method != "POST" || !strings.HasSuffix(rb.Url, "/objects") { + t.Errorf("Expected POST to /objects, got %s %s", rb.Method, rb.Url) + } + // Optionally validate body here if needed + return shepherdResp, nil + }) + + respObj, err := upload.GeneratePresignedURL(mockGen3, testFilename, testMetadata, testBucketname) if err != nil { - t.Error(err) + t.Fatalf("Unexpected error: %v", err) } - if url != mockPresignedURL { - t.Errorf("Wanted the presignedURL to be set to %v, got %v", mockPresignedURL, url) + + if respObj.URL != mockPresignedURL { + t.Errorf("Wanted URL %s, got %s", mockPresignedURL, respObj.URL) } - if guid != mockGUID { - t.Errorf("Wanted generated GUID to be %v, got %v", mockGUID, guid) + if respObj.GUID != mockGUID { + t.Errorf("Wanted GUID %s, got %s", mockGUID, respObj.GUID) } } From 1130f52d2cc1d006f836ec490732f8152d3148a9 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Mon, 29 Dec 2025 13:46:17 -0800 Subject: [PATCH 03/14] update client --- client/api/gen3.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/api/gen3.go b/client/api/gen3.go index 1746638..ff3ecde 100644 --- a/client/api/gen3.go +++ b/client/api/gen3.go @@ -47,6 +47,7 @@ type FunctionInterface interface { ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) ExportCredential(ctx context.Context, cred *conf.Credential) error + NewAccessToken(ctx context.Context) error } func (f *Functions) NewAccessToken(ctx context.Context) error { From 7289156af3e1c7d33843d5c4e508f97a54742623 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Mon, 29 Dec 2025 15:34:36 -0800 Subject: [PATCH 04/14] add back bug fixes --- client/upload/singleFile.go | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go index 527fd4c..d0ff48a 100644 --- a/client/upload/singleFile.go +++ b/client/upload/singleFile.go @@ -49,20 +49,25 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str } filename := filepath.Base(filePath) if _, err := os.Stat(filePath); os.IsNotExist(err) { + if enableLogs { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() return fmt.Errorf("[ERROR] The file you specified \"%s\" does not exist locally\n", filePath) } file, err := os.Open(filePath) if err != nil { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() + if enableLogs { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) g3i.Logger().Println("File open error: " + err.Error()) + return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", filePath, err.Error()) } defer file.Close() @@ -71,11 +76,12 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str furObject, err = generateUploadRequest(ctx, g3i, furObject, file, nil) if err != nil { - file.Close() + if enableLogs { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) } @@ -86,12 +92,15 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str _, err = uploadPart(ctx, furObject.PresignedURL, bytes.NewReader(jsonData), int64(len(jsonData))) if err != nil { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) + if enableLogs { + sb := g3i.Logger().Scoreboard() + sb.IncrementSB(len(sb.Counts)) + } return fmt.Errorf("[ERROR] Error uploading file %s: %s\n", filePath, err.Error()) - } else { + } + if enableLogs { g3i.Logger().Scoreboard().IncrementSB(0) + g3i.Logger().Scoreboard().PrintSB() } - g3i.Logger().Scoreboard().PrintSB() return nil } From a95f44240b141fd907d6805017e6413dc3394775 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Tue, 30 Dec 2025 08:54:52 -0800 Subject: [PATCH 05/14] bugfix make single part / file uploader work --- client/upload/singleFile.go | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go index d0ff48a..b3f6b7a 100644 --- a/client/upload/singleFile.go +++ b/client/upload/singleFile.go @@ -1,9 +1,7 @@ package upload import ( - "bytes" "context" - "encoding/json" "errors" "fmt" "os" @@ -48,15 +46,6 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str filePath = filePaths[0] } filename := filepath.Base(filePath) - if _, err := os.Stat(filePath); os.IsNotExist(err) { - if enableLogs { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() - } - g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - return fmt.Errorf("[ERROR] The file you specified \"%s\" does not exist locally\n", filePath) - } file, err := os.Open(filePath) if err != nil { @@ -72,6 +61,12 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str } defer file.Close() + fi, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + fileSize := fi.Size() + furObject := common.FileUploadRequestObject{FilePath: filePath, Filename: filename, GUID: guid, Bucket: bucketName} furObject, err = generateUploadRequest(ctx, g3i, furObject, file, nil) @@ -85,19 +80,15 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) } - jsonData, err := json.Marshal(furObject) - if err != nil { - return fmt.Errorf("failed to marshal furObject: %w", err) - } - _, err = uploadPart(ctx, furObject.PresignedURL, bytes.NewReader(jsonData), int64(len(jsonData))) + _, err = uploadPart(ctx, furObject.PresignedURL, file, fileSize) if err != nil { if enableLogs { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) + g3i.Logger().Scoreboard().IncrementSB(1) // Increment failure } - return fmt.Errorf("[ERROR] Error uploading file %s: %s\n", filePath, err.Error()) + return fmt.Errorf("[ERROR] Error uploading file content for %s: %w", filePath, err) } + if enableLogs { g3i.Logger().Scoreboard().IncrementSB(0) g3i.Logger().Scoreboard().PrintSB() From 0aaac8b886687703b269f0483287aeea4bb6b082 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:16:33 -0800 Subject: [PATCH 06/14] Preserve error chain in upload error wrapping (#16) * fix missing error check * Initial plan * Fix error wrapping to preserve error chain using fmt.Errorf with %w Co-authored-by: bwalsh <47808+bwalsh@users.noreply.github.com> * Update PR description with completion status Co-authored-by: bwalsh <47808+bwalsh@users.noreply.github.com> * Add /bin/ to .gitignore and remove accidentally committed binary Co-authored-by: bwalsh <47808+bwalsh@users.noreply.github.com> --------- Co-authored-by: Brian Walsh Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: bwalsh <47808+bwalsh@users.noreply.github.com> --- .gitignore | 1 + client/upload/request.go | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 849f232..6aa2b55 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ # Build artifacts /build/ +/bin/ checksums.txt \ No newline at end of file diff --git a/client/upload/request.go b/client/upload/request.go index 80b7744..894db52 100644 --- a/client/upload/request.go +++ b/client/upload/request.go @@ -101,10 +101,13 @@ func generateUploadRequest(ctx context.Context, g3 client.Gen3Interface, furObje Method: http.MethodGet, }, ) + if err != nil { + return furObject, fmt.Errorf("Upload error: %w", err) + } msg, err := g3.ParseFenceURLResponse(resp) if err != nil && !strings.Contains(err.Error(), "No GUID found") { - return furObject, errors.New("Upload error: " + err.Error()) + return furObject, fmt.Errorf("Upload error: %w", err) } if msg.URL == "" { return furObject, errors.New("Upload error: error in generating presigned URL for " + furObject.Filename) From 91b365cd3c3ec933460de15e9a77b708a04730b0 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 15 Jan 2026 08:58:27 -0800 Subject: [PATCH 07/14] rework retry to be fully integrated with refresh --- client/request/auth.go | 39 +++++++++++++-------------------------- client/request/request.go | 39 +++++++++++++++++++-------------------- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/client/request/auth.go b/client/request/auth.go index 7d08f65..eb87829 100644 --- a/client/request/auth.go +++ b/client/request/auth.go @@ -66,38 +66,25 @@ type AuthTransport struct { } func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.RLock() + token := t.Cred.AccessToken + t.mu.RUnlock() - resp, err := t.Base.RoundTrip(req) - if err != nil { - return nil, err - } - - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadGateway { - resp.Body.Close() - - newToken, refreshErr := t.tryRefresh(req.Context()) - if refreshErr != nil { - return nil, refreshErr - } - - retryReq := req.Clone(req.Context()) - retryReq.Header.Set("Authorization", "Bearer "+newToken) - return t.Base.RoundTrip(retryReq) - } - - return resp, nil + // Just add the header and pass it down + req.Header.Set("Authorization", "Bearer "+token) + return t.Base.RoundTrip(req) } -func (t *AuthTransport) tryRefresh(ctx context.Context) (string, error) { - // Only one goroutine can enter this block +func (t *AuthTransport) refreshOnce(ctx context.Context) error { t.refreshMu.Lock() defer t.refreshMu.Unlock() - if err := t.NewAccessToken(ctx); err != nil { - return "", err + t.mu.RLock() + if t.Cred.AccessToken != "" { + t.mu.RUnlock() + return nil } + t.mu.RUnlock() - t.mu.RLock() - defer t.mu.RUnlock() - return t.Cred.AccessToken, nil + return t.NewAccessToken(ctx) } diff --git a/client/request/request.go b/client/request/request.go index a1076e7..c91d9a0 100644 --- a/client/request/request.go +++ b/client/request/request.go @@ -5,9 +5,7 @@ package request import ( "context" "errors" - "net" "net/http" - "time" "github.com/calypr/data-client/client/conf" "github.com/calypr/data-client/client/logs" @@ -16,7 +14,6 @@ import ( type Request struct { Logs logs.Logger - Ctx context.Context RetryClient *retryablehttp.Client } @@ -30,20 +27,7 @@ func NewRequestInterface( cred *conf.Credential, conf conf.ManagerInterface, ) RequestInterface { - retryClient := retryablehttp.NewClient() - retryClient.RetryMax = 3 - retryClient.Logger = logger - - baseTransport := &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 100, - TLSHandshakeTimeout: 5 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - } + baseTransport := &http.Transport{ /* ... your config ... */ } authTransport := &AuthTransport{ Base: baseTransport, @@ -51,9 +35,24 @@ func NewRequestInterface( Manager: conf, } - retryClient.HTTPClient = &http.Client{ - Timeout: 0, - Transport: authTransport, // The outer shell is now AuthTransport + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = 3 + retryClient.Logger = logger + retryClient.HTTPClient.Transport = authTransport + + retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { + shouldRetry, retryErr := + retryablehttp.DefaultRetryPolicy(ctx, resp, err) + + if resp != nil && + (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadGateway) { + err := authTransport.refreshOnce(ctx) + if err != nil { + return false, err + } + return true, nil + } + return shouldRetry, retryErr } return &Request{ From d42f12d104bcce5bea89ee81d81a227292870d37 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Thu, 15 Jan 2026 09:03:29 -0800 Subject: [PATCH 08/14] fix logic --- client/request/request.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/client/request/request.go b/client/request/request.go index c91d9a0..d27af3d 100644 --- a/client/request/request.go +++ b/client/request/request.go @@ -5,7 +5,9 @@ package request import ( "context" "errors" + "net" "net/http" + "time" "github.com/calypr/data-client/client/conf" "github.com/calypr/data-client/client/logs" @@ -27,18 +29,29 @@ func NewRequestInterface( cred *conf.Credential, conf conf.ManagerInterface, ) RequestInterface { - baseTransport := &http.Transport{ /* ... your config ... */ } + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = 3 + retryClient.Logger = logger + baseTransport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + TLSHandshakeTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } authTransport := &AuthTransport{ Base: baseTransport, Cred: cred, Manager: conf, } - - retryClient := retryablehttp.NewClient() - retryClient.RetryMax = 3 - retryClient.Logger = logger - retryClient.HTTPClient.Transport = authTransport + retryClient.HTTPClient = &http.Client{ + Timeout: 0, + Transport: authTransport, + } retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { shouldRetry, retryErr := From 2b7b0e891d9f72125e3c50542bdfb893aa38885b Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Fri, 16 Jan 2026 07:21:03 -0800 Subject: [PATCH 09/14] fix fatalf bug --- client/request/request.go | 2 +- client/upload/singleFile.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/request/request.go b/client/request/request.go index d27af3d..c16677c 100644 --- a/client/request/request.go +++ b/client/request/request.go @@ -30,7 +30,7 @@ func NewRequestInterface( conf conf.ManagerInterface, ) RequestInterface { retryClient := retryablehttp.NewClient() - retryClient.RetryMax = 3 + retryClient.RetryMax = 5 retryClient.Logger = logger baseTransport := &http.Transport{ DialContext: (&net.Dialer{ diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go index b3f6b7a..cc9a430 100644 --- a/client/upload/singleFile.go +++ b/client/upload/singleFile.go @@ -77,7 +77,7 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str sb.PrintSB() } g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) - g3i.Logger().Fatalf("Error occurred during request generation: %s", err.Error()) + g3i.Logger().Printf("Error occurred during request generation: %s", err.Error()) return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) } From 958f16679e8cf254e89de029f49e79723f00b876 Mon Sep 17 00:00:00 2001 From: matthewpeterkort Date: Fri, 16 Jan 2026 07:52:32 -0800 Subject: [PATCH 10/14] add retry wait bounds --- client/request/request.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/request/request.go b/client/request/request.go index c16677c..585624e 100644 --- a/client/request/request.go +++ b/client/request/request.go @@ -32,6 +32,8 @@ func NewRequestInterface( retryClient := retryablehttp.NewClient() retryClient.RetryMax = 5 retryClient.Logger = logger + retryClient.RetryWaitMin = 5 * time.Second + retryClient.RetryWaitMax = 15 * time.Second baseTransport := &http.Transport{ DialContext: (&net.Dialer{ Timeout: 5 * time.Second, From b5432fe8d14e7d71dc490542f0e72cd1b9356c8c Mon Sep 17 00:00:00 2001 From: Brian Date: Mon, 26 Jan 2026 21:33:13 -0800 Subject: [PATCH 11/14] adds OptimalChunkSize, ProgressCallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MinChunkSize: 10MB * fix:improve-multipart #18 * improve OptimalChunkSize.scaleLinear * adds optimal-chunk-size * table-driven tests GetLfsCustomTransferInt * Address PR review comments: fix naming, formatting, and test coverage * Add logging for git config errors and improve documentation structure * Fix spelling: terrabytes → terabytes * adds ProgressEvent * Fix nil pointer dereferences and duplicate progress events - Added nil check for req.Progress in multipart.go to prevent panic - Removed duplicate progress event in singleFile.go (Finalize already emits) - Fixed concurrent BytesSoFar calculation using atomic counter for monotonicity * adds tests * add noop ProgressCallback --- client/common/constants.go | 57 +++++++- client/common/constants_test.go | 110 ++++++++++++++++ client/common/progress.go | 12 ++ client/common/types.go | 4 + client/download/batch.go | 64 ++++++--- client/download/progress_writer.go | 68 ++++++++++ client/download/progress_writer_test.go | 46 +++++++ client/download/transfer.go | 97 ++++++++++++++ client/download/transfer_test.go | 165 ++++++++++++++++++++++++ client/upload/multipart.go | 26 ++-- client/upload/multipart_test.go | 153 ++++++++++++++++++++++ client/upload/progress_reader.go | 68 ++++++++++ client/upload/progress_reader_test.go | 46 +++++++ client/upload/singleFile.go | 27 +++- client/upload/upload.go | 2 +- client/upload/utils.go | 56 ++++++++ client/upload/utils_test.go | 124 ++++++++++++++++++ cmd/upload-multiple.go | 4 +- cmd/upload-single.go | 4 +- docs/optimal-chunk-size.md | 152 ++++++++++++++++++++++ 20 files changed, 1246 insertions(+), 39 deletions(-) create mode 100644 client/common/constants_test.go create mode 100644 client/common/progress.go create mode 100644 client/download/progress_writer.go create mode 100644 client/download/progress_writer_test.go create mode 100644 client/download/transfer.go create mode 100644 client/download/transfer_test.go create mode 100644 client/upload/multipart_test.go create mode 100644 client/upload/progress_reader.go create mode 100644 client/upload/progress_reader_test.go create mode 100644 client/upload/utils_test.go create mode 100644 docs/optimal-chunk-size.md diff --git a/client/common/constants.go b/client/common/constants.go index aae8f10..6f9bc64 100644 --- a/client/common/constants.go +++ b/client/common/constants.go @@ -1,7 +1,12 @@ package common import ( + "fmt" + "log" "os" + "os/exec" + "strconv" + "strings" "time" ) @@ -14,7 +19,7 @@ const ( MB // GB is gigabytes GB - // TB is terrabytes + // TB is terabytes TB ) const ( @@ -71,12 +76,12 @@ const ( HeaderContentType = "Content-Type" MIMEApplicationJSON = "application/json" - // FileSizeLimit is the maximun single file size for non-multipart upload (5GB) + // FileSizeLimit is the maximum single file size for non-multipart upload (5GB) FileSizeLimit = 5 * GB - // MultipartFileSizeLimit is the maximun single file size for multipart upload (5TB) + // MultipartFileSizeLimit is the maximum single file size for multipart upload (5TB) MultipartFileSizeLimit = 5 * TB - MinMultipartChunkSize = 5 * MB + MinMultipartChunkSize = 10 * MB // MaxRetryCount is the maximum retry number per record MaxRetryCount = 5 @@ -85,5 +90,47 @@ const ( MaxMultipartParts = 10000 MaxConcurrentUploads = 10 MaxRetries = 5 - MinChunkSize = 5 * 1024 * 1024 ) + +var ( + // MinChunkSize is configurable via git config and initialized in init() + MinChunkSize int64 +) + +func init() { + v, err := GetLfsCustomTransferInt("lfs.customtransfer.drs.multipart-min-chunk-size", 10) + if err != nil { + log.Printf("Warning: Could not read git config for multipart-min-chunk-size, using default (10 MB): %v\n", err) + MinChunkSize = int64(10) * MB + return + } + + MinChunkSize = int64(v) * MB +} + +func GetLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { + defaultText := strconv.FormatInt(defaultValue, 10) + // TODO cache or get all the configs at once? + cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) + output, err := cmd.Output() + if err != nil { + return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) + } + + value := strings.TrimSpace(string(output)) + + parsed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return defaultValue, fmt.Errorf("invalid int value for %s: %q", key, value) + } + + if parsed < 0 { + return defaultValue, fmt.Errorf("invalid negative int value for %s: %d", key, parsed) + } + + if parsed == 0 || parsed > 500 { + return defaultValue, fmt.Errorf("invalid int value for %s: %d. Must be between 1 and 500", key, parsed) + } + + return parsed, nil +} diff --git a/client/common/constants_test.go b/client/common/constants_test.go new file mode 100644 index 0000000..8eed0e0 --- /dev/null +++ b/client/common/constants_test.go @@ -0,0 +1,110 @@ +package common + +import ( + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestGetLfsCustomTransferInt(t *testing.T) { + configDir := t.TempDir() + configPath := filepath.Join(configDir, "gitconfig") + + setConfig := func(t *testing.T, key, value string) { + t.Helper() + cmd := exec.Command("git", "config", "--file", configPath, key, value) + if err := cmd.Run(); err != nil { + t.Fatalf("set git config %s=%s: %v", key, value, err) + } + } + + setEnv := func(t *testing.T) { + t.Helper() + t.Setenv("GIT_CONFIG_GLOBAL", configPath) + t.Setenv("GIT_CONFIG_SYSTEM", os.DevNull) + t.Setenv("GIT_CONFIG_NOSYSTEM", "1") + } + + const key = "lfs.customtransfer.drs.multipart-min-chunk-size" + + tests := []struct { + name string + value string + defaultVal int64 + want int64 + wantErr bool + setValue bool + }{ + { + name: "missing uses default", + defaultVal: 10, + want: 10, + wantErr: false, + setValue: false, + }, + { + name: "valid value", + value: "25", + defaultVal: 10, + want: 25, + wantErr: false, + setValue: true, + }, + { + name: "negative value", + value: "-3", + defaultVal: 10, + want: 10, + wantErr: true, + setValue: true, + }, + { + name: "zero value", + value: "0", + defaultVal: 10, + want: 10, + wantErr: true, + setValue: true, + }, + { + name: "over max", + value: "501", + defaultVal: 10, + want: 10, + wantErr: true, + setValue: true, + }, + { + name: "non-integer", + value: "abc", + defaultVal: 10, + want: 10, + wantErr: true, + setValue: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := os.WriteFile(configPath, nil, 0o600); err != nil { + t.Fatalf("reset git config: %v", err) + } + if tt.setValue { + setConfig(t, key, tt.value) + } + setEnv(t) + + got, err := GetLfsCustomTransferInt(key, tt.defaultVal) + if tt.wantErr && err == nil { + t.Fatalf("expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Fatalf("value = %d, want %d", got, tt.want) + } + }) + } +} diff --git a/client/common/progress.go b/client/common/progress.go new file mode 100644 index 0000000..c743e7c --- /dev/null +++ b/client/common/progress.go @@ -0,0 +1,12 @@ +package common + +// ProgressEvent matches the Git LFS custom transfer progress payload. +type ProgressEvent struct { + Event string `json:"event"` + Oid string `json:"oid"` + BytesSoFar int64 `json:"bytesSoFar"` + BytesSinceLast int64 `json:"bytesSinceLast"` +} + +// ProgressCallback emits transfer progress updates. +type ProgressCallback func(ProgressEvent) error diff --git a/client/common/types.go b/client/common/types.go index 5a0ac8d..617bd38 100644 --- a/client/common/types.go +++ b/client/common/types.go @@ -15,8 +15,10 @@ type FileUploadRequestObject struct { Filename string FileMetadata FileMetadata GUID string + OID string PresignedURL string Bucket string `json:"bucket,omitempty"` + Progress ProgressCallback } // FileDownloadResponseObject defines a object for file download @@ -24,12 +26,14 @@ type FileDownloadResponseObject struct { DownloadPath string Filename string GUID string + OID string URL string Range int64 Overwrite bool Skip bool Response *http.Response Writer io.Writer + Progress ProgressCallback } // FileMetadata defines the metadata accepted by the new object management API, Shepherd diff --git a/client/download/batch.go b/client/download/batch.go index de86659..be46051 100644 --- a/client/download/batch.go +++ b/client/download/batch.go @@ -40,7 +40,18 @@ func downloadFiles( // Scoreboard: maxRetries = 0 for now (no retry logic yet) sb := logs.NewSB(0, logger) - p := mpb.New(mpb.WithOutput(os.Stdout)) + useProgressBars := true + for _, fdr := range files { + if fdr.Progress != nil { + useProgressBars = false + break + } + } + + var p *mpb.Progress + if useProgressBars { + p = mpb.New(mpb.WithOutput(os.Stdout)) + } var eg errgroup.Group eg.SetLimit(numParallel) @@ -101,29 +112,46 @@ func downloadFiles( // Progress bar for this file total := fdr.Response.ContentLength + fdr.Range - bar := p.AddBar(total, - mpb.PrependDecorators( - decor.Name(truncateFilename(fdr.Filename, 40)+" "), - decor.CountersKibiByte("% .1f / % .1f"), - ), - mpb.AppendDecorators( - decor.Percentage(), - decor.AverageSpeed(decor.SizeB1024(0), "% .1f"), - ), - ) + var writer io.Writer = file + var bar *mpb.Bar + var tracker *progressWriter + + if useProgressBars { + bar = p.AddBar(total, + mpb.PrependDecorators( + decor.Name(truncateFilename(fdr.Filename, 40)+" "), + decor.CountersKibiByte("% .1f / % .1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), "% .1f"), + ), + ) + + if fdr.Range > 0 { + bar.SetCurrent(fdr.Range) + } - if fdr.Range > 0 { - bar.SetCurrent(fdr.Range) + writer = bar.ProxyWriter(file) + } else if fdr.Progress != nil { + tracker = newProgressWriter(file, fdr.Progress, resolveDownloadOID(*fdr), total) + writer = tracker } - writer := bar.ProxyWriter(file) - _, copyErr := io.Copy(writer, fdr.Response.Body) _ = fdr.Response.Body.Close() _ = file.Close() + if tracker != nil { + if finalizeErr := tracker.Finalize(); finalizeErr != nil && copyErr == nil { + copyErr = finalizeErr + } + } + if copyErr != nil { - bar.Abort(true) + if bar != nil { + bar.Abort(true) + } err = fmt.Errorf("download failed for %s: %w", fdr.Filename, copyErr) return err } @@ -134,7 +162,9 @@ func downloadFiles( // Wait for all downloads _ = eg.Wait() - p.Wait() + if p != nil { + p.Wait() + } // Combine errors var combinedError error diff --git a/client/download/progress_writer.go b/client/download/progress_writer.go new file mode 100644 index 0000000..9ed8ab0 --- /dev/null +++ b/client/download/progress_writer.go @@ -0,0 +1,68 @@ +package download + +import ( + "io" + + "github.com/calypr/data-client/client/common" +) + +type progressWriter struct { + writer io.Writer + onProgress common.ProgressCallback + oid string + total int64 + bytesSoFar int64 +} + +func newProgressWriter(writer io.Writer, onProgress common.ProgressCallback, oid string, total int64) *progressWriter { + return &progressWriter{ + writer: writer, + onProgress: onProgress, + oid: oid, + total: total, + } +} + +func (pw *progressWriter) Write(p []byte) (int, error) { + n, err := pw.writer.Write(p) + if n > 0 && pw.onProgress != nil { + delta := int64(n) + pw.bytesSoFar += delta + if progressErr := pw.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pw.oid, + BytesSoFar: pw.bytesSoFar, + BytesSinceLast: delta, + }); progressErr != nil { + return n, progressErr + } + } + return n, err +} + +func (pw *progressWriter) Finalize() error { + if pw.onProgress == nil { + return nil + } + if pw.total == 0 || pw.bytesSoFar >= pw.total { + return nil + } + delta := pw.total - pw.bytesSoFar + pw.bytesSoFar = pw.total + return pw.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pw.oid, + BytesSoFar: pw.bytesSoFar, + BytesSinceLast: delta, + }) +} + +func resolveDownloadOID(fdr common.FileDownloadResponseObject) string { + if fdr.OID != "" { + return fdr.OID + } + if fdr.GUID != "" { + return fdr.GUID + } + return fdr.Filename +} diff --git a/client/download/progress_writer_test.go b/client/download/progress_writer_test.go new file mode 100644 index 0000000..8d573c8 --- /dev/null +++ b/client/download/progress_writer_test.go @@ -0,0 +1,46 @@ +package download + +import ( + "bytes" + "io" + "testing" + + "github.com/calypr/data-client/client/common" +) + +func TestProgressWriterFinalizes(t *testing.T) { + payload := bytes.Repeat([]byte("b"), 20) + var events []common.ProgressEvent + + writer := newProgressWriter(io.Discard, func(event common.ProgressEvent) error { + events = append(events, event) + return nil + }, "oid-456", int64(len(payload))) + + if _, err := writer.Write(payload); err != nil { + t.Fatalf("write failed: %v", err) + } + if err := writer.Finalize(); err != nil { + t.Fatalf("finalize failed: %v", err) + } + + if len(events) == 0 { + t.Fatal("expected progress events, got none") + } + + var total int64 + for _, event := range events { + if event.Event != "progress" { + t.Fatalf("unexpected event type: %s", event.Event) + } + total += event.BytesSinceLast + } + + last := events[len(events)-1] + if last.BytesSoFar != int64(len(payload)) { + t.Fatalf("expected final bytesSoFar %d, got %d", len(payload), last.BytesSoFar) + } + if total != int64(len(payload)) { + t.Fatalf("expected bytesSinceLast sum %d, got %d", len(payload), total) + } +} diff --git a/client/download/transfer.go b/client/download/transfer.go new file mode 100644 index 0000000..e54ddab --- /dev/null +++ b/client/download/transfer.go @@ -0,0 +1,97 @@ +package download + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/calypr/data-client/client/client" + "github.com/calypr/data-client/client/common" +) + +// DownloadSingleWithProgress downloads a single object while emitting progress events. +func DownloadSingleWithProgress( + ctx context.Context, + g3i client.Gen3Interface, + guid string, + downloadPath string, + protocol string, + oid string, + progress common.ProgressCallback, +) error { + var err error + downloadPath, err = common.ParseRootPath(downloadPath) + if err != nil { + return fmt.Errorf("invalid download path: %w", err) + } + if !strings.HasSuffix(downloadPath, "/") { + downloadPath += "/" + } + + renamed := make([]RenamedOrSkippedFileInfo, 0) + info, err := AskGen3ForFileInfo(ctx, g3i, guid, protocol, downloadPath, "original", false, &renamed) + if err != nil { + return err + } + + fdr := common.FileDownloadResponseObject{ + DownloadPath: downloadPath, + Filename: info.Name, + GUID: guid, + OID: oid, + Progress: progress, + } + + protocolText := "" + if protocol != "" { + protocolText = "?protocol=" + protocol + } + if err := GetDownloadResponse(ctx, g3i, &fdr, protocolText); err != nil { + return err + } + + fullPath := filepath.Join(fdr.DownloadPath, fdr.Filename) + if dir := filepath.Dir(fullPath); dir != "." { + if err = os.MkdirAll(dir, 0766); err != nil { + _ = fdr.Response.Body.Close() + return fmt.Errorf("mkdir for %s: %w", fullPath, err) + } + } + + flags := os.O_CREATE | os.O_WRONLY + if fdr.Range > 0 { + flags |= os.O_APPEND + } else if fdr.Overwrite { + flags |= os.O_TRUNC + } + + file, err := os.OpenFile(fullPath, flags, 0666) + if err != nil { + _ = fdr.Response.Body.Close() + return fmt.Errorf("open local file %s: %w", fullPath, err) + } + + total := fdr.Response.ContentLength + fdr.Range + var writer io.Writer = file + var tracker *progressWriter + if fdr.Progress != nil { + tracker = newProgressWriter(file, fdr.Progress, resolveDownloadOID(fdr), total) + writer = tracker + } + + _, copyErr := io.Copy(writer, fdr.Response.Body) + _ = fdr.Response.Body.Close() + _ = file.Close() + if tracker != nil { + if finalizeErr := tracker.Finalize(); finalizeErr != nil && copyErr == nil { + copyErr = finalizeErr + } + } + if copyErr != nil { + return fmt.Errorf("download failed for %s: %w", fdr.Filename, copyErr) + } + return nil +} diff --git a/client/download/transfer_test.go b/client/download/transfer_test.go new file mode 100644 index 0000000..7c702dc --- /dev/null +++ b/client/download/transfer_test.go @@ -0,0 +1,165 @@ +package download + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/calypr/data-client/client/api" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/request" +) + +type fakeGen3Download struct { + cred *conf.Credential + logger *logs.TeeLogger + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeGen3Download) GetCredential() *conf.Credential { return f.cred } +func (f *fakeGen3Download) Logger() *logs.TeeLogger { return f.logger } +func (f *fakeGen3Download) New(method, url string) *request.RequestBuilder { + return &request.RequestBuilder{Method: method, Url: url} +} +func (f *fakeGen3Download) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + return f.doFunc(ctx, req) +} +func (f *fakeGen3Download) CheckPrivileges(context.Context) (map[string]any, error) { + return nil, nil +} +func (f *fakeGen3Download) CheckForShepherdAPI(context.Context) (bool, error) { return false, nil } +func (f *fakeGen3Download) DeleteRecord(context.Context, string) (string, error) { + return "", nil +} +func (f *fakeGen3Download) GetDownloadPresignedUrl(context.Context, string, string) (string, error) { + return "https://download.example.com/object", nil +} +func (f *fakeGen3Download) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { + return (&api.Functions{}).ParseFenceURLResponse(resp) +} +func (f *fakeGen3Download) ExportCredential(context.Context, *conf.Credential) error { return nil } +func (f *fakeGen3Download) NewAccessToken(context.Context) error { return nil } + +func TestDownloadSingleWithProgressEmitsEvents(t *testing.T) { + payload := bytes.Repeat([]byte("d"), 64) + downloadDir := t.TempDir() + downloadPath := downloadDir + string(os.PathSeparator) + + var events []common.ProgressEvent + progress := func(event common.ProgressEvent) error { + events = append(events, event) + return nil + } + + fake := &fakeGen3Download{ + cred: &conf.Credential{APIEndpoint: "https://example.com", AccessToken: "token"}, + logger: logs.NewTeeLogger("", "", io.Discard), + doFunc: func(_ context.Context, req *request.RequestBuilder) (*http.Response, error) { + switch { + case strings.Contains(req.Url, common.IndexdIndexEndpoint): + return newDownloadJSONResponse(req.Url, `{"file_name":"payload.bin","size":64}`), nil + case strings.HasPrefix(req.Url, "https://download.example.com/"): + return newDownloadResponse(req.Url, payload, http.StatusOK), nil + default: + return nil, errors.New("unexpected request url: " + req.Url) + } + }, + } + + err := DownloadSingleWithProgress(context.Background(), fake, "guid-123", downloadPath, "", "oid-123", progress) + if err != nil { + t.Fatalf("download failed: %v", err) + } + + if len(events) == 0 { + t.Fatal("expected progress events") + } + for i := 1; i < len(events); i++ { + if events[i].BytesSoFar < events[i-1].BytesSoFar { + t.Fatalf("bytesSoFar not monotonic: %d then %d", events[i-1].BytesSoFar, events[i].BytesSoFar) + } + } + last := events[len(events)-1] + if last.BytesSoFar != int64(len(payload)) { + t.Fatalf("expected final bytesSoFar %d, got %d", len(payload), last.BytesSoFar) + } + fullPath := filepath.Join(downloadPath, "payload.bin") + if _, err := os.Stat(fullPath); err != nil { + t.Fatalf("expected file to exist: %v", err) + } +} + +func TestDownloadSingleWithProgressFinalizeOnError(t *testing.T) { + downloadDir := t.TempDir() + downloadPath := downloadDir + string(os.PathSeparator) + + var events []common.ProgressEvent + progress := func(event common.ProgressEvent) error { + events = append(events, event) + return nil + } + + fake := &fakeGen3Download{ + cred: &conf.Credential{APIEndpoint: "https://example.com", AccessToken: "token"}, + logger: logs.NewTeeLogger("", "", io.Discard), + doFunc: func(_ context.Context, req *request.RequestBuilder) (*http.Response, error) { + switch { + case strings.Contains(req.Url, common.IndexdIndexEndpoint): + return newDownloadJSONResponse(req.Url, `{"file_name":"payload.bin","size":64}`), nil + case strings.HasPrefix(req.Url, "https://download.example.com/"): + return newDownloadResponse(req.Url, []byte("short"), http.StatusOK), nil + default: + return nil, errors.New("unexpected request url: " + req.Url) + } + }, + } + + err := DownloadSingleWithProgress(context.Background(), fake, "guid-123", downloadPath, "", "oid-123", progress) + if err == nil { + t.Fatal("expected download error") + } + + if len(events) == 0 { + t.Fatal("expected progress events") + } + last := events[len(events)-1] + if last.BytesSoFar != 64 { + t.Fatalf("expected finalize bytesSoFar 64, got %d", last.BytesSoFar) + } +} + +func newDownloadJSONResponse(rawURL, body string) *http.Response { + parsedURL, err := url.Parse(rawURL) + if err != nil { + parsedURL = &url.URL{} + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: parsedURL}, + Header: make(http.Header), + } +} + +func newDownloadResponse(rawURL string, payload []byte, status int) *http.Response { + parsedURL, err := url.Parse(rawURL) + if err != nil { + parsedURL = &url.URL{} + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(bytes.NewReader(payload)), + ContentLength: int64(len(payload)), + Request: &http.Request{URL: parsedURL}, + Header: make(http.Header), + } +} diff --git a/client/upload/multipart.go b/client/upload/multipart.go index be97d69..df8d0cd 100644 --- a/client/upload/multipart.go +++ b/client/upload/multipart.go @@ -12,6 +12,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" client "github.com/calypr/data-client/client/client" "github.com/calypr/data-client/client/common" @@ -60,18 +61,8 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi key := fmt.Sprintf("%s/%s", finalGUID, req.Filename) g3.Logger().Printf("Initialized Upload: ID=%s, Key=%s\n", uploadID, key) - optimalChunkSize := func(fSize int64) int64 { - if fSize <= 512*common.MB { - return 32 * common.MB - } - chunkSize := fSize / common.MaxMultipartParts - if chunkSize < common.MinChunkSize { - chunkSize = common.MinChunkSize - } - return ((chunkSize + common.MB - 1) / common.MB) * common.MB - } + chunkSize := OptimalChunkSize(fileSize) - chunkSize := optimalChunkSize(fileSize) numChunks := int((fileSize + chunkSize - 1) / chunkSize) chunks := make(chan int, numChunks) @@ -85,6 +76,7 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi mu sync.Mutex parts []MultipartPartObject uploadErrors []error + totalBytes int64 // Atomic counter for monotonically increasing BytesSoFar ) // 3. Worker logic @@ -128,6 +120,18 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi if bar != nil { bar.IncrInt64(size) } + if req.Progress != nil { + currentTotal := atomic.AddInt64(&totalBytes, size) + err = req.Progress(common.ProgressEvent{ + Event: "progress", + Oid: req.OID, + BytesSinceLast: size, + BytesSoFar: currentTotal, + }) + if err != nil { + g3.Logger().Printf("progress callback error: %v", err) + } + } mu.Unlock() } } diff --git a/client/upload/multipart_test.go b/client/upload/multipart_test.go new file mode 100644 index 0000000..c03cbea --- /dev/null +++ b/client/upload/multipart_test.go @@ -0,0 +1,153 @@ +package upload + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "sync" + "testing" + + "github.com/calypr/data-client/client/api" + "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/client/request" +) + +type fakeGen3Upload struct { + cred *conf.Credential + logger *logs.TeeLogger + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeGen3Upload) GetCredential() *conf.Credential { return f.cred } +func (f *fakeGen3Upload) Logger() *logs.TeeLogger { return f.logger } +func (f *fakeGen3Upload) New(method, url string) *request.RequestBuilder { + return &request.RequestBuilder{Method: method, Url: url} +} +func (f *fakeGen3Upload) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + return f.doFunc(ctx, req) +} +func (f *fakeGen3Upload) CheckPrivileges(context.Context) (map[string]any, error) { + return nil, nil +} +func (f *fakeGen3Upload) CheckForShepherdAPI(context.Context) (bool, error) { return false, nil } +func (f *fakeGen3Upload) DeleteRecord(context.Context, string) (string, error) { + return "", nil +} +func (f *fakeGen3Upload) GetDownloadPresignedUrl(context.Context, string, string) (string, error) { + return "", nil +} +func (f *fakeGen3Upload) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { + return (&api.Functions{}).ParseFenceURLResponse(resp) +} +func (f *fakeGen3Upload) ExportCredential(context.Context, *conf.Credential) error { return nil } +func (f *fakeGen3Upload) NewAccessToken(context.Context) error { return nil } + +func TestMultipartUploadProgressIntegration(t *testing.T) { + ctx := context.Background() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + _, _ = io.Copy(io.Discard, r.Body) + _ = r.Body.Close() + w.Header().Set("ETag", "etag-123") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + file, err := os.CreateTemp(t.TempDir(), "multipart-*.bin") + if err != nil { + t.Fatalf("create temp file: %v", err) + } + defer file.Close() + + fileSize := int64(101 * common.MB) + if err := file.Truncate(fileSize); err != nil { + t.Fatalf("truncate file: %v", err) + } + if _, err := file.Seek(0, io.SeekStart); err != nil { + t.Fatalf("seek file: %v", err) + } + + var ( + events []common.ProgressEvent + mu sync.Mutex + ) + progress := func(event common.ProgressEvent) error { + mu.Lock() + defer mu.Unlock() + events = append(events, event) + return nil + } + + logger := logs.NewTeeLogger("", "", io.Discard) + fake := &fakeGen3Upload{ + cred: &conf.Credential{ + APIEndpoint: "https://example.com", + AccessToken: "token", + }, + logger: logger, + doFunc: func(_ context.Context, req *request.RequestBuilder) (*http.Response, error) { + switch { + case strings.Contains(req.Url, common.FenceDataMultipartInitEndpoint): + return newJSONResponse(req.Url, `{"uploadId":"upload-123","guid":"guid-123"}`), nil + case strings.Contains(req.Url, common.FenceDataMultipartUploadEndpoint): + return newJSONResponse(req.Url, fmt.Sprintf(`{"presigned_url":"%s"}`, server.URL)), nil + case strings.Contains(req.Url, common.FenceDataMultipartCompleteEndpoint): + return newJSONResponse(req.Url, `{}`), nil + default: + return nil, fmt.Errorf("unexpected request url: %s", req.Url) + } + }, + } + + requestObject := common.FileUploadRequestObject{ + FilePath: file.Name(), + Filename: "multipart.bin", + GUID: "guid-123", + OID: "oid-123", + Bucket: "bucket", + Progress: progress, + } + + if err := MultipartUpload(ctx, fake, requestObject, file, false); err != nil { + t.Fatalf("multipart upload failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if len(events) == 0 { + t.Fatal("expected progress events") + } + for i := 1; i < len(events); i++ { + if events[i].BytesSoFar < events[i-1].BytesSoFar { + t.Fatalf("bytesSoFar not monotonic: %d then %d", events[i-1].BytesSoFar, events[i].BytesSoFar) + } + } + last := events[len(events)-1] + if last.BytesSoFar != fileSize { + t.Fatalf("expected final bytesSoFar %d, got %d", fileSize, last.BytesSoFar) + } +} + +func newJSONResponse(rawURL, body string) *http.Response { + parsedURL, err := url.Parse(rawURL) + if err != nil { + parsedURL = &url.URL{} + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Request: &http.Request{URL: parsedURL}, + Header: make(http.Header), + } +} diff --git a/client/upload/progress_reader.go b/client/upload/progress_reader.go new file mode 100644 index 0000000..8262ee5 --- /dev/null +++ b/client/upload/progress_reader.go @@ -0,0 +1,68 @@ +package upload + +import ( + "io" + + "github.com/calypr/data-client/client/common" +) + +type progressReader struct { + reader io.Reader + onProgress common.ProgressCallback + oid string + total int64 + bytesSoFar int64 +} + +func newProgressReader(reader io.Reader, onProgress common.ProgressCallback, oid string, total int64) *progressReader { + return &progressReader{ + reader: reader, + onProgress: onProgress, + oid: oid, + total: total, + } +} + +func resolveUploadOID(req common.FileUploadRequestObject) string { + if req.OID != "" { + return req.OID + } + if req.GUID != "" { + return req.GUID + } + return req.Filename +} + +func (pr *progressReader) Read(p []byte) (int, error) { + n, err := pr.reader.Read(p) + if n > 0 && pr.onProgress != nil { + delta := int64(n) + pr.bytesSoFar += delta + if progressErr := pr.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pr.oid, + BytesSoFar: pr.bytesSoFar, + BytesSinceLast: delta, + }); progressErr != nil { + return n, progressErr + } + } + return n, err +} + +func (pr *progressReader) Finalize() error { + if pr.onProgress == nil { + return nil + } + if pr.total == 0 || pr.bytesSoFar >= pr.total { + return nil + } + delta := pr.total - pr.bytesSoFar + pr.bytesSoFar = pr.total + return pr.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pr.oid, + BytesSoFar: pr.bytesSoFar, + BytesSinceLast: delta, + }) +} diff --git a/client/upload/progress_reader_test.go b/client/upload/progress_reader_test.go new file mode 100644 index 0000000..de77d8e --- /dev/null +++ b/client/upload/progress_reader_test.go @@ -0,0 +1,46 @@ +package upload + +import ( + "bytes" + "io" + "testing" + + "github.com/calypr/data-client/client/common" +) + +func TestProgressReaderFinalizes(t *testing.T) { + payload := bytes.Repeat([]byte("a"), 16) + var events []common.ProgressEvent + + reader := newProgressReader(bytes.NewReader(payload), func(event common.ProgressEvent) error { + events = append(events, event) + return nil + }, "oid-123", int64(len(payload))) + + if _, err := io.Copy(io.Discard, reader); err != nil { + t.Fatalf("copy failed: %v", err) + } + if err := reader.Finalize(); err != nil { + t.Fatalf("finalize failed: %v", err) + } + + if len(events) == 0 { + t.Fatal("expected progress events, got none") + } + + var total int64 + for _, event := range events { + if event.Event != "progress" { + t.Fatalf("unexpected event type: %s", event.Event) + } + total += event.BytesSinceLast + } + + last := events[len(events)-1] + if last.BytesSoFar != int64(len(payload)) { + t.Fatalf("expected final bytesSoFar %d, got %d", len(payload), last.BytesSoFar) + } + if total != int64(len(payload)) { + t.Fatalf("expected bytesSinceLast sum %d, got %d", len(payload), total) + } +} diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go index cc9a430..32c4194 100644 --- a/client/upload/singleFile.go +++ b/client/upload/singleFile.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" @@ -12,7 +13,7 @@ import ( "github.com/calypr/data-client/client/logs" ) -func UploadSingle(ctx context.Context, profile string, guid string, filePath string, bucketName string, enableLogs bool) error { +func UploadSingle(ctx context.Context, profile string, guid string, oid string, filePath string, bucketName string, enableLogs bool, progressCallback common.ProgressCallback) error { logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog()) if enableLogs { @@ -67,9 +68,17 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str } fileSize := fi.Size() - furObject := common.FileUploadRequestObject{FilePath: filePath, Filename: filename, GUID: guid, Bucket: bucketName} + furObject := common.FileUploadRequestObject{ + FilePath: filePath, + Filename: filename, + GUID: guid, + OID: oid, + Bucket: bucketName, + Progress: progressCallback, + } furObject, err = generateUploadRequest(ctx, g3i, furObject, file, nil) + if err != nil { if enableLogs { sb := g3i.Logger().Scoreboard() @@ -81,7 +90,19 @@ func UploadSingle(ctx context.Context, profile string, guid string, filePath str return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) } - _, err = uploadPart(ctx, furObject.PresignedURL, file, fileSize) + var reader io.Reader = file + var progressTracker *progressReader + if furObject.Progress != nil { + progressTracker = newProgressReader(file, furObject.Progress, resolveUploadOID(furObject), fileSize) + reader = progressTracker + } + + _, err = uploadPart(ctx, furObject.PresignedURL, reader, fileSize) + if progressTracker != nil { + if finalizeErr := progressTracker.Finalize(); finalizeErr != nil && err == nil { + err = finalizeErr + } + } if err != nil { if enableLogs { g3i.Logger().Scoreboard().IncrementSB(1) // Increment failure diff --git a/client/upload/upload.go b/client/upload/upload.go index b786164..14fc894 100644 --- a/client/upload/upload.go +++ b/client/upload/upload.go @@ -38,7 +38,7 @@ func Upload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadR // Use Single-Part if file is smaller than 5GB (or your defined limit) if fileSize < 5*common.GB { g3.Logger().Printf("File size %d bytes (< 5GB), performing single-part upload\n", fileSize) - UploadSingle(ctx, g3.GetCredential().Profile, req.GUID, req.FilePath, req.Bucket, true) + UploadSingle(ctx, g3.GetCredential().Profile, req.GUID, req.GUID, req.FilePath, req.Bucket, true, nil) } g3.Logger().Printf("File size %d bytes (>= 5GB), performing multipart upload\n", fileSize) return MultipartUpload(ctx, g3, req, file, showProgress) diff --git a/client/upload/utils.go b/client/upload/utils.go index 2dbfa85..c26f3fc 100644 --- a/client/upload/utils.go +++ b/client/upload/utils.go @@ -131,3 +131,59 @@ func FormatSize(size int64) string { return fmt.Sprintf("%.1f"+unitMap[unitSize], float64(size)/float64(unitSize)) } + +// OptimalChunkSize returns a recommended chunk size for the given fileSize (in bytes). +// - <= 100 MB: return fileSize (use single PUT) +// - >100 MB and <= 1 GB: 10 MB +// - >1 GB and <= 10 GB: scaled between 25 MB and 128 MB +// - >10 GB and <= 100 GB: 256 MB +// - >100 GB: scaled between 512 MB and 1024 MB (1 GB) +// See: +// https://cloud.switch.ch/-/documentation/s3/multipart-uploads/#best-practices +func OptimalChunkSize(fileSize int64) int64 { + if fileSize <= 0 { + return 1 * common.MB + } + + switch { + case fileSize <= 100*common.MB: + // Single PUT: return whole file size + return fileSize + + case fileSize <= 1*common.GB: + return 10 * common.MB + + case fileSize <= 10*common.GB: + return scaleLinear(fileSize, 1*common.GB, 10*common.GB, 25*common.MB, 128*common.MB) + + case fileSize <= 100*common.GB: + return 256 * common.MB + + default: + // Scale for very large files; cap scaling at 1 TB for ratio purposes + return scaleLinear(fileSize, 100*common.GB, 1000*common.GB, 512*common.MB, 1024*common.MB) + } +} + +// scaleLinear scales size in [minSize, maxSize] to chunk in [minChunk, maxChunk] (linear). +// Result is rounded down to nearest MB and clamped to [minChunk, maxChunk]. +func scaleLinear(size, minSize, maxSize, minChunk, maxChunk int64) int64 { + if size <= minSize { + return minChunk + } + if size >= maxSize { + return maxChunk + } + ratio := float64(size-minSize) / float64(maxSize-minSize) + chunkF := float64(minChunk) + ratio*(float64(maxChunk-minChunk)) + // round down to nearest MB + mb := int64(common.MB) + chunk := int64(chunkF) / mb * mb + if chunk < minChunk { + return minChunk + } + if chunk > maxChunk { + return maxChunk + } + return chunk +} diff --git a/client/upload/utils_test.go b/client/upload/utils_test.go new file mode 100644 index 0000000..8681096 --- /dev/null +++ b/client/upload/utils_test.go @@ -0,0 +1,124 @@ +package upload + +import ( + "testing" + + "github.com/calypr/data-client/client/common" +) + +func TestOptimalChunkSize(t *testing.T) { + tests := []struct { + name string + fileSize int64 + wantChunkSize int64 + wantParts int64 + }{ + { + name: "0 bytes", + fileSize: 0, + wantChunkSize: 1 * common.MB, + wantParts: 0, + }, + { + name: "1MB", + fileSize: 1 * common.MB, + wantChunkSize: 1 * common.MB, + wantParts: 1, + }, + { + name: "100MB", + fileSize: 100 * common.MB, + wantChunkSize: 100 * common.MB, + wantParts: 1, + }, + { + name: "100MB+1B", + fileSize: 100*common.MB + 1, + wantChunkSize: 10 * common.MB, + wantParts: 11, + }, + { + name: "500MB", + fileSize: 500 * common.MB, + wantChunkSize: 10 * common.MB, + wantParts: 50, + }, + { + name: "1GB", + fileSize: 1 * common.GB, + wantChunkSize: 10 * common.MB, + wantParts: 103, + }, + { + name: "1GB+1B", + fileSize: 1*common.GB + 1, + wantChunkSize: 25 * common.MB, + wantParts: 41, + }, + { + name: "5GB", + fileSize: 5 * common.GB, + wantChunkSize: 70 * common.MB, + wantParts: 74, + }, + { + name: "10GB", + fileSize: 10 * common.GB, + wantChunkSize: 128 * common.MB, + wantParts: 80, + }, + { + name: "10GB+1B", + fileSize: 10*common.GB + 1, + wantChunkSize: 256 * common.MB, + wantParts: 41, + }, + { + name: "50GB", + fileSize: 50 * common.GB, + wantChunkSize: 256 * common.MB, + wantParts: 200, + }, + { + name: "100GB", + fileSize: 100 * common.GB, + wantChunkSize: 256 * common.MB, + wantParts: 400, + }, + { + name: "100GB+1B", + fileSize: 100*common.GB + 1, + wantChunkSize: 512 * common.MB, + wantParts: 201, + }, + { + name: "500GB", + fileSize: 500 * common.GB, + wantChunkSize: 739 * common.MB, + wantParts: 693, + }, + { + name: "1TB", + fileSize: 1 * common.TB, + wantChunkSize: 1 * common.GB, + wantParts: 1024, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chunkSize := OptimalChunkSize(tt.fileSize) + if chunkSize != tt.wantChunkSize { + t.Fatalf("chunk size = %d, want %d", chunkSize, tt.wantChunkSize) + } + + parts := int64(0) + if tt.fileSize > 0 && chunkSize > 0 { + parts = (tt.fileSize + chunkSize - 1) / chunkSize + } + if parts != tt.wantParts { + t.Fatalf("parts = %d, want %d", parts, tt.wantParts) + } + }) + } +} diff --git a/cmd/upload-multiple.go b/cmd/upload-multiple.go index 13f91d7..66fef2e 100644 --- a/cmd/upload-multiple.go +++ b/cmd/upload-multiple.go @@ -38,6 +38,7 @@ Options to run multipart uploads for large files and parallel batch uploading ar fmt.Printf("Notice: this command uploads to pre-existing GUIDs from a manifest.\nIf you want to upload new files (new GUIDs generated automatically), use \"./data-client upload\" instead.\n\n") ctx := context.Background() + noopProgress := func(common.ProgressEvent) error { return nil } logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) defer closer() @@ -94,6 +95,7 @@ Options to run multipart uploads for large files and parallel batch uploading ar // GUID comes from manifest → override fur.GUID = obj.ObjectID fur.Bucket = bucketName + fur.Progress = noopProgress logger.Println("\t" + localFilePath + " → GUID " + obj.ObjectID) requests = append(requests, fur) @@ -124,7 +126,7 @@ Options to run multipart uploads for large files and parallel batch uploading ar } } else { for _, req := range single { - upload.UploadSingle(ctx, profileConfig.Profile, req.GUID, req.FilePath, req.Bucket, true) + upload.UploadSingle(ctx, profileConfig.Profile, req.GUID, req.GUID, req.FilePath, req.Bucket, true, noopProgress) } } diff --git a/cmd/upload-single.go b/cmd/upload-single.go index d8a8b53..0270e36 100644 --- a/cmd/upload-single.go +++ b/cmd/upload-single.go @@ -5,6 +5,7 @@ import ( "context" "log" + "github.com/calypr/data-client/client/common" "github.com/calypr/data-client/client/upload" "github.com/spf13/cobra" ) @@ -20,7 +21,8 @@ func init() { Long: `Gets a presigned URL for which to upload a file associated with a GUID and then uploads the specified file.`, Example: `./data-client upload-single --profile= --guid=f6923cf3-xxxx-xxxx-xxxx-14ab3f84f9d6 --file=`, Run: func(cmd *cobra.Command, args []string) { - err := upload.UploadSingle(context.Background(), profile, guid, filePath, bucketName, true) + noopProgress := func(common.ProgressEvent) error { return nil } + err := upload.UploadSingle(context.Background(), profile, guid, guid, filePath, bucketName, true, noopProgress) if err != nil { log.Fatalln(err.Error()) } diff --git a/docs/optimal-chunk-size.md b/docs/optimal-chunk-size.md new file mode 100644 index 0000000..86019ba --- /dev/null +++ b/docs/optimal-chunk-size.md @@ -0,0 +1,152 @@ + + +# Engineering note — Optimal Chunk Size Calculation for Multipart Uploads + +## OLD: + optimalChunkSize determines the ideal chunk/part size for multipart upload based on file size. + The chunk size (also known as "message size" or "part size") affects upload performance and + must comply with S3 constraints. + + Calculation logic: + - For files ≤ 512 MB: Returns 32 MB chunks for optimal performance + - For files > 512 MB: Calculates fileSize/maxMultipartParts, with minimum of 5 MB + - Enforces minimum of 5 MB (S3 requirement for all parts except the last) + - Rounds up to nearest MB for alignment + + This results in: + - Files ≤ 512 MB: 32 MB chunks + - Files 512 MB - ~49 GB: 5 MB chunks (minimum enforced) + The ~49 GB threshold (10,000 parts × 5 MB) is where files exceed S3's + 10,000 part limit when using the minimum chunk size + - Files > ~49 GB: Dynamically calculated to stay under 10,000 parts + + Examples: + - 100 MB file → 32 MB chunks (4 parts) + - 1 GB file → 5 MB chunks (~205 parts) + - 10 GB file → 5 MB chunks (~2,048 parts) + - 50 GB file → 6 MB chunks (~8,534 parts) + - 100 GB file → 11 MB chunks (~9,310 parts) + - 1 TB file → 105 MB chunks (~9,987 parts) + +## NEW + +OptimalChunkSize determines the ideal chunk/part size for multipart upload based on file size. +The chunk size (also known as "message size" or "part size") affects upload performance and +must comply with S3 constraints. + +Calculation logic: + - For files ≤ 100 MB: Returns the file size itself (single PUT, no multipart) + - For files > 100 MB and ≤ 1 GB: Returns 10 MB chunks + - For files > 1 GB and ≤ 10 GB: Scales linearly between 25 MB and 128 MB + - For files > 10 GB and ≤ 100 GB: Returns 256 MB chunks + - For files > 100 GB: Scales linearly between 512 MB and 1024 MB (capped at 1 TB for ratio purposes) + - All chunk sizes are rounded down to the nearest MB + - Minimum chunk size is 1 MB (for zero or negative input) + +This results in: + - Files ≤ 100 MB: Single PUT upload + - Files 100 MB - 1 GB: 10 MB chunks + - Files 1 GB - 10 GB: 25-128 MB chunks (scaled) + - Files 10 GB - 100 GB: 256 MB chunks + - Files > 100 GB: 512-1024 MB chunks (scaled) + +Examples: + - 100 MB file → 100 MB chunk (1 part, single PUT) + - 500 MB file → 10 MB chunks (50 parts) + - 1 GB file → 10 MB chunks (103 parts) + - 5 GB file → 70 MB chunks (74 parts, scaled) + - 10 GB file → 128 MB chunks (80 parts) + - 50 GB file → 256 MB chunks (200 parts) + - 100 GB file → 256 MB chunks (400 parts) + - 500 GB file → 739 MB chunks (693 parts, scaled) + - 1 TB file → 1024 MB chunks (1024 parts) + +### Testing + + +```bash +go test ./client/upload -run '^TestOptimalChunkSize$' -v + +``` + +Purpose +- Validate `OptimalChunkSize` behavior and return values (chunk size and number of parts) across thresholds, boundaries and scaled ranges. + +Key behavior to assert +1. Input type and units: sizes are `int64` bytes; tests should use `common.MB` / `common.GB` constants. +2. Parts calculation: `parts = ceil(fileSize / chunk)`; `fileSize == 0` returns `parts == 0`. +3. Scaling: scaled ranges are linear, rounded **down** to the nearest MB and clamped to range. +4. Minimum chunk clamp: result is at least `1 MB`. +5. Boundary semantics: implementation uses `<=` and some ranges start at `X + 1` — include exact, \-1 and \+1 byte checks. + +Parameterized test cases (file size ⇒ expected chunk ⇒ expected parts) +1. `0` bytes + - chunk: `1 MB` (fallback) + - parts: `0` + +2. `1 MB` + - chunk: `1 MB` (<= 100 MB) + - parts: `1` + +3. `100 MB` + - chunk: `100 MB` (<= 100 MB) + - parts: `1` + +4. `100 MB + 1 B` + - chunk: `10 MB` (> 100 MB - <= 1 GB) + - parts: ceil((100 MB + 1 B) / 10 MB) = `11` + +5. `500 MB` + - chunk: `10 MB` + - parts: `50` + +6. `1 GB` (1024 MB) + - chunk: `10 MB` (<= 1 GB) + - parts: ceil(1024 / 10) = `103` + +7. `1 GB + 1 B` + - chunk: `25 MB` (start of 1 GB - 10 GB scaled range) + - parts: ceil((1024 MB + 1 B) / 25 MB) = `41` + +8. `5 GB` (5120 MB) + - chunk: linear between `25 MB` and `128 MB` → ≈ `70 MB` (rounded down) + - parts: ceil(5120 / 70) = `74` + +9. `10 GB` (10240 MB) + - chunk: `128 MB` (end of 1 GB - 10 GB scaled range) + - parts: `80` + +10. `10 GB + 1 B` + - chunk: `256 MB` (> 10 GB - <= 100 GB fixed) + - parts: ceil((10240 MB + 1 B) / 256 MB) = `41` + +11. `50 GB` (51200 MB) + - chunk: `256 MB` + - parts: `200` + +12. `100 GB` (102400 MB) + - chunk: `256 MB` + - parts: `400` + +13. `100 GB + 1 B` + - chunk: `512 MB` (start of > 100 GB scaled range) + - parts: ceil((102400 MB + 1 B) / 512 MB) = `201` + +14. `500 GB` (512000 MB) + - chunk: linear between `512 MB` and `1024 MB` → ≈ `739 MB` (rounded down) + - parts: ceil(512000 / 739) = `693` + +15. `1 TB` (1024 GB = 1,048,576 MB) — note: use project units consistently + - chunk: `1024 MB` (max of scaled range) + - parts: 1,048,576 / 1024 = `1024` + +Test design notes (concise) +1. Use table-driven subtests in `client/upload/utils_test.go`. Include fields: name, `fileSize int64`, `wantChunk int64`, `wantParts int64`. +2. For scaled cases assert: MB alignment, clamped to min/max, and exact `wantParts`. Use integer arithmetic for parts. +3. Add explicit boundary triples for each threshold: exact, -1 byte, +1 byte. +4. Include negative and zero cases to verify fallback behavior. +5. Keep tests deterministic and fast (no external deps). + +Execution +- Run from repo root: `go test ./client/upload -v` +- Run single test: `go test ./client/upload -run '^TestOptimalChunkSize$' -v` \ No newline at end of file From 6fd53ed816f888786303ac78c5c3ea9e5fbf5c25 Mon Sep 17 00:00:00 2001 From: Matthew Peterkort <33436238+matthewpeterkort@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:04:17 -0800 Subject: [PATCH 12/14] Refactor/calypr clients (#26) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MinChunkSize: 10MB * fix:improve-multipart #18 * improve OptimalChunkSize.scaleLinear * adds optimal-chunk-size * table-driven tests GetLfsCustomTransferInt * Initial plan * Address PR review comments: fix naming, formatting, and test coverage * Add logging for git config errors and improve documentation structure * Fix spelling: terrabytes → terabytes * refactor a bunch of stuff, incorporate latest commits on develop * add g3client * operational data-client * tests passing * add sower client * fix up logging to pass back to main git-drs logger when configured * addresse logging feedbacks * add requestor client and commands * add command validation for collaborator command * fix up releaser * fix PartNumber typo regression * fix partNumber --- .github/workflows/build_and_test_push.yaml | 214 ------ .github/workflows/coverage.yaml | 103 +++ .github/workflows/release.yaml | 34 + Makefile | 27 +- build.sh | 68 -- bump-tag.sh | 128 ++++ client/api/gen3.go | 367 ---------- client/api/types.go | 25 - client/client/client.go | 69 -- client/common/constants_test.go | 110 --- client/common/progress.go | 12 - client/download/url_resolution.go | 80 --- client/logs/logger.go | 41 -- client/logs/tee_logger.go | 172 ----- client/mocks/mock_gen3interface.go | 192 ------ client/upload/singleFile.go | 118 ---- client/upload/upload.go | 125 ---- cmd/auth.go | 11 +- cmd/collaborator.go | 264 ++++++++ cmd/configure.go | 20 +- cmd/delete.go | 8 +- cmd/download-multiple.go | 10 +- cmd/download-single.go | 12 +- cmd/gitversion.go | 2 +- cmd/retry-upload.go | 10 +- cmd/upload-multipart.go | 20 +- cmd/upload-multiple.go | 31 +- cmd/upload-single.go | 24 +- cmd/upload.go | 24 +- {client/common => common}/common.go | 23 + {client/common => common}/constants.go | 45 +- .../common => common}/isHidden_notwindows.go | 0 {client/common => common}/isHidden_windows.go | 0 {client/common => common}/logHelper.go | 0 common/progress.go | 52 ++ common/resource.go | 14 + {client/common => common}/types.go | 16 +- {client/conf => conf}/config.go | 29 +- conf/config_test.go | 199 ++++++ {client/conf => conf}/validate.go | 49 +- conf/validate_test.go | 130 ++++ docs/optimal-chunk-size.md | 9 +- {client/download => download}/batch.go | 21 +- {client/download => download}/downloader.go | 50 +- {client/download => download}/file_info.go | 36 +- .../download => download}/progress_writer.go | 47 +- .../progress_writer_test.go | 2 +- {client/download => download}/transfer.go | 71 +- .../download => download}/transfer_test.go | 85 ++- {client/download => download}/types.go | 6 +- download/url_resolution.go | 87 +++ {client/download => download}/utils.go | 26 +- fence/client.go | 637 ++++++++++++++++++ fence/client_test.go | 250 +++++++ fence/types.go | 93 +++ g3client/client.go | 246 +++++++ go.mod | 22 + go.sum | 45 +- indexd/add_url.go | 106 +++ indexd/client.go | 515 ++++++++++++++ indexd/client_test.go | 266 ++++++++ indexd/convert.go | 99 +++ indexd/drs/drs.go | 87 +++ indexd/drs/types.go | 56 ++ indexd/hash/hash.go | 144 ++++ indexd/hash/hash_test.go | 53 ++ indexd/records.go | 97 +++ indexd/s3_utils.go | 124 ++++ indexd/tests/add-url-integration_test.go | 68 ++ indexd/tests/client_read_test.go.todo | 134 ++++ indexd/tests/client_write_test.go.todo | 369 ++++++++++ indexd/tests/mock_servers_test.go | 610 +++++++++++++++++ indexd/types.go | 75 +++ indexd/types_test.go | 60 ++ {client/logs => logs}/factory.go | 25 +- logs/handler.go | 102 +++ logs/logger.go | 35 + logs/logger_test.go | 210 ++++++ logs/noop.go | 11 + {client/logs => logs}/scoreboard.go | 4 +- logs/tee_logger.go | 217 ++++++ {client/mocks => mocks}/mock_configure.go | 6 +- mocks/mock_fence.go | 252 +++++++ {client/mocks => mocks}/mock_functions.go | 38 +- mocks/mock_gen3interface.go | 115 ++++ mocks/mock_indexd.go | 251 +++++++ {client/mocks => mocks}/mock_request.go | 6 +- {client/request => request}/auth.go | 9 +- {client/request => request}/builder.go | 8 +- {client/request => request}/request.go | 14 +- request/request_test.go | 263 ++++++++ requestor/client.go | 265 ++++++++ requestor/client_test.go | 57 ++ requestor/policies/add-user-guppy-admin.yaml | 9 + requestor/policies/add-user-read.yaml | 5 + requestor/policies/add-user-write.yaml | 5 + requestor/types.go | 34 + sower/client.go | 148 ++++ sower/types.go | 33 + tests/download-multiple_test.go | 106 ++- tests/functions_test.go | 252 ------- tests/utils_test.go | 171 +++-- {client/upload => upload}/batch.go | 34 +- {client/upload => upload}/multipart.go | 131 +--- {client/upload => upload}/multipart_test.go | 89 ++- {client/upload => upload}/progress_reader.go | 46 +- .../upload => upload}/progress_reader_test.go | 2 +- {client/upload => upload}/request.go | 64 +- {client/upload => upload}/retry.go | 38 +- upload/singleFile.go | 96 +++ {client/upload => upload}/types.go | 31 +- upload/upload.go | 208 ++++++ {client/upload => upload}/utils.go | 20 +- {client/upload => upload}/utils_test.go | 2 +- 114 files changed, 8250 insertions(+), 2636 deletions(-) create mode 100644 .github/workflows/coverage.yaml create mode 100644 .github/workflows/release.yaml delete mode 100755 build.sh create mode 100644 bump-tag.sh delete mode 100644 client/api/gen3.go delete mode 100644 client/api/types.go delete mode 100644 client/client/client.go delete mode 100644 client/common/constants_test.go delete mode 100644 client/common/progress.go delete mode 100644 client/download/url_resolution.go delete mode 100644 client/logs/logger.go delete mode 100644 client/logs/tee_logger.go delete mode 100644 client/mocks/mock_gen3interface.go delete mode 100644 client/upload/singleFile.go delete mode 100644 client/upload/upload.go create mode 100644 cmd/collaborator.go rename {client/common => common}/common.go (79%) rename {client/common => common}/constants.go (71%) rename {client/common => common}/isHidden_notwindows.go (100%) rename {client/common => common}/isHidden_windows.go (100%) rename {client/common => common}/logHelper.go (100%) create mode 100644 common/progress.go create mode 100644 common/resource.go rename {client/common => common}/types.go (83%) rename {client/conf => conf}/config.go (91%) create mode 100644 conf/config_test.go rename {client/conf => conf}/validate.go (51%) create mode 100644 conf/validate_test.go rename {client/download => download}/batch.go (90%) rename {client/download => download}/downloader.go (67%) rename {client/download => download}/file_info.go (70%) rename {client/download => download}/progress_writer.go (53%) rename {client/download => download}/progress_writer_test.go (95%) rename {client/download => download}/transfer.go (50%) rename {client/download => download}/transfer_test.go (58%) rename {client/download => download}/types.go (90%) create mode 100644 download/url_resolution.go rename {client/download => download}/utils.go (67%) create mode 100644 fence/client.go create mode 100644 fence/client_test.go create mode 100644 fence/types.go create mode 100644 g3client/client.go create mode 100644 indexd/add_url.go create mode 100644 indexd/client.go create mode 100644 indexd/client_test.go create mode 100644 indexd/convert.go create mode 100644 indexd/drs/drs.go create mode 100644 indexd/drs/types.go create mode 100644 indexd/hash/hash.go create mode 100644 indexd/hash/hash_test.go create mode 100644 indexd/records.go create mode 100644 indexd/s3_utils.go create mode 100644 indexd/tests/add-url-integration_test.go create mode 100644 indexd/tests/client_read_test.go.todo create mode 100644 indexd/tests/client_write_test.go.todo create mode 100644 indexd/tests/mock_servers_test.go create mode 100644 indexd/types.go create mode 100644 indexd/types_test.go rename {client/logs => logs}/factory.go (66%) create mode 100644 logs/handler.go create mode 100644 logs/logger.go create mode 100644 logs/logger_test.go create mode 100644 logs/noop.go rename {client/logs => logs}/scoreboard.go (95%) create mode 100644 logs/tee_logger.go rename {client/mocks => mocks}/mock_configure.go (94%) create mode 100644 mocks/mock_fence.go rename {client/mocks => mocks}/mock_functions.go (76%) create mode 100644 mocks/mock_gen3interface.go create mode 100644 mocks/mock_indexd.go rename {client/mocks => mocks}/mock_request.go (91%) rename {client/request => request}/auth.go (90%) rename {client/request => request}/builder.go (87%) rename {client/request => request}/request.go (90%) create mode 100644 request/request_test.go create mode 100644 requestor/client.go create mode 100644 requestor/client_test.go create mode 100644 requestor/policies/add-user-guppy-admin.yaml create mode 100644 requestor/policies/add-user-read.yaml create mode 100644 requestor/policies/add-user-write.yaml create mode 100644 requestor/types.go create mode 100644 sower/client.go create mode 100644 sower/types.go delete mode 100755 tests/functions_test.go rename {client/upload => upload}/batch.go (70%) rename {client/upload => upload}/multipart.go (60%) rename {client/upload => upload}/multipart_test.go (54%) rename {client/upload => upload}/progress_reader.go (56%) rename {client/upload => upload}/progress_reader_test.go (95%) rename {client/upload => upload}/request.go (53%) rename {client/upload => upload}/retry.go (77%) create mode 100644 upload/singleFile.go rename {client/upload => upload}/types.go (52%) create mode 100644 upload/upload.go rename {client/upload => upload}/utils.go (87%) rename {client/upload => upload}/utils_test.go (98%) diff --git a/.github/workflows/build_and_test_push.yaml b/.github/workflows/build_and_test_push.yaml index 0b81f2a..4ed6a63 100644 --- a/.github/workflows/build_and_test_push.yaml +++ b/.github/workflows/build_and_test_push.yaml @@ -27,218 +27,4 @@ jobs: - name: Run Tests run: go test -v github.com/uc-cdis/gen3-client/tests - build: - env: - goarch: amd64 - needs: test - runs-on: ubuntu-latest - strategy: - matrix: - include: - - goos: linux - goarch: amd64 - zipfile: dataclient_linux.zip - - goos: darwin - goarch: amd64 - zipfile: dataclient_osx.zip - - goos: windows - goarch: amd64 - zipfile: dataclient_win64.zip - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - - name: Setup Go 1.17 - uses: actions/setup-go@v4 - with: - go-version: '1.17' - - - name: Run Setup Script - run: | - bash .github/scripts/before_install.sh - env: - GITHUB_BRANCH: ${{ github.ref_name }} - ACCESS_KEY: ${{ secrets.AWS_S3_ACCESS_KEY_ID }} - SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }} - - - - name: Run Build Script - run: | - bash .github/scripts/build.sh - env: - GOOS: ${{ matrix.goos }} - GOARCH: ${{ env.goarch }} - GITHUB_BRANCH: ${{ github.ref_name }} - GITHUB_PULL_REQUEST: ${{ github.event_name == 'pull_request' }} - - - name: Upload Artifacts - uses: actions/upload-artifact@v4 - with: - name: build-artifact-${{ matrix.goos }} - path: ~/shared/${{ matrix.zipfile }} - retention-days: 3 - - - sign: - needs: build - runs-on: macos-latest - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - name: Download OSX Artifact - uses: actions/download-artifact@v4 - with: - name: build-artifact-darwin - path: ./dist - - name: Unzip OSX Artifact and remove zip file - run: | - cd ./dist - ls - unzip dataclient_osx.zip - rm dataclient_osx.zip - - - - - name: Build executable - shell: bash - env: - APPLE_CERT_PASSWORD: ${{ secrets.APPLE_CERT_PASSWORD }} - APPLE_NOTARY_UUID: ${{ secrets.APPLE_NOTARY_UUID }} - APPLE_NOTARY_KEY: ${{ secrets.APPLE_NOTARY_KEY}} - APPLE_NOTARY_DATA: ${{ secrets.APPLE_NOTARY_DATA }} - APPLE_CERT_DATA: ${{ secrets.APPLE_CERT_DATA }} - APPLICATION_CERT_PASSWORD: ${{ secrets.APPLICATION_CERT_PASSWORD }} - APPLICATION_CERT_DATA: ${{ secrets.APPLICATION_CERT_DATA }} - APPLE_TEAM_ID: WYQ7U7YUC9 - - run: | - # Setup - SIGNFILE="$(pwd)/dist/gen3-client" - - # Export certs - echo "$APPLE_CERT_DATA" | base64 --decode > /tmp/certs.p12 - echo "$APPLE_NOTARY_DATA" | base64 --decode > /tmp/notary.p8 - echo "$APPLICATION_CERT_DATA" | base64 --decode > /tmp/app_certs.p12 - - # Create keychain - security create-keychain -p actions macos-build.keychain - security default-keychain -s macos-build.keychain - security unlock-keychain -p actions macos-build.keychain - security set-keychain-settings -t 3600 -u macos-build.keychain - - # Import certs to keychain - security import /tmp/certs.p12 -k ~/Library/Keychains/macos-build.keychain -P "$APPLE_CERT_PASSWORD" -T /usr/bin/codesign -T /usr/bin/productsign - security import /tmp/app_certs.p12 -k ~/Library/Keychains/macos-build.keychain -P "$APPLICATION_CERT_PASSWORD" -T /usr/bin/codesign -T /usr/bin/productsign - - # Key signing - security set-key-partition-list -S apple-tool:,apple: -s -k actions macos-build.keychain - - # Verify keychain things - security find-identity -v macos-build.keychain | grep "$APPLE_TEAM_ID" | grep "Developer ID Application" - security find-identity -v macos-build.keychain | grep "$APPLE_TEAM_ID" | grep "Developer ID Installer" - - # Force the codesignature - codesign --force --options=runtime --keychain "/Users/runner/Library/Keychains/macos-build.keychain-db" -s "$APPLE_TEAM_ID" "$SIGNFILE" - # Verify the code signature - codesign -v "$SIGNFILE" --verbose - - mkdir -p ./dist/pkg - cp ./dist/gen3-client ./dist/pkg/gen3-client - pkgbuild --identifier "org.uc-cdis.gen3-client.pkg" --timestamp --install-location /Applications --root ./dist/pkg installer.pkg - pwd - ls - productbuild --resources ./resources --distribution ./distribution.xml gen3-client.pkg - productsign --sign "$APPLE_TEAM_ID" --timestamp gen3-client.pkg gen3-client_signed.pkg - - xcrun notarytool store-credentials "notarytool-profile" --issuer $APPLE_NOTARY_UUID --key-id $APPLE_NOTARY_KEY --key /tmp/notary.p8 - xcrun notarytool submit gen3-client_signed.pkg --keychain-profile "notarytool-profile" --wait - xcrun stapler staple gen3-client_signed.pkg - mv gen3-client_signed.pkg dataclient_osx.pkg - - - name: Upload signed artifact - uses: actions/upload-artifact@v4 - with: - name: build-artifact-darwin-signed - path: dataclient_osx.pkg - - sync_signed_to_aws: - runs-on: ubuntu-latest - needs: sign - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - name: Run Setup Script - run: | - bash ./.github/scripts/before_install.sh - env: - GITHUB_BRANCH: ${{ github.ref_name }} - ACCESS_KEY: ${{ secrets.AWS_S3_ACCESS_KEY_ID }} - SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }} - - - name: Download OSX Artifact - uses: actions/download-artifact@v4 - with: - name: build-artifact-darwin-signed - - - name: Sync to AWS - env: - GITHUB_BRANCH: ${{ github.ref_name }} - run: | - rm ~/shared/dataclient_osx.zip - zip dataclient_osx_signed.zip dataclient_osx.pkg - mv dataclient_osx_signed.zip ~/shared/ - aws s3 sync ~/shared s3://cdis-dc-builds/$GITHUB_BRANCH - - - get_tagged_branch: - if: startsWith(github.ref, 'refs/tags/') - runs-on: ubuntu-latest - needs: [build,sign] - outputs: - branch: ${{ steps.check_step.outputs.branch }} - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Get current branch - id: check_step - # 1. Get the list of branches ref where this tag exists - # 2. Remove 'origin/' from that result - # 3. Put that string in output - # => We can now use function 'contains(list, item)'' - run: | - raw=$(git branch -r --contains ${{ github.ref }}) - branch="$(echo ${raw//origin\//} | tr -d '\n')" - echo "{name}=branch" >> $GITHUB_OUTPUT - echo "Branches where this tag exists : $branch." - - - deploy: - needs: get_tagged_branch - if: startsWith(github.ref, 'refs/tags/') && contains(${{ needs.get_tagged_branch.outputs.branch }}, 'master') - runs-on: ubuntu-latest - steps: - - name: Download Linux Artifact - uses: actions/download-artifact@v4 - with: - name: build-artifact-linux - - - name: Download OSX Artifact - uses: actions/download-artifact@v4 - with: - name: build-artifact-darwin-signed - - - name: Download Windows Artifact - uses: actions/download-artifact@v4 - with: - name: build-artifact-windows - - - name: Create Release gh cli - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GH_TAG: ${{ github.ref_name }} - run: gh release create "$GH_TAG" dataclient_linux.zip dataclient_osx.pkg dataclient_win64.zip --repo="$GITHUB_REPOSITORY" diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml new file mode 100644 index 0000000..e4e7063 --- /dev/null +++ b/.github/workflows/coverage.yaml @@ -0,0 +1,103 @@ +name: "Test Coverage Check" + +on: + pull_request: + branches: + - master + push: + branches: + - master + +jobs: + coverage: + name: Test Coverage + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.24.2' + + - name: Run Tests with Coverage + run: | + go test -coverprofile=coverage.out -covermode=atomic ./... + continue-on-error: true + + - name: Generate Coverage Report + id: coverage + run: | + # Get overall coverage + OVERALL=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//') + echo "overall=$OVERALL" >> $GITHUB_OUTPUT + + # Generate detailed report + echo "## Test Coverage Report" > coverage-report.md + echo "" >> coverage-report.md + echo "**Overall Coverage:** ${OVERALL}%" >> coverage-report.md + echo "" >> coverage-report.md + echo "### Package Coverage" >> coverage-report.md + echo "" >> coverage-report.md + echo "| Package | Coverage |" >> coverage-report.md + echo "|---------|----------|" >> coverage-report.md + + # Extract package coverage + go test -coverprofile=/dev/null -covermode=atomic ./... 2>&1 | \ + grep "coverage:" | \ + grep -v "setup failed" | \ + awk '{ + pkg=$1; + cov=$4; + gsub(/github.com\/calypr\/data-client\//, "", pkg); + if (cov ~ /statements/) { + print "| " pkg " | " cov " |" + } else { + print "| " pkg " | " cov " |" + } + }' >> coverage-report.md + + cat coverage-report.md + + - name: Check Coverage Thresholds + run: | + chmod +x ./scripts/check-coverage.sh + ./scripts/check-coverage.sh 30 20 + + - name: Upload Coverage to Codecov (Optional) + uses: codecov/codecov-action@v4 + if: always() + with: + files: ./coverage.out + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + - name: Comment PR with Coverage + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const coverage = fs.readFileSync('coverage-report.md', 'utf8'); + + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: coverage + }); + + - name: Upload Coverage Artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: | + coverage.out + coverage-report.md + retention-days: 30 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..c23cf06 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,34 @@ +name: Release + +on: + push: + tags: + - '*' + workflow_dispatch: + +permissions: + contents: write + +jobs: + goreleaser: + runs-on: ubuntu-latest + steps: + - + name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + - + name: Set up Go + uses: actions/setup-go@v5 + - + name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + with: + # either 'goreleaser' (default) or 'goreleaser-pro' + distribution: goreleaser + # 'latest', 'nightly', or a semver + version: 'latest' + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Makefile b/Makefile index e524d95..c936cd2 100644 --- a/Makefile +++ b/Makefile @@ -10,9 +10,13 @@ MAIN_PACKAGE := . # The directory where the final binary will be placed BIN_DIR := ./bin +# Coverage thresholds +COVERAGE_THRESHOLD := 30 +PACKAGE_COVERAGE_THRESHOLD := 20 + # --- Targets --- -.PHONY: all build test generate tidy clean help +.PHONY: all build test test-coverage coverage-html coverage-check generate tidy clean help # The default target run when you type 'make' all: build @@ -28,6 +32,24 @@ test: @echo "--> Running all tests..." @go test -v ./... +## test-coverage: Runs tests with coverage profiling +test-coverage: + @echo "--> Running tests with coverage..." + @go test -coverprofile=coverage.out -covermode=atomic ./... + @echo "--> Coverage report generated: coverage.out" + @go tool cover -func=coverage.out | tail -1 + +## coverage-html: Generates HTML coverage report +coverage-html: test-coverage + @echo "--> Generating HTML coverage report..." + @go tool cover -html=coverage.out -o coverage.html + @echo "--> HTML coverage report generated: coverage.html" + +## coverage-check: Verifies coverage meets minimum thresholds +coverage-check: test-coverage + @echo "--> Checking coverage thresholds..." + @./scripts/check-coverage.sh $(COVERAGE_THRESHOLD) $(PACKAGE_COVERAGE_THRESHOLD) + ## generate: Runs go generate commands to create mocks, embedded assets, etc. generate: @echo "--> Running code generation (go generate)..." @@ -39,8 +61,9 @@ tidy: @go mod tidy @go fmt ./... -## clean: Removes the compiled binary +## clean: Removes the compiled binary and coverage files clean: @echo "--> Cleaning up..." @rm -f $(BIN_DIR)/$(TARGET_NAME) + @rm -f coverage.out coverage.html diff --git a/build.sh b/build.sh deleted file mode 100755 index 213b40a..0000000 --- a/build.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env bash -# -# Adapted from 'How To Build Go Executables for Multiple Platforms on Ubuntu 16.04' -# By Marko Mudrinić -# -# Usage: -# ./build.sh - -if [ "$1" == "-h" ] || [ "$1" == "--help" ]; then - echo "usage: $0" - echo "output: zipped executables to ./build directory" -fi - -package=$1 - -if [[ -z "$package" ]]; then - package='gen3-client' -fi - -package_split=(${package//\// }) -package_name=${package_split[-1]} - -platforms=( - "darwin/arm64" - "darwin/amd64" - "linux/amd64" - "windows/amd64" -) - -mkdir -p ./build -> checksums.txt -for platform in "${platforms[@]}" -do - platform_split=(${platform//\// }) - GOOS=${platform_split[0]} - GOARCH=${platform_split[1]} - output_name=$package_name'-'$GOOS'-'$GOARCH - exe_name=$package_name - - if [ $GOOS = "windows" ]; then - exe_name+='.exe' - - elif [ $GOOS = "darwin" ]; then - if [ $GOARCH = "arm64" ]; then - output_name=$package_name'-macos' - - elif [ $GOARCH = "amd64" ]; then - output_name=$package_name'-macos-intel' - fi - fi - - printf 'Building %s...' "$output_name" - env GOOS=$GOOS GOARCH=$GOARCH go build -o ./build/$exe_name . - cd build - zip -r -q $output_name $exe_name - sha256sum $output_name.zip >> checksums.txt - cd .. - - if [ $? -ne 0 ]; then - echo 'An error has occurred! Aborting the script execution...' - exit 1 - fi - echo 'OK' -done - -# Clean up build artifacts -rm build/{$package_name,$package_name.exe} - diff --git a/bump-tag.sh b/bump-tag.sh new file mode 100644 index 0000000..b38169e --- /dev/null +++ b/bump-tag.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +# File: `bump-patch.sh` +set -euo pipefail + +# Find latest tag excluding major v0 +LATEST_TAG=$(git tag --list --sort=-v:refname | grep -v '^v0' | head -n1 || true) +if [ -z "$LATEST_TAG" ]; then + echo "No suitable tag found (excluding v0). Aborting." >&2 + exit 1 +fi + +# check that the working directory is clean +if [ -n "$(git status --porcelain)" ]; then + echo "Working directory is not clean. Please commit or stash changes before running this script." >&2 + exit 1 +fi + +usage() { + cat <<-EOF +Usage: $0 [--major | --minor | --patch] + +LATEST_TAG: $LATEST_TAG + +Options: + --major Bump major (MAJOR+1, MINOR=0, PATCH=0) + --minor Bump minor (MINOR+1, PATCH=0) + --patch Bump patch (PATCH+1) [default] +EOF + exit 1 +} + +# Parse options +opt_major=false +opt_minor=false +opt_patch=false +count=0 + +while [ $# -gt 0 ]; do + case "$1" in + --major) + opt_major=true + count=$((count + 1)) + shift + ;; + --minor) + opt_minor=true + count=$((count + 1)) + shift + ;; + --patch) + opt_patch=true + count=$((count + 1)) + shift + ;; + --help|-h) + usage + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Default to patch if no option provided +if [ "$count" -eq 0 ]; then + opt_patch=true +fi + +# Disallow specifying more than one +if [ "$count" -gt 1 ]; then + echo "Specify only one of --major, --minor, or --patch" >&2 + exit 1 +fi + + +# Parse semver vMAJOR.MINOR.PATCH +if [[ "$LATEST_TAG" =~ ^v?([0-9]+)\.([0-9]+)\.([0-9]+)$ ]]; then + MAJOR="${BASH_REMATCH[1]}" + MINOR="${BASH_REMATCH[2]}" + PATCH="${BASH_REMATCH[3]}" +else + echo "Latest tag '$LATEST_TAG' is not in semver format. Aborting." >&2 + exit 1 +fi + +# Compute new version +if [ "$opt_major" = true ]; then + NEW_MAJOR=$((MAJOR + 1)) + NEW_MINOR=0 + NEW_PATCH=0 + NEW_TAG="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" + NEW_FILE_VER="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" +elif [ "$opt_minor" = true ]; then + NEW_MAJOR=$MAJOR + NEW_MINOR=$((MINOR + 1)) + NEW_PATCH=0 + NEW_TAG="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" + NEW_FILE_VER="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" +else + # patch + NEW_MAJOR=$MAJOR + NEW_MINOR=$MINOR + NEW_PATCH=$((PATCH + 1)) + NEW_TAG="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" + NEW_FILE_VER="${NEW_MAJOR}.${NEW_MINOR}.${NEW_PATCH}" +fi + +BRANCH="$(git rev-parse --abbrev-ref HEAD)" + +echo "Latest branch: $BRANCH" +echo "Latest tag: $LATEST_TAG" +echo "New tag: $NEW_TAG (files will use ${NEW_FILE_VER})" + +# Update internal version file +if [ -f cmd/gitversion.go ]; then + # sed on mac is -i '' + sed -E -i '' -e "s/(gitversion *= *\")[^\"]+(\")/\1${NEW_FILE_VER}\2/" cmd/gitversion.go + git add cmd/gitversion.go +fi + +# Commit, tag and push +git commit -m "chore(release): bump to ${NEW_TAG}" || echo "No changes to commit" +git tag -a "${NEW_TAG}" -m "Release ${NEW_TAG}" +echo "Created tag. Please push tag ${NEW_TAG} on branch ${BRANCH}." + +echo git push origin "${BRANCH}" +echo git push origin "${NEW_TAG}" diff --git a/client/api/gen3.go b/client/api/gen3.go deleted file mode 100644 index ff3ecde..0000000 --- a/client/api/gen3.go +++ /dev/null @@ -1,367 +0,0 @@ -package api - -//go:generate mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/api FunctionInterface - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/request" - "github.com/hashicorp/go-version" -) - -func NewFunctions(config conf.ManagerInterface, request request.RequestInterface, cred *conf.Credential, logger logs.Logger) FunctionInterface { - return &Functions{ - RequestInterface: request, - Cred: cred, - Config: config, - Logger: logger, - } -} - -type Functions struct { - request.RequestInterface - - Cred *conf.Credential - Config conf.ManagerInterface - Logger logs.Logger -} - -type FunctionInterface interface { - request.RequestInterface - - CheckPrivileges(ctx context.Context) (map[string]any, error) - CheckForShepherdAPI(ctx context.Context) (bool, error) - DeleteRecord(ctx context.Context, guid string) (string, error) - GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) - - ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) - ExportCredential(ctx context.Context, cred *conf.Credential) error - NewAccessToken(ctx context.Context) error -} - -func (f *Functions) NewAccessToken(ctx context.Context) error { - if f.Cred.APIKey == "" { - return errors.New("APIKey is required to refresh access token") - } - - payload, err := json.Marshal(map[string]string{"api_key": f.Cred.APIKey}) - if err != nil { - return err - } - bodyReader := bytes.NewReader(payload) - - resp, err := f.Do( - ctx, - f.New(http.MethodPost, f.Cred.APIEndpoint+common.FenceAccessTokenEndpoint). - WithHeader(common.HeaderContentType, common.MIMEApplicationJSON). - WithBody(bodyReader), - ) - - if err != nil { - return fmt.Errorf("Error when calling Request.Do: %s", err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.New("failed to refresh token, status: " + strconv.Itoa(resp.StatusCode)) - } - - var result common.AccessTokenStruct - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return errors.New("failed to parse token response: " + err.Error()) - } - - f.Cred.AccessToken = result.AccessToken - return nil -} - -func (f *Functions) GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { - hasShepherd, err := f.CheckForShepherdAPI(ctx) // error already logged upstream - if err == nil && hasShepherd { - return f.resolveFromShepherd(ctx, guid) - } - return f.resolveFromFence(ctx, guid, protocolText) -} - -// Todo: why isn't this calld in every fence response that has a body ? why is this seperated out -func (f *Functions) ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) { - msg := FenceResponse{} - if resp == nil { - return msg, errors.New("Nil response received") - } - - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - bodyStr := buf.String() - - err := json.Unmarshal(buf.Bytes(), &msg) - if err != nil { - return msg, fmt.Errorf("failed to decode JSON: %w (Raw body: %s)", err, buf.String()) - } - - if !(resp.StatusCode == 200 || resp.StatusCode == 201 || resp.StatusCode == 204) { - strUrl := resp.Request.URL.String() - switch resp.StatusCode { - case http.StatusUnauthorized: - return msg, fmt.Errorf("401 Unauthorized: %s (URL: %s)", bodyStr, strUrl) - case http.StatusForbidden: - return msg, fmt.Errorf("403 Forbidden: %s (URL: %s)", bodyStr, strUrl) - case http.StatusNotFound: - return msg, fmt.Errorf("404 Not Found: %s (URL: %s)", bodyStr, strUrl) - case http.StatusInternalServerError: - return msg, fmt.Errorf("500 Internal Server Error: %s (URL: %s)", bodyStr, strUrl) - case http.StatusServiceUnavailable: - return msg, fmt.Errorf("503 Service Unavailable: %s (URL: %s)", bodyStr, strUrl) - case http.StatusBadGateway: - return msg, fmt.Errorf("502 Bad Gateway: %s (URL: %s)", bodyStr, strUrl) - default: - return msg, fmt.Errorf("Unexpected Error (%d): %s (URL: %s)", resp.StatusCode, bodyStr, strUrl) - } - } - - // Logic for successful status codes - if strings.Contains(bodyStr, "Can't find a location for the data") { - return msg, errors.New("The provided GUID is not found") - } - - return msg, nil -} - -func (f *Functions) CheckForShepherdAPI(ctx context.Context) (bool, error) { - // Check if Shepherd is enabled - if f.Cred.UseShepherd == "false" { - return false, nil - } - if f.Cred.UseShepherd != "true" && common.DefaultUseShepherd == false { - return false, nil - } - // If Shepherd is enabled, make sure that the commons has a compatible version of Shepherd deployed. - // Compare the version returned from the Shepherd version endpoint with the minimum acceptable Shepherd version. - var minShepherdVersion string - if f.Cred.MinShepherdVersion == "" { - minShepherdVersion = common.DefaultMinShepherdVersion - } else { - minShepherdVersion = f.Cred.MinShepherdVersion - } - - res, err := f.Do(ctx, - &request.RequestBuilder{ - Url: f.Cred.APIEndpoint + common.ShepherdVersionEndpoint, - Method: http.MethodGet, - Token: f.Cred.AccessToken, - }, - ) - if err != nil { - return false, errors.New("Error occurred during generating HTTP request: " + err.Error()) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return false, nil - } - bodyBytes, err := io.ReadAll(res.Body) - if err != nil { - return false, errors.New("Error occurred when reading HTTP request: " + err.Error()) - } - body, err := strconv.Unquote(string(bodyBytes)) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) - } - // Compare the version in the response to the target version - ver, err := version.NewVersion(body) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) - } - minVer, err := version.NewVersion(minShepherdVersion) - if err != nil { - return false, fmt.Errorf("Error occurred when parsing minimum acceptable Shepherd version: %v: %v", minShepherdVersion, err) - } - if ver.GreaterThanOrEqual(minVer) { - return true, nil - } - return false, fmt.Errorf("Shepherd is enabled, but %v does not have correct Shepherd version. (Need Shepherd version >=%v, got %v)", f.Cred.APIEndpoint, minVer, ver) -} -func (f *Functions) CheckPrivileges(ctx context.Context) (map[string]any, error) { - /* - Return user privileges from specified profile - */ - var err error - var data map[string]any - - resp, err := f.Do(ctx, - &request.RequestBuilder{ - Url: f.Cred.APIEndpoint + common.FenceUserEndpoint, - Method: http.MethodGet, - Token: f.Cred.AccessToken, - }, - ) - if err != nil { - return nil, errors.New("Error occurred when getting response from remote: " + err.Error()) - } - defer resp.Body.Close() - - str := ResponseToString(resp) - err = json.Unmarshal([]byte(str), &data) - if err != nil { - return nil, errors.New("Error occurred when unmarshalling response: " + err.Error()) - } - - resourceAccess, ok := data["authz"].(map[string]any) - - // If the `authz` section (Arborist permissions) is empty or missing, try get `project_access` section (Fence permissions) - if len(resourceAccess) == 0 || !ok { - resourceAccess, ok = data["project_access"].(map[string]any) - if !ok { - return nil, errors.New("Not possible to read access privileges of user") - } - } - - return resourceAccess, err -} - -func (f *Functions) DeleteRecord(ctx context.Context, guid string) (string, error) { - endpoint := common.FenceDataEndpoint + "/" + guid - msg := "" - hasShepherd, err := f.CheckForShepherdAPI(ctx) - if err != nil { - f.Logger.Printf("WARNING: Error checking Shepherd API: %v. Falling back to Fence.\n", err) - } else if hasShepherd { - endpoint = common.ShepherdEndpoint + "/objects/" + guid - } - - resp, err := f.Do(ctx, - &request.RequestBuilder{ - Url: f.Cred.APIEndpoint + endpoint, - Method: http.MethodDelete, - Token: f.Cred.AccessToken, - }, - ) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode == 204 { - msg = "Record with GUID " + guid + " has been deleted" - } else { - _, err = f.ParseFenceURLResponse(resp) - if err != nil { - return "", err - } - } - return msg, nil -} - -func (f *Functions) ExportCredential(ctx context.Context, cred *conf.Credential) error { - - if cred.Profile == "" { - return fmt.Errorf("profile name is required") - } - if cred.APIEndpoint == "" { - return fmt.Errorf("API endpoint is required") - } - - // Normalize endpoint - cred.APIEndpoint = strings.TrimSpace(cred.APIEndpoint) - cred.APIEndpoint = strings.TrimSuffix(cred.APIEndpoint, "/") - - // Validate URL format - parsedURL, err := conf.ValidateUrl(cred.APIEndpoint) - if err != nil { - return fmt.Errorf("invalid apiendpoint URL: %w", err) - } - fenceBase := parsedURL.Scheme + "://" + parsedURL.Host - if _, err := f.Config.Load(cred.Profile); err != nil && !errors.Is(err, conf.ErrProfileNotFound) { - return err - } - - if cred.APIKey != "" { - // Always refresh the access token — ignore any old one that might be in the struct - err = f.NewAccessToken(ctx) - if err != nil { - if strings.Contains(err.Error(), "401") { - return fmt.Errorf("authentication failed (401) for %s — your API key is invalid, revoked, or expired", fenceBase) - } - if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "no such host") { - return fmt.Errorf("cannot reach Fence at %s — is this a valid Gen3 commons?", fenceBase) - } - return fmt.Errorf("failed to refresh access token: %w", err) - } - } else { - f.Logger.Printf("WARNING: Your profile will only be valid for 24 hours since you have only provided a refresh token for authentication") - } - - // Clean up shepherd flags - cred.UseShepherd = strings.TrimSpace(cred.UseShepherd) - cred.MinShepherdVersion = strings.TrimSpace(cred.MinShepherdVersion) - - if cred.MinShepherdVersion != "" { - if _, err = version.NewVersion(cred.MinShepherdVersion); err != nil { - return fmt.Errorf("invalid min-shepherd-version: %w", err) - } - } - - if err := f.Config.Save(cred); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - - return nil -} - -func (f *Functions) resolveFromShepherd(ctx context.Context, guid string) (string, error) { - // We use f.Cred.APIEndpoint because the struct owns the credential state - url := fmt.Sprintf("%s%s/objects/%s/download", f.Cred.APIEndpoint, common.ShepherdEndpoint, guid) - - // We call f.Do directly because of method promotion (embedding) - resp, err := f.Do(ctx, f.New(http.MethodGet, url)) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("shepherd error: %d", resp.StatusCode) - } - - var result struct { - URL string `json:"url"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to decode shepherd response: %w", err) - } - - return result.URL, nil -} - -func (f *Functions) resolveFromFence(ctx context.Context, guid, protocolText string) (string, error) { - resp, err := f.Do( - ctx, - &request.RequestBuilder{ - Url: f.Cred.APIEndpoint + common.FenceDataDownloadEndpoint + "/" + guid + protocolText, - Method: http.MethodGet, - Token: f.Cred.AccessToken, - }, - ) - if err != nil { - return "", errors.New("Failed to get URL from Fence via DoAuthenticatedRequest: " + err.Error()) - } - defer resp.Body.Close() - - msg, err := f.ParseFenceURLResponse(resp) - if err != nil || msg.URL == "" { - return "", errors.New("Failed to get URL from Fence via ParseFenceURLResponse: " + err.Error()) - } - - return msg.URL, nil -} diff --git a/client/api/types.go b/client/api/types.go deleted file mode 100644 index 59feec5..0000000 --- a/client/api/types.go +++ /dev/null @@ -1,25 +0,0 @@ -package api - -import ( - "bytes" - "net/http" -) - -type Message any -type Response any - -type FenceResponse struct { - URL string `json:"url"` - GUID string `json:"guid"` - UploadID string `json:"uploadId"` - PresignedURL string `json:"presigned_url"` - FileName string `json:"file_name"` - URLs []string `json:"urls"` - Size int64 `json:"size"` -} - -func ResponseToString(resp *http.Response) string { - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) // nolint: errcheck - return buf.String() -} diff --git a/client/client/client.go b/client/client/client.go deleted file mode 100644 index 41fb41f..0000000 --- a/client/client/client.go +++ /dev/null @@ -1,69 +0,0 @@ -package client - -import ( - "context" - "fmt" - - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/request" -) - -//go:generate mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/client Gen3Interface - -// Top level wrapper Interface for calling lower level interface functions. -// -// Gen3Interface contains minimum number of methods to enable calling functions in the FunctionInterface -// The credential is embedded in the implementation, so it doesn't need to be passed to each method. -type Gen3Interface interface { - GetCredential() *conf.Credential - Logger() *logs.TeeLogger - - api.FunctionInterface -} - -// Gen3Client wraps jwt.FunctionInterface and embeds the credential -type Gen3Client struct { - Ctx context.Context - api.FunctionInterface - - credential *conf.Credential - logger *logs.TeeLogger -} - -func (g *Gen3Client) Logger() *logs.TeeLogger { - return g.logger -} - -// GetCredential returns the embedded credential -func (g *Gen3Client) GetCredential() *conf.Credential { - return g.credential -} - -// NewGen3Interface returns a Gen3Client that embeds the credential and implements Gen3Interface. -// This eliminates the need to pass credentials around everywhere. -func NewGen3Interface(profile string, logger *logs.TeeLogger, opts ...func(*Gen3Client)) (Gen3Interface, error) { - config := conf.NewConfigure(logger) - cred, err := config.Load(profile) - if err != nil { - return nil, err - } - - if valid, err := config.IsValid(cred); !valid { - return nil, fmt.Errorf("invalid credential: %v", err) - } - - apiClient := api.NewFunctions( - config, - request.NewRequestInterface(logger, cred, config), - cred, - logger, - ) - - return &Gen3Client{ - FunctionInterface: apiClient, - credential: cred, - logger: logger, - }, nil -} diff --git a/client/common/constants_test.go b/client/common/constants_test.go deleted file mode 100644 index 8eed0e0..0000000 --- a/client/common/constants_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package common - -import ( - "os" - "os/exec" - "path/filepath" - "testing" -) - -func TestGetLfsCustomTransferInt(t *testing.T) { - configDir := t.TempDir() - configPath := filepath.Join(configDir, "gitconfig") - - setConfig := func(t *testing.T, key, value string) { - t.Helper() - cmd := exec.Command("git", "config", "--file", configPath, key, value) - if err := cmd.Run(); err != nil { - t.Fatalf("set git config %s=%s: %v", key, value, err) - } - } - - setEnv := func(t *testing.T) { - t.Helper() - t.Setenv("GIT_CONFIG_GLOBAL", configPath) - t.Setenv("GIT_CONFIG_SYSTEM", os.DevNull) - t.Setenv("GIT_CONFIG_NOSYSTEM", "1") - } - - const key = "lfs.customtransfer.drs.multipart-min-chunk-size" - - tests := []struct { - name string - value string - defaultVal int64 - want int64 - wantErr bool - setValue bool - }{ - { - name: "missing uses default", - defaultVal: 10, - want: 10, - wantErr: false, - setValue: false, - }, - { - name: "valid value", - value: "25", - defaultVal: 10, - want: 25, - wantErr: false, - setValue: true, - }, - { - name: "negative value", - value: "-3", - defaultVal: 10, - want: 10, - wantErr: true, - setValue: true, - }, - { - name: "zero value", - value: "0", - defaultVal: 10, - want: 10, - wantErr: true, - setValue: true, - }, - { - name: "over max", - value: "501", - defaultVal: 10, - want: 10, - wantErr: true, - setValue: true, - }, - { - name: "non-integer", - value: "abc", - defaultVal: 10, - want: 10, - wantErr: true, - setValue: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := os.WriteFile(configPath, nil, 0o600); err != nil { - t.Fatalf("reset git config: %v", err) - } - if tt.setValue { - setConfig(t, key, tt.value) - } - setEnv(t) - - got, err := GetLfsCustomTransferInt(key, tt.defaultVal) - if tt.wantErr && err == nil { - t.Fatalf("expected error, got nil") - } - if !tt.wantErr && err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != tt.want { - t.Fatalf("value = %d, want %d", got, tt.want) - } - }) - } -} diff --git a/client/common/progress.go b/client/common/progress.go deleted file mode 100644 index c743e7c..0000000 --- a/client/common/progress.go +++ /dev/null @@ -1,12 +0,0 @@ -package common - -// ProgressEvent matches the Git LFS custom transfer progress payload. -type ProgressEvent struct { - Event string `json:"event"` - Oid string `json:"oid"` - BytesSoFar int64 `json:"bytesSoFar"` - BytesSinceLast int64 `json:"bytesSinceLast"` -} - -// ProgressCallback emits transfer progress updates. -type ProgressCallback func(ProgressEvent) error diff --git a/client/download/url_resolution.go b/client/download/url_resolution.go deleted file mode 100644 index 475a55e..0000000 --- a/client/download/url_resolution.go +++ /dev/null @@ -1,80 +0,0 @@ -package download - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/request" -) - -// GetDownloadResponse gets presigned URL and prepares HTTP response -func GetDownloadResponse(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject, protocolText string) error { - url, err := g3.GetDownloadPresignedUrl(ctx, fdr.GUID, protocolText) - if err != nil { - return err - } - fdr.URL = url - - if fdr.Range > 0 && !isCloudPresignedURL(url) { - if !supportsRange(url) { - fdr.Range = 0 - } - } - - return makeDownloadRequest(ctx, g3, fdr) -} - -func isCloudPresignedURL(url string) bool { - return strings.Contains(url, "X-Amz-Signature") || strings.Contains(url, "X-Goog-Signature") -} - -func supportsRange(url string) bool { - resp, err := http.Head(url) - if err != nil || resp.Header.Get("Accept-Ranges") != "bytes" { - return false - } - return true -} - -func makeDownloadRequest(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject) error { - headers := map[string]string{} - if fdr.Range > 0 { - headers["Range"] = "bytes=" + strconv.FormatInt(fdr.Range, 10) + "-" - } - - resp, err := g3.Do( - ctx, - &request.RequestBuilder{ - Method: http.MethodGet, - Url: fdr.URL, - Headers: headers, - }, - ) - - if err != nil { - return errors.New("Request failed: " + strings.ReplaceAll(err.Error(), fdr.URL, "")) - } - - // Check for non-success status codes - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { - defer resp.Body.Close() // Ensure the body is closed - - bodyBytes, err := io.ReadAll(resp.Body) - bodyString := "" - if err == nil { - bodyString = string(bodyBytes) - } - - return fmt.Errorf("non-OK response: %d, body: %s", resp.StatusCode, bodyString) - } - - fdr.Response = resp - return nil -} diff --git a/client/logs/logger.go b/client/logs/logger.go deleted file mode 100644 index 7a6d53b..0000000 --- a/client/logs/logger.go +++ /dev/null @@ -1,41 +0,0 @@ -package logs - -import ( - "io" -) - -type Logger interface { - Printf(format string, v ...any) - Println(v ...any) - Fatalf(format string, v ...any) - Fatal(v ...any) - Writer() io.Writer -} - -type Option func(*config) - -type config struct { - console bool - messageFile bool - failedLog bool - succeededLog bool - enableScoreboard bool - baseLogger Logger -} - -func WithConsole() Option { return func(c *config) { c.console = true } } -func WithMessageFile() Option { return func(c *config) { c.messageFile = true } } -func WithFailedLog() Option { return func(c *config) { c.failedLog = true } } -func WithSucceededLog() Option { return func(c *config) { c.succeededLog = true } } -func WithScoreboard() Option { return func(c *config) { c.enableScoreboard = true } } -func WithBaseLogger(base Logger) Option { return func(c *config) { c.baseLogger = base } } - -func defaults() *config { - return &config{ - console: true, - messageFile: true, - failedLog: true, - succeededLog: true, - baseLogger: nil, - } -} diff --git a/client/logs/tee_logger.go b/client/logs/tee_logger.go deleted file mode 100644 index 08a6d0f..0000000 --- a/client/logs/tee_logger.go +++ /dev/null @@ -1,172 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "io" // Added for standard logging methods like Fatal - "maps" - "os" - "sync" - - "github.com/calypr/data-client/client/common" -) - -// --- teeLogger Implementation --- -type TeeLogger struct { - mu sync.RWMutex - writers []io.Writer - scoreboard *Scoreboard - - failedMu sync.Mutex - FailedMap map[string]common.RetryObject // Maps filePath to FileMetadata - failedPath string - - succeededMu sync.Mutex - succeededMap map[string]string // Maps filePath to GUID - succeededPath string -} - -// NewTeeLogger combines initialization and log loading (replacing initSyncLogs) -func NewTeeLogger(logDir, profile string, writers ...io.Writer) *TeeLogger { - t := &TeeLogger{ - mu: sync.RWMutex{}, - writers: writers, - scoreboard: nil, - - FailedMap: make(map[string]common.RetryObject), - succeededMap: make(map[string]string), - } - - return t -} - -// Internal helper function (replaces the global loadJSON) -func loadJSON(path string, v any) { - data, _ := os.ReadFile(path) - if len(data) > 0 { - // Error handling for Unmarshal is often omitted in utility code - // but is good practice. We keep the original style for now. - json.Unmarshal(data, v) - } -} - -// --- Public Logger Methods --- - -// Printf implements part of the standard Logger interface. -func (t *TeeLogger) Printf(format string, v ...any) { - t.write(fmt.Sprintf(format, v...)) -} - -// Println implements part of the standard Logger interface. -func (t *TeeLogger) Println(v ...any) { - t.write(fmt.Sprintln(v...)) -} - -// Fatalf implements part of the standard Logger interface and exits the program. -func (t *TeeLogger) Fatalf(format string, v ...any) { - s := fmt.Sprintf(format, v...) - t.write(s) - os.Exit(1) -} - -// Fatal implements part of the standard Logger interface and exits the program. -func (t *TeeLogger) Fatal(v ...any) { - s := fmt.Sprintln(v...) - t.write(s) - os.Exit(1) -} - -// Writer implements part of the standard Logger interface, returning a multi-writer. -func (t *TeeLogger) Writer() io.Writer { - t.mu.RLock() - defer t.mu.RUnlock() - return io.MultiWriter(t.writers...) -} - -// Scoreboard returns the embedded ScoreboardAccess. -func (t *TeeLogger) Scoreboard() *Scoreboard { - return t.scoreboard -} - -// GetSucceededLogMap returns a copy of the succeeded log map. -func (t *TeeLogger) GetSucceededLogMap() map[string]string { - t.succeededMu.Lock() - defer t.succeededMu.Unlock() - // Return a copy to prevent external modification - copiedMap := make(map[string]string, len(t.succeededMap)) - maps.Copy(copiedMap, t.succeededMap) - - return copiedMap -} - -// GetFailedLogMap returns a copy of the failed log map. -func (t *TeeLogger) GetFailedLogMap() map[string]common.RetryObject { - t.failedMu.Lock() - defer t.failedMu.Unlock() - // Return a copy to prevent external modification - copiedMap := make(map[string]common.RetryObject, len(t.FailedMap)) - maps.Copy(copiedMap, t.FailedMap) - return copiedMap -} - -func (t *TeeLogger) DeleteFromFailedLog(path string) { - t.failedMu.Lock() - defer t.failedMu.Unlock() - delete(t.FailedMap, path) -} - -// --- Internal Utility Methods --- - -// write handles writing the string to all configured writers. -func (t *TeeLogger) write(s string) { - t.mu.RLock() - defer t.mu.RUnlock() - for _, w := range t.writers { - _, _ = fmt.Fprint(w, s) - } -} - -func (t *TeeLogger) GetSucceededCount() int { - return len(t.succeededMap) -} - -func (t *TeeLogger) writeFailedSync(e common.RetryObject) { - t.failedMu.Lock() - defer t.failedMu.Unlock() - - // Store the FileMetadata part in the map - t.FailedMap[e.FilePath] = e - - data, _ := json.MarshalIndent(t.FailedMap, "", " ") - os.WriteFile(t.failedPath, data, 0644) -} - -func (t *TeeLogger) writeSucceededSync(path, guid string) { - t.succeededMu.Lock() - defer t.succeededMu.Unlock() - t.succeededMap[path] = guid - data, _ := json.MarshalIndent(t.succeededMap, "", " ") - os.WriteFile(t.succeededPath, data, 0644) -} - -// --- Tracking Methods (Part of Logger Interface) --- - -func (t *TeeLogger) Failed(filePath, filename string, metadata common.FileMetadata, guid string, retryCount int, multipart bool) { - if t.failedPath != "" { - t.writeFailedSync(common.RetryObject{ - FilePath: filePath, - Filename: filename, - FileMetadata: metadata, - GUID: guid, - RetryCount: retryCount, - Multipart: multipart, - }) - } -} - -func (t *TeeLogger) Succeeded(filePath, guid string) { - // Use t.succeededPath instead of checking the old global succeededPath - if t.succeededPath != "" { - t.writeSucceededSync(filePath, guid) - } -} diff --git a/client/mocks/mock_gen3interface.go b/client/mocks/mock_gen3interface.go deleted file mode 100644 index 99f3f25..0000000 --- a/client/mocks/mock_gen3interface.go +++ /dev/null @@ -1,192 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/client (interfaces: Gen3Interface) -// -// Generated by this command: -// -// mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/client/client Gen3Interface -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - http "net/http" - reflect "reflect" - - api "github.com/calypr/data-client/client/api" - conf "github.com/calypr/data-client/client/conf" - logs "github.com/calypr/data-client/client/logs" - request "github.com/calypr/data-client/client/request" - gomock "go.uber.org/mock/gomock" -) - -// MockGen3Interface is a mock of Gen3Interface interface. -type MockGen3Interface struct { - ctrl *gomock.Controller - recorder *MockGen3InterfaceMockRecorder - isgomock struct{} -} - -// MockGen3InterfaceMockRecorder is the mock recorder for MockGen3Interface. -type MockGen3InterfaceMockRecorder struct { - mock *MockGen3Interface -} - -// NewMockGen3Interface creates a new mock instance. -func NewMockGen3Interface(ctrl *gomock.Controller) *MockGen3Interface { - mock := &MockGen3Interface{ctrl: ctrl} - mock.recorder = &MockGen3InterfaceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockGen3Interface) EXPECT() *MockGen3InterfaceMockRecorder { - return m.recorder -} - -// CheckForShepherdAPI mocks base method. -func (m *MockGen3Interface) CheckForShepherdAPI(ctx context.Context) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckForShepherdAPI", ctx) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CheckForShepherdAPI indicates an expected call of CheckForShepherdAPI. -func (mr *MockGen3InterfaceMockRecorder) CheckForShepherdAPI(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockGen3Interface)(nil).CheckForShepherdAPI), ctx) -} - -// CheckPrivileges mocks base method. -func (m *MockGen3Interface) CheckPrivileges(ctx context.Context) (map[string]any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckPrivileges", ctx) - ret0, _ := ret[0].(map[string]any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CheckPrivileges indicates an expected call of CheckPrivileges. -func (mr *MockGen3InterfaceMockRecorder) CheckPrivileges(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockGen3Interface)(nil).CheckPrivileges), ctx) -} - -// DeleteRecord mocks base method. -func (m *MockGen3Interface) DeleteRecord(ctx context.Context, guid string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRecord", ctx, guid) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DeleteRecord indicates an expected call of DeleteRecord. -func (mr *MockGen3InterfaceMockRecorder) DeleteRecord(ctx, guid any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecord", reflect.TypeOf((*MockGen3Interface)(nil).DeleteRecord), ctx, guid) -} - -// Do mocks base method. -func (m *MockGen3Interface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Do", ctx, req) - ret0, _ := ret[0].(*http.Response) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Do indicates an expected call of Do. -func (mr *MockGen3InterfaceMockRecorder) Do(ctx, req any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockGen3Interface)(nil).Do), ctx, req) -} - -// ExportCredential mocks base method. -func (m *MockGen3Interface) ExportCredential(ctx context.Context, cred *conf.Credential) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExportCredential", ctx, cred) - ret0, _ := ret[0].(error) - return ret0 -} - -// ExportCredential indicates an expected call of ExportCredential. -func (mr *MockGen3InterfaceMockRecorder) ExportCredential(ctx, cred any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportCredential", reflect.TypeOf((*MockGen3Interface)(nil).ExportCredential), ctx, cred) -} - -// GetCredential mocks base method. -func (m *MockGen3Interface) GetCredential() *conf.Credential { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCredential") - ret0, _ := ret[0].(*conf.Credential) - return ret0 -} - -// GetCredential indicates an expected call of GetCredential. -func (mr *MockGen3InterfaceMockRecorder) GetCredential() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockGen3Interface)(nil).GetCredential)) -} - -// GetPresignedUrl mocks base method. -func (m *MockGen3Interface) GetPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPresignedUrl", ctx, guid, protocolText) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetPresignedUrl indicates an expected call of GetPresignedUrl. -func (mr *MockGen3InterfaceMockRecorder) GetPresignedUrl(ctx, guid, protocolText any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPresignedUrl", reflect.TypeOf((*MockGen3Interface)(nil).GetPresignedUrl), ctx, guid, protocolText) -} - -// Logger mocks base method. -func (m *MockGen3Interface) Logger() *logs.TeeLogger { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Logger") - ret0, _ := ret[0].(*logs.TeeLogger) - return ret0 -} - -// Logger indicates an expected call of Logger. -func (mr *MockGen3InterfaceMockRecorder) Logger() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockGen3Interface)(nil).Logger)) -} - -// New mocks base method. -func (m *MockGen3Interface) New(method, url string) *request.RequestBuilder { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "New", method, url) - ret0, _ := ret[0].(*request.RequestBuilder) - return ret0 -} - -// New indicates an expected call of New. -func (mr *MockGen3InterfaceMockRecorder) New(method, url any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockGen3Interface)(nil).New), method, url) -} - -// ParseFenceURLResponse mocks base method. -func (m *MockGen3Interface) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ParseFenceURLResponse", resp) - ret0, _ := ret[0].(api.FenceResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ParseFenceURLResponse indicates an expected call of ParseFenceURLResponse. -func (mr *MockGen3InterfaceMockRecorder) ParseFenceURLResponse(resp any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseFenceURLResponse", reflect.TypeOf((*MockGen3Interface)(nil).ParseFenceURLResponse), resp) -} diff --git a/client/upload/singleFile.go b/client/upload/singleFile.go deleted file mode 100644 index 32c4194..0000000 --- a/client/upload/singleFile.go +++ /dev/null @@ -1,118 +0,0 @@ -package upload - -import ( - "context" - "errors" - "fmt" - "io" - "os" - "path/filepath" - - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" -) - -func UploadSingle(ctx context.Context, profile string, guid string, oid string, filePath string, bucketName string, enableLogs bool, progressCallback common.ProgressCallback) error { - - logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog()) - if enableLogs { - logger, closer = logs.New( - profile, - logs.WithSucceededLog(), - logs.WithFailedLog(), - logs.WithScoreboard(), - logs.WithConsole(), - ) - } - defer closer() - - // Instantiate interface to Gen3 - g3i, err := client.NewGen3Interface( - profile, - logger, - ) - if err != nil { - return fmt.Errorf("failed to parse config on profile %s: %w", profile, err) - } - - filePaths, err := common.ParseFilePaths(filePath, false) - if len(filePaths) > 1 { - return errors.New("more than 1 file location has been found. Do not use \"*\" in file path or provide a folder as file path") - } - if err != nil { - return errors.New("file path parsing error: " + err.Error()) - } - if len(filePaths) == 1 { - filePath = filePaths[0] - } - filename := filepath.Base(filePath) - - file, err := os.Open(filePath) - if err != nil { - if enableLogs { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() - } - g3i.Logger().Failed(filePath, filename, common.FileMetadata{}, "", 0, false) - g3i.Logger().Println("File open error: " + err.Error()) - - return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", filePath, err.Error()) - } - defer file.Close() - - fi, err := file.Stat() - if err != nil { - return fmt.Errorf("failed to stat file: %w", err) - } - fileSize := fi.Size() - - furObject := common.FileUploadRequestObject{ - FilePath: filePath, - Filename: filename, - GUID: guid, - OID: oid, - Bucket: bucketName, - Progress: progressCallback, - } - - furObject, err = generateUploadRequest(ctx, g3i, furObject, file, nil) - - if err != nil { - if enableLogs { - sb := g3i.Logger().Scoreboard() - sb.IncrementSB(len(sb.Counts)) - sb.PrintSB() - } - g3i.Logger().Failed(furObject.FilePath, furObject.Filename, common.FileMetadata{}, furObject.GUID, 0, false) - g3i.Logger().Printf("Error occurred during request generation: %s", err.Error()) - return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", filePath, err.Error()) - } - - var reader io.Reader = file - var progressTracker *progressReader - if furObject.Progress != nil { - progressTracker = newProgressReader(file, furObject.Progress, resolveUploadOID(furObject), fileSize) - reader = progressTracker - } - - _, err = uploadPart(ctx, furObject.PresignedURL, reader, fileSize) - if progressTracker != nil { - if finalizeErr := progressTracker.Finalize(); finalizeErr != nil && err == nil { - err = finalizeErr - } - } - if err != nil { - if enableLogs { - g3i.Logger().Scoreboard().IncrementSB(1) // Increment failure - } - return fmt.Errorf("[ERROR] Error uploading file content for %s: %w", filePath, err) - } - - if enableLogs { - g3i.Logger().Scoreboard().IncrementSB(0) - g3i.Logger().Scoreboard().PrintSB() - } - return nil -} diff --git a/client/upload/upload.go b/client/upload/upload.go deleted file mode 100644 index 14fc894..0000000 --- a/client/upload/upload.go +++ /dev/null @@ -1,125 +0,0 @@ -package upload - -import ( - "context" - "fmt" - "io" - "net/http" - "os" - - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/request" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" -) - -// Upload is a unified catch-all function that automatically chooses between -// single-part and multipart upload based on file size. -func Upload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { - g3.Logger().Printf("Processing Upload Request for: %s\n", req.FilePath) - - file, err := os.Open(req.FilePath) - if err != nil { - return fmt.Errorf("cannot open file %s: %w", req.FilePath, err) - } - defer file.Close() - - stat, err := file.Stat() - if err != nil { - return fmt.Errorf("cannot stat file: %w", err) - } - - fileSize := stat.Size() - if fileSize == 0 { - return fmt.Errorf("file is empty: %s", req.Filename) - } - - // Use Single-Part if file is smaller than 5GB (or your defined limit) - if fileSize < 5*common.GB { - g3.Logger().Printf("File size %d bytes (< 5GB), performing single-part upload\n", fileSize) - UploadSingle(ctx, g3.GetCredential().Profile, req.GUID, req.GUID, req.FilePath, req.Bucket, true, nil) - } - g3.Logger().Printf("File size %d bytes (>= 5GB), performing multipart upload\n", fileSize) - return MultipartUpload(ctx, g3, req, file, showProgress) -} - -func performSinglePartUpload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { - // 1. Get the Presigned URL - respObj, err := GeneratePresignedUploadURL(ctx, g3, req.Filename, req.FileMetadata, req.Bucket) - if err != nil { - return fmt.Errorf("failed to generate single-part URL: %w", err) - } - - req.GUID = respObj.GUID - req.PresignedURL = respObj.URL - - // 2. Open file and setup progress - file, _ := os.Open(req.FilePath) - defer file.Close() - - var body io.Reader = file - var p *mpb.Progress - if showProgress { - p = mpb.New(mpb.WithOutput(os.Stdout)) - fi, _ := file.Stat() - bar := p.AddBar(fi.Size(), - mpb.PrependDecorators(decor.Name(req.Filename+" ")), - mpb.AppendDecorators(decor.Percentage()), - ) - body = bar.ProxyReader(file) - } - - resp, err := g3.Do(ctx, &request.RequestBuilder{ - Method: http.MethodPut, - Url: req.PresignedURL, - Body: body, - }) - - if p != nil { - p.Wait() - } - - if err != nil || resp.StatusCode != http.StatusOK { - return fmt.Errorf("single-part upload failed") - } - return nil -} - -// UploadSingleFile handles single-part upload with progress -func UploadSingleFile(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { - file, err := os.Open(req.FilePath) - if err != nil { - return err - } - defer file.Close() - - fi, _ := file.Stat() - if fi.Size() > common.FileSizeLimit { - return fmt.Errorf("file exceeds 5GB limit") - } - - respObj, err := GeneratePresignedUploadURL(ctx, g3, req.Filename, req.FileMetadata, req.Bucket) - if err != nil { - return err - } - - // Generate request with progress bar - var p *mpb.Progress - if showProgress { - p = mpb.New(mpb.WithOutput(os.Stdout)) - } - - fur, err := generateUploadRequest(ctx, g3, common.FileUploadRequestObject{ - FilePath: req.FilePath, - Filename: req.Filename, - PresignedURL: respObj.URL, - GUID: respObj.GUID, - Bucket: req.Bucket, - }, file, p) - if err != nil { - return err - } - - return MultipartUpload(ctx, g3, fur, file, showProgress) -} diff --git a/cmd/auth.go b/cmd/auth.go index 7de1b36..6e0398a 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -6,8 +6,8 @@ import ( "log" "sort" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/spf13/cobra" ) @@ -24,12 +24,15 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole()) defer logCloser() - g3i, err := client.NewGen3Interface(profile, logger) + g3i, err := g3client.NewGen3Interface( + profile, logger, + g3client.WithClients(g3client.FenceClient), + ) if err != nil { log.Fatalf("Fatal NewGen3Interface error: %s\n", err) } - resourceAccess, err := g3i.CheckPrivileges(context.Background()) + resourceAccess, err := g3i.Fence().CheckPrivileges(context.Background()) if err != nil { g3i.Logger().Fatalf("Fatal authentication error: %s\n", err) } else { diff --git a/cmd/collaborator.go b/cmd/collaborator.go new file mode 100644 index 0000000..7fc1528 --- /dev/null +++ b/cmd/collaborator.go @@ -0,0 +1,264 @@ +package cmd + +import ( + "fmt" + "os" + + "regexp" + "strings" + + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/requestor" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +var collaboratorCmd = &cobra.Command{ + Use: "collaborator", + Short: "Manage collaborators and access requests", +} + +var emailRegex = regexp.MustCompile(`^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$`) + +func validateProjectAndUser(projectID, username string) error { + if !emailRegex.MatchString(strings.ToLower(username)) { + return fmt.Errorf("invalid username '%s': must be a valid email address", username) + } + + parts := strings.Split(projectID, "-") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return fmt.Errorf("invalid project_id '%s': must be in the form 'program-project'", projectID) + } + + return nil +} + +func printRequest(r requestor.Request) { + b, err := yaml.Marshal(r) + if err != nil { + fmt.Printf("ID: %s (Error formatting details: %v)\n", r.RequestID, err) + return + } + fmt.Println(string(b)) +} + +func getRequestorClient() (requestor.RequestorInterface, func()) { + if profile == "" { + fmt.Println("Error: profile is required. Please specify a profile using the --profile flag.") + os.Exit(1) + } + + // Initialize logger + logger, logCloser := logs.New(profile) + + // Initialize Gen3Interface handles selective initialization + g3i, err := g3client.NewGen3Interface(profile, logger, g3client.WithClients(g3client.RequestorClient)) + if err != nil { + fmt.Printf("Error accessing Gen3: %v\n", err) + logCloser() + os.Exit(1) + } + + return g3i.Requestor(), logCloser +} + +var collaboratorListCmd = &cobra.Command{ + Use: "ls", + Short: "List requests", + Run: func(cmd *cobra.Command, args []string) { + mine, _ := cmd.Flags().GetBool("mine") + active, _ := cmd.Flags().GetBool("active") + username, _ := cmd.Flags().GetString("username") + + client, closer := getRequestorClient() + defer closer() + + requests, err := client.ListRequests(cmd.Context(), mine, active, username) + if err != nil { + fmt.Printf("Error listing requests: %v\n", err) + os.Exit(1) + } + + for _, r := range requests { + printRequest(r) + } + }, +} + +var collaboratorPendingCmd = &cobra.Command{ + Use: "pending", + Short: "List pending requests", + Run: func(cmd *cobra.Command, args []string) { + client, closer := getRequestorClient() + defer closer() + + // Fetch all requests + requests, err := client.ListRequests(cmd.Context(), false, false, "") + if err != nil { + fmt.Printf("Error listing requests: %v\n", err) + os.Exit(1) + } + + fmt.Println("Pending requests:") + for _, r := range requests { + if r.Status != "SIGNED" { + printRequest(r) + } + } + }, +} + +var collaboratorAddUserCmd = &cobra.Command{ + Use: "add [project_id] [username]", + Short: "Add a user to a project", + Args: func(cmd *cobra.Command, args []string) error { + if err := cobra.ExactArgs(2)(cmd, args); err != nil { + return err + } + return validateProjectAndUser(args[0], args[1]) + }, + Run: func(cmd *cobra.Command, args []string) { + projectID := args[0] + username := args[1] + write, _ := cmd.Flags().GetBool("write") + guppy, _ := cmd.Flags().GetBool("guppy") + approve, _ := cmd.Flags().GetBool("approve") + + client, closer := getRequestorClient() + defer closer() + + reqs, err := client.AddUser(cmd.Context(), projectID, username, write, guppy) + if err != nil { + fmt.Printf("Error adding user: %v\n", err) + os.Exit(1) + } + + if approve { + fmt.Println("\nAuto-approving requests...") + for _, r := range reqs { + updatedReq, err := client.UpdateRequest(cmd.Context(), r.RequestID, "SIGNED") + if err != nil { + fmt.Printf("Error approving request %s: %v\n", r.RequestID, err) + } else { + fmt.Printf("Approved request %s:\n", updatedReq.RequestID) + printRequest(*updatedReq) + } + } + } else { + fmt.Println("Created requests:") + for _, r := range reqs { + printRequest(r) + } + fmt.Printf("\nAn authorized user must approve these requests to add %s to %s\n", username, projectID) + } + }, +} + +var collaboratorRemoveUserCmd = &cobra.Command{ + Use: "rm [project_id] [username]", + Short: "Remove a user from a project", + Args: func(cmd *cobra.Command, args []string) error { + if err := cobra.ExactArgs(2)(cmd, args); err != nil { + return err + } + return validateProjectAndUser(args[0], args[1]) + }, + Run: func(cmd *cobra.Command, args []string) { + projectID := args[0] + username := args[1] + approve, _ := cmd.Flags().GetBool("approve") + + client, closer := getRequestorClient() + defer closer() + + reqs, err := client.RemoveUser(cmd.Context(), projectID, username) + if err != nil { + fmt.Printf("Error removing user: %v\n", err) + os.Exit(1) + } + + if approve { + fmt.Println("\nAuto-approving revoke requests...") + for _, r := range reqs { + updatedReq, err := client.UpdateRequest(cmd.Context(), r.RequestID, "SIGNED") + if err != nil { + fmt.Printf("Error approving request %s: %v\n", r.RequestID, err) + } else { + fmt.Printf("Approved request %s:\n", updatedReq.RequestID) + printRequest(*updatedReq) + } + } + } else { + fmt.Println("Created revoke requests:") + for _, r := range reqs { + printRequest(r) + } + } + }, +} + +var collaboratorApproveCmd = &cobra.Command{ + Use: "approve [request_id]", + Short: "Approve a request (sign it)", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + requestID := args[0] + + client, closer := getRequestorClient() + defer closer() + + req, err := client.UpdateRequest(cmd.Context(), requestID, "SIGNED") + if err != nil { + fmt.Printf("Error approving request: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Approved request %s\n", req.RequestID) + printRequest(*req) + }, +} + +var collaboratorUpdateCmd = &cobra.Command{ + Use: "update [request_id] [status]", + Short: "Update a request status", + Hidden: true, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + requestID := args[0] + status := args[1] + + client, closer := getRequestorClient() + defer closer() + + req, err := client.UpdateRequest(cmd.Context(), requestID, status) + if err != nil { + fmt.Printf("Error updating request: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Updated request %s to status %s\n", req.RequestID, req.Status) + }, +} + +func init() { + RootCmd.AddCommand(collaboratorCmd) + collaboratorCmd.AddCommand(collaboratorListCmd) + collaboratorCmd.AddCommand(collaboratorPendingCmd) + collaboratorCmd.AddCommand(collaboratorAddUserCmd) + collaboratorCmd.AddCommand(collaboratorRemoveUserCmd) + collaboratorCmd.AddCommand(collaboratorApproveCmd) + collaboratorCmd.AddCommand(collaboratorUpdateCmd) + + collaboratorListCmd.Flags().Bool("mine", false, "List my requests") + collaboratorListCmd.Flags().Bool("active", false, "List only active requests") + collaboratorListCmd.Flags().String("username", "", "List requests for user") + + collaboratorAddUserCmd.Flags().BoolP("write", "w", false, "Grant write access") + collaboratorAddUserCmd.Flags().BoolP("guppy", "g", false, "Grant guppy admin access") + collaboratorAddUserCmd.Flags().BoolP("approve", "a", false, "Automatically approve the requests") + + collaboratorRemoveUserCmd.Flags().BoolP("approve", "a", false, "Automatically approve the revoke requests") + + collaboratorCmd.PersistentFlags().StringVar(&profile, "profile", "", "Specify profile to use") +} diff --git a/cmd/configure.go b/cmd/configure.go index 604693d..b6eb564 100644 --- a/cmd/configure.go +++ b/cmd/configure.go @@ -4,11 +4,10 @@ import ( "context" "fmt" - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" - req "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/spf13/cobra" ) @@ -37,7 +36,7 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole()) defer logCloser() - configure := conf.NewConfigure(logger) + configure := conf.NewConfigure(logger.Logger) if credFile != "" { readCred, err := configure.Import(credFile, "") if err != nil { @@ -51,13 +50,8 @@ func init() { cred.AccessToken = "" } - newFunc := api.NewFunctions( - configure, - req.NewRequestInterface(logger, cred, configure), - cred, - logger, - ) - err := newFunc.ExportCredential(context.Background(), cred) + g3i := g3client.NewGen3InterfaceFromCredential(cred, logger, g3client.WithClients()) + err := g3i.ExportCredential(context.Background(), cred) if err != nil { logger.Println(err.Error()) } diff --git a/cmd/delete.go b/cmd/delete.go index e11c92f..4589577 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -3,8 +3,8 @@ package cmd import ( "context" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/spf13/cobra" ) @@ -24,12 +24,12 @@ If no profile is specified, "default" profile is used for authentication.`, logger, logCloser := logs.New(profile, logs.WithConsole()) defer logCloser() - g3i, err := client.NewGen3Interface(profile, logger) + g3i, err := g3client.NewGen3Interface(profile, logger) if err != nil { logger.Fatalf("Fatal NewGen3Interface error: %s\n", err) } - msg, err := g3i.DeleteRecord(context.Background(), guid) + msg, err := g3i.Fence().DeleteRecord(context.Background(), guid) if err != nil { logger.Fatal(err) } diff --git a/cmd/download-multiple.go b/cmd/download-multiple.go index ed59486..fa91c15 100644 --- a/cmd/download-multiple.go +++ b/cmd/download-multiple.go @@ -7,10 +7,10 @@ import ( "log" "os" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/download" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/download" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -38,7 +38,7 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole(), logs.WithFailedLog(), logs.WithScoreboard(), logs.WithSucceededLog()) defer logCloser() - g3i, err := client.NewGen3Interface(profile, logger) + g3i, err := g3client.NewGen3Interface(profile, logger) if err != nil { log.Fatalf("Failed to parse config on profile %s, %v", profile, err) } diff --git a/cmd/download-single.go b/cmd/download-single.go index 6438acd..6d1c5db 100644 --- a/cmd/download-single.go +++ b/cmd/download-single.go @@ -4,10 +4,10 @@ import ( "context" "log" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/download" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/download" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/spf13/cobra" ) @@ -32,14 +32,14 @@ func init() { logger, logCloser := logs.New(profile, logs.WithConsole(), logs.WithFailedLog(), logs.WithSucceededLog(), logs.WithScoreboard()) defer logCloser() - g3I, err := client.NewGen3Interface(profile, logger) + g3I, err := g3client.NewGen3Interface(profile, logger) if err != nil { log.Fatalf("Failed to parse config on profile %s, %v", profile, err) } objects := []common.ManifestObject{ common.ManifestObject{ - ObjectID: guid, + GUID: guid, }, } err = download.DownloadMultiple( diff --git a/cmd/gitversion.go b/cmd/gitversion.go index cc123f5..ce96e41 100644 --- a/cmd/gitversion.go +++ b/cmd/gitversion.go @@ -2,5 +2,5 @@ package cmd var ( gitcommit = "N/A" - gitversion = "2025.12" + gitversion = "2026.2" ) diff --git a/cmd/retry-upload.go b/cmd/retry-upload.go index bd68a42..69de60d 100644 --- a/cmd/retry-upload.go +++ b/cmd/retry-upload.go @@ -3,10 +3,10 @@ package cmd import ( "context" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/upload" "github.com/spf13/cobra" ) @@ -28,7 +28,7 @@ func init() { ) defer closer() - g3, err := client.NewGen3Interface(profile, Logger) + g3, err := g3client.NewGen3Interface(profile, Logger) if err != nil { Logger.Fatalf("Failed to initialize client: %v", err) } diff --git a/cmd/upload-multipart.go b/cmd/upload-multipart.go index 86330a3..5f020e5 100644 --- a/cmd/upload-multipart.go +++ b/cmd/upload-multipart.go @@ -5,10 +5,10 @@ import ( "os" "path/filepath" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/upload" "github.com/spf13/cobra" ) @@ -35,30 +35,30 @@ This method is resilient to network interruptions and supports resume capability logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) defer closer() - g3, err := client.NewGen3Interface( + g3, err := g3client.NewGen3Interface( profile, logger, ) if err != nil { - logger.Fatalf("failed to initialize Gen3 interface: %w", err) + logger.Fatalf("failed to initialize Gen3 interface: %v", err) } absPath, err := common.GetAbsolutePath(filePath) if err != nil { - logger.Fatalf("invalid file path: %w", err) + logger.Fatalf("invalid file path: %v", err) } fileInfo := common.FileUploadRequestObject{ - FilePath: absPath, - Filename: filepath.Base(absPath), + SourcePath: absPath, + ObjectKey: filepath.Base(absPath), GUID: guid, FileMetadata: common.FileMetadata{}, } file, err := os.Open(absPath) if err != nil { - logger.Fatalf("cannot open file %s: %w", absPath, err) + logger.Fatalf("cannot open file %s: %v", absPath, err) } defer file.Close() diff --git a/cmd/upload-multiple.go b/cmd/upload-multiple.go index 66fef2e..99e58ff 100644 --- a/cmd/upload-multiple.go +++ b/cmd/upload-multiple.go @@ -9,10 +9,10 @@ import ( "os" "path/filepath" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/upload" "github.com/spf13/cobra" ) @@ -38,12 +38,10 @@ Options to run multipart uploads for large files and parallel batch uploading ar fmt.Printf("Notice: this command uploads to pre-existing GUIDs from a manifest.\nIf you want to upload new files (new GUIDs generated automatically), use \"./data-client upload\" instead.\n\n") ctx := context.Background() - noopProgress := func(common.ProgressEvent) error { return nil } - logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard()) defer closer() - g3i, err := client.NewGen3Interface(profile, logger) + g3i, err := g3client.NewGen3Interface(profile, logger) if err != nil { logger.Fatalf("Failed to parse config on profile %s: %v", profile, err) } @@ -80,24 +78,19 @@ Options to run multipart uploads for large files and parallel batch uploading ar for _, obj := range objects { localFilePath := filepath.Join(absUploadPath, obj.Title) - if err != nil { - logger.Println("Skipping:", err) - continue - } - fur, err := upload.ProcessFilename(logger, absUploadPath, localFilePath, obj.ObjectID, includeSubDirName, false) + fur, err := upload.ProcessFilename(logger, absUploadPath, localFilePath, obj.GUID, includeSubDirName, false) if err != nil { logger.Printf("Skipping %s: %v\n", localFilePath, err) - logger.Failed(localFilePath, filepath.Base(localFilePath), common.FileMetadata{}, obj.ObjectID, 0, false) + logger.Failed(localFilePath, filepath.Base(localFilePath), common.FileMetadata{}, obj.GUID, 0, false) continue } // GUID comes from manifest → override - fur.GUID = obj.ObjectID + fur.GUID = obj.GUID fur.Bucket = bucketName - fur.Progress = noopProgress - logger.Println("\t" + localFilePath + " → GUID " + obj.ObjectID) + logger.Println("\t" + localFilePath + " → GUID " + obj.GUID) requests = append(requests, fur) } @@ -126,16 +119,16 @@ Options to run multipart uploads for large files and parallel batch uploading ar } } else { for _, req := range single { - upload.UploadSingle(ctx, profileConfig.Profile, req.GUID, req.GUID, req.FilePath, req.Bucket, true, noopProgress) + upload.UploadSingle(ctx, g3i, req, true) } } // Upload multipart files for _, req := range multi { - file, err := os.Open(req.FilePath) + file, err := os.Open(req.SourcePath) if err != nil { - g3i.Logger().Printf("Error opening file %s : %v", req.FilePath, err) + g3i.Logger().Printf("Error opening file %s : %v", req.SourcePath, err) continue } diff --git a/cmd/upload-single.go b/cmd/upload-single.go index 0270e36..34eb9ba 100644 --- a/cmd/upload-single.go +++ b/cmd/upload-single.go @@ -4,9 +4,12 @@ package cmd import ( "context" "log" + "path/filepath" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/upload" "github.com/spf13/cobra" ) @@ -21,8 +24,21 @@ func init() { Long: `Gets a presigned URL for which to upload a file associated with a GUID and then uploads the specified file.`, Example: `./data-client upload-single --profile= --guid=f6923cf3-xxxx-xxxx-xxxx-14ab3f84f9d6 --file=`, Run: func(cmd *cobra.Command, args []string) { - noopProgress := func(common.ProgressEvent) error { return nil } - err := upload.UploadSingle(context.Background(), profile, guid, guid, filePath, bucketName, true, noopProgress) + logger, closer := logs.New(profile, logs.WithSucceededLog(), logs.WithFailedLog(), logs.WithScoreboard(), logs.WithConsole()) + defer closer() + + g3i, err := g3client.NewGen3Interface(profile, logger) + if err != nil { + log.Fatalf("Failed to parse config on profile %s: %v", profile, err) + } + + req := common.FileUploadRequestObject{ + SourcePath: filePath, + ObjectKey: filepath.Base(filePath), + Bucket: bucketName, + GUID: guid, + } + err = upload.UploadSingle(context.Background(), g3i, req, true) if err != nil { log.Fatalln(err.Error()) } diff --git a/cmd/upload.go b/cmd/upload.go index ffae48f..a99fdc0 100644 --- a/cmd/upload.go +++ b/cmd/upload.go @@ -6,10 +6,10 @@ import ( "os" "path/filepath" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/upload" "github.com/spf13/cobra" ) @@ -37,14 +37,14 @@ func init() { Logger, logCloser := logs.New(profile, logs.WithSucceededLog(), logs.WithScoreboard(), logs.WithFailedLog()) defer logCloser() // Instantiate interface to Gen3 - g3i, err := client.NewGen3Interface(profile, Logger) + g3i, err := g3client.NewGen3Interface(profile, Logger) if err != nil { log.Fatalf("Failed to parse config on profile %s, %v", profile, err) } logger := g3i.Logger() if hasMetadata { - hasShepherd, err := g3i.CheckForShepherdAPI(ctx) + hasShepherd, err := g3i.Fence().CheckForShepherdAPI(ctx) if err != nil { logger.Printf("WARNING: Error when checking for Shepherd API: %v", err) } else { @@ -120,20 +120,20 @@ func init() { } } else { for _, furObject := range singlePartObjects { - file, err := os.Open(furObject.FilePath) + file, err := os.Open(furObject.SourcePath) if err != nil { - logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Failed(furObject.SourcePath, furObject.ObjectKey, furObject.FileMetadata, furObject.GUID, 0, false) logger.Println("File open error: " + err.Error()) continue } defer file.Close() fi, err := file.Stat() if err != nil { - logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Failed(furObject.SourcePath, furObject.ObjectKey, furObject.FileMetadata, furObject.GUID, 0, false) logger.Println("File stat error for file" + fi.Name() + ", file may be missing or unreadable because of permissions.\n") continue } - upload.UploadSingleFile(ctx, g3i, furObject, true) + upload.UploadSingle(ctx, g3i, furObject, true) } } @@ -146,9 +146,9 @@ func init() { } g3i.Logger().Println("Multipart uploading...") for _, furObject := range multipartObjects { - file, err := os.Open(furObject.FilePath) + file, err := os.Open(furObject.SourcePath) if err != nil { - logger.Failed(furObject.FilePath, furObject.Filename, furObject.FileMetadata, furObject.GUID, 0, false) + logger.Failed(furObject.SourcePath, furObject.ObjectKey, furObject.FileMetadata, furObject.GUID, 0, false) logger.Println("File open error: " + err.Error()) continue } diff --git a/client/common/common.go b/common/common.go similarity index 79% rename from client/common/common.go rename to common/common.go index 57dce2b..716625f 100644 --- a/client/common/common.go +++ b/common/common.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log" + "net/http" "os" "path/filepath" "strings" @@ -116,3 +117,25 @@ func cleanupHiddenFiles(filePaths []string) []string { } return filePaths[:i] } + +// CanDownloadFile checks if a file can be downloaded from the given signed URL +// by issuing a ranged GET for a single byte to mimic HEAD behavior. +func CanDownloadFile(signedURL string) error { + req, err := http.NewRequest("GET", signedURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Range", "bytes=0-0") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("error while sending the request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusPartialContent || resp.StatusCode == http.StatusOK { + return nil + } + + return fmt.Errorf("failed to access file, HTTP status: %d", resp.StatusCode) +} diff --git a/client/common/constants.go b/common/constants.go similarity index 71% rename from client/common/constants.go rename to common/constants.go index 6f9bc64..6299a2c 100644 --- a/client/common/constants.go +++ b/common/constants.go @@ -1,12 +1,7 @@ package common import ( - "fmt" - "log" "os" - "os/exec" - "strconv" - "strings" "time" ) @@ -94,43 +89,5 @@ const ( var ( // MinChunkSize is configurable via git config and initialized in init() - MinChunkSize int64 + MinChunkSize = 10 * MB ) - -func init() { - v, err := GetLfsCustomTransferInt("lfs.customtransfer.drs.multipart-min-chunk-size", 10) - if err != nil { - log.Printf("Warning: Could not read git config for multipart-min-chunk-size, using default (10 MB): %v\n", err) - MinChunkSize = int64(10) * MB - return - } - - MinChunkSize = int64(v) * MB -} - -func GetLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { - defaultText := strconv.FormatInt(defaultValue, 10) - // TODO cache or get all the configs at once? - cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) - output, err := cmd.Output() - if err != nil { - return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) - } - - value := strings.TrimSpace(string(output)) - - parsed, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return defaultValue, fmt.Errorf("invalid int value for %s: %q", key, value) - } - - if parsed < 0 { - return defaultValue, fmt.Errorf("invalid negative int value for %s: %d", key, parsed) - } - - if parsed == 0 || parsed > 500 { - return defaultValue, fmt.Errorf("invalid int value for %s: %d. Must be between 1 and 500", key, parsed) - } - - return parsed, nil -} diff --git a/client/common/isHidden_notwindows.go b/common/isHidden_notwindows.go similarity index 100% rename from client/common/isHidden_notwindows.go rename to common/isHidden_notwindows.go diff --git a/client/common/isHidden_windows.go b/common/isHidden_windows.go similarity index 100% rename from client/common/isHidden_windows.go rename to common/isHidden_windows.go diff --git a/client/common/logHelper.go b/common/logHelper.go similarity index 100% rename from client/common/logHelper.go rename to common/logHelper.go diff --git a/common/progress.go b/common/progress.go new file mode 100644 index 0000000..f07c856 --- /dev/null +++ b/common/progress.go @@ -0,0 +1,52 @@ +package common + +import ( + "context" +) + +// ProgressEvent matches the Git LFS custom transfer progress payload. +type ProgressEvent struct { + Event string `json:"event"` + Oid string `json:"oid"` + BytesSoFar int64 `json:"bytesSoFar"` + BytesSinceLast int64 `json:"bytesSinceLast"` + Message string `json:"message,omitempty"` + Level string `json:"level,omitempty"` + Attrs map[string]any `json:"attrs,omitempty"` +} + +// ProgressCallback emits transfer progress updates. +type ProgressCallback func(ProgressEvent) error + +type contextKey string + +const ( + progressKey contextKey = "progressCallback" + oidKey contextKey = "activeOid" +) + +// WithProgress returns a new context with the provided ProgressCallback. +func WithProgress(ctx context.Context, cb ProgressCallback) context.Context { + return context.WithValue(ctx, progressKey, cb) +} + +// GetProgress returns the ProgressCallback from the context, or nil if not found. +func GetProgress(ctx context.Context) ProgressCallback { + if cb, ok := ctx.Value(progressKey).(ProgressCallback); ok { + return cb + } + return nil +} + +// WithOid returns a new context with the provided OID. +func WithOid(ctx context.Context, oid string) context.Context { + return context.WithValue(ctx, oidKey, oid) +} + +// GetOid returns the OID from the context, or empty string if not found. +func GetOid(ctx context.Context) string { + if oid, ok := ctx.Value(oidKey).(string); ok { + return oid + } + return "" +} diff --git a/common/resource.go b/common/resource.go new file mode 100644 index 0000000..9e0d011 --- /dev/null +++ b/common/resource.go @@ -0,0 +1,14 @@ +package common + +import ( + "fmt" + "strings" +) + +func ProjectToResource(project string) (string, error) { + if !strings.Contains(project, "-") { + return "", fmt.Errorf("error: invalid project ID %s, ID should look like -", project) + } + projectIdArr := strings.SplitN(project, "-", 2) + return "/programs/" + projectIdArr[0] + "/projects/" + projectIdArr[1], nil +} diff --git a/client/common/types.go b/common/types.go similarity index 83% rename from client/common/types.go rename to common/types.go index 617bd38..4626c44 100644 --- a/client/common/types.go +++ b/common/types.go @@ -11,14 +11,12 @@ type AccessTokenStruct struct { // FileUploadRequestObject defines a object for file upload type FileUploadRequestObject struct { - FilePath string - Filename string + SourcePath string + ObjectKey string FileMetadata FileMetadata GUID string - OID string PresignedURL string Bucket string `json:"bucket,omitempty"` - Progress ProgressCallback } // FileDownloadResponseObject defines a object for file download @@ -26,14 +24,12 @@ type FileDownloadResponseObject struct { DownloadPath string Filename string GUID string - OID string - URL string + PresignedURL string Range int64 Overwrite bool Skip bool Response *http.Response Writer io.Writer - Progress ProgressCallback } // FileMetadata defines the metadata accepted by the new object management API, Shepherd @@ -46,8 +42,8 @@ type FileMetadata struct { // RetryObject defines a object for retry upload type RetryObject struct { - FilePath string - Filename string + SourcePath string + ObjectKey string FileMetadata FileMetadata GUID string RetryCount int @@ -56,7 +52,7 @@ type RetryObject struct { } type ManifestObject struct { - ObjectID string `json:"object_id"` + GUID string `json:"object_id"` SubjectID string `json:"subject_id"` Title string `json:"title"` Size int64 `json:"size"` diff --git a/client/conf/config.go b/conf/config.go similarity index 91% rename from client/conf/config.go rename to conf/config.go index 4297f50..6c40967 100644 --- a/client/conf/config.go +++ b/conf/config.go @@ -1,17 +1,17 @@ package conf -//go:generate mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/conf ManagerInterface +//go:generate mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/conf ManagerInterface import ( "encoding/json" "errors" "fmt" + "log/slog" "os" "path" "strings" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" "gopkg.in/ini.v1" ) @@ -28,10 +28,10 @@ type Credential struct { } type Manager struct { - Logger logs.Logger + Logger *slog.Logger } -func NewConfigure(logs logs.Logger) ManagerInterface { +func NewConfigure(logs *slog.Logger) ManagerInterface { return &Manager{ Logger: logs, } @@ -46,7 +46,8 @@ type ManagerInterface interface { Save(cred *Credential) error EnsureExists() error - IsValid(*Credential) (bool, error) + IsCredentialValid(*Credential) (bool, error) + IsTokenValid(string) (bool, error) } func (man *Manager) configPath() (string, error) { @@ -97,7 +98,7 @@ func (man *Manager) Load(profile string) (*Credential, error) { homeDir, err := os.UserHomeDir() if err != nil { errs := fmt.Errorf("Error occurred when getting home directory: %s", err.Error()) - man.Logger.Printf(errs.Error()) + man.Logger.Error(errs.Error()) return nil, errs } configPath := path.Join(homeDir + common.PathSeparator + ".gen3" + common.PathSeparator + "gen3_client_config.ini") @@ -151,13 +152,13 @@ func (man *Manager) Save(profileConfig *Credential) error { configPath, err := man.configPath() if err != nil { errs := fmt.Errorf("error occurred when getting config path: %s", err.Error()) - man.Logger.Println(errs.Error()) + man.Logger.Error(errs.Error()) return errs } cfg, err := ini.Load(configPath) if err != nil { errs := fmt.Errorf("error occurred when loading config file: %s", err.Error()) - man.Logger.Println(errs.Error()) + man.Logger.Error(errs.Error()) return errs } @@ -180,7 +181,7 @@ func (man *Manager) Save(profileConfig *Credential) error { err = cfg.SaveTo(configPath) if err != nil { errs := fmt.Errorf("error occurred when saving config file: %s", err.Error()) - man.Logger.Println(errs.Error()) + man.Logger.Error(errs.Error()) return fmt.Errorf("error occurred when saving config file: %s", err.Error()) } return nil @@ -222,16 +223,16 @@ func (man *Manager) Import(filePath, fenceToken string) (*Credential, error) { if filePath != "" { fullPath, err := common.GetAbsolutePath(filePath) if err != nil { - man.Logger.Println("error parsing credential file path: " + err.Error()) + man.Logger.Error("error parsing credential file path: " + err.Error()) return nil, err } content, err := os.ReadFile(fullPath) if err != nil { if os.IsNotExist(err) { - man.Logger.Println("File not found: " + fullPath) + man.Logger.Error("File not found: " + fullPath) } else { - man.Logger.Println("error reading file: " + err.Error()) + man.Logger.Error("error reading file: " + err.Error()) } return nil, err } @@ -243,7 +244,7 @@ func (man *Manager) Import(filePath, fenceToken string) (*Credential, error) { if err := json.Unmarshal([]byte(jsonStr), &cred); err != nil { errMsg := fmt.Errorf("cannot parse JSON credential file: %w", err) - man.Logger.Println(errMsg.Error()) + man.Logger.Error(errMsg.Error()) return nil, errMsg } } else if fenceToken != "" { diff --git a/conf/config_test.go b/conf/config_test.go new file mode 100644 index 0000000..1806184 --- /dev/null +++ b/conf/config_test.go @@ -0,0 +1,199 @@ +package conf + +import ( + "log/slog" + "os" + "path" + "testing" +) + +func TestNewConfigure(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := NewConfigure(logger) + + if manager == nil { + t.Fatal("Expected non-nil manager") + } + + // Type assertion to verify it's a *Manager + if _, ok := manager.(*Manager); !ok { + t.Error("Expected manager to be of type *Manager") + } +} + +func TestConfigPath(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + configPath, err := manager.configPath() + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if configPath == "" { + t.Error("Expected non-empty config path") + } + + // Verify path contains expected components + if !contains(configPath, ".gen3") { + t.Error("Expected config path to contain .gen3 directory") + } + + if !contains(configPath, "gen3_client_config.ini") { + t.Error("Expected config path to contain gen3_client_config.ini") + } +} + +func TestImport_WithCredentialFile(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + // Create a temporary credential file + tmpDir := t.TempDir() + credFile := path.Join(tmpDir, "cred.json") + + credContent := `{ + "KeyID": "test-key-id", + "APIKey": "test-api-key" + }` + + if err := os.WriteFile(credFile, []byte(credContent), 0644); err != nil { + t.Fatalf("Failed to create test credential file: %v", err) + } + + cred, err := manager.Import(credFile, "") + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if cred == nil { + t.Fatal("Expected non-nil credential") + } + + if cred.KeyID != "test-key-id" { + t.Errorf("Expected KeyID 'test-key-id', got '%s'", cred.KeyID) + } + + if cred.APIKey != "test-api-key" { + t.Errorf("Expected APIKey 'test-api-key', got '%s'", cred.APIKey) + } +} + +func TestImport_WithFenceToken(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + token := "test-fence-token-12345" + cred, err := manager.Import("", token) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if cred == nil { + t.Fatal("Expected non-nil credential") + } + + if cred.AccessToken != token { + t.Errorf("Expected AccessToken '%s', got '%s'", token, cred.AccessToken) + } +} + +func TestImport_NoCredentialOrToken(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + _, err := manager.Import("", "") + + if err == nil { + t.Fatal("Expected error when neither credential file nor token provided") + } + + if !contains(err.Error(), "either credential file or fence token must be provided") { + t.Errorf("Expected specific error message, got: %v", err) + } +} + +func TestImport_InvalidCredentialFile(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + // Test with non-existent file + _, err := manager.Import("/nonexistent/path/cred.json", "") + + if err == nil { + t.Fatal("Expected error for non-existent file") + } +} + +func TestImport_InvalidJSON(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + // Create a temporary file with invalid JSON + tmpDir := t.TempDir() + credFile := path.Join(tmpDir, "invalid.json") + + invalidJSON := `{invalid json content` + + if err := os.WriteFile(credFile, []byte(invalidJSON), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + _, err := manager.Import(credFile, "") + + if err == nil { + t.Fatal("Expected error for invalid JSON") + } + + if !contains(err.Error(), "cannot parse JSON credential file") { + t.Errorf("Expected JSON parse error, got: %v", err) + } +} + +func TestEnsureExists(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + // This test is tricky because it modifies the user's home directory + // We'll just verify it doesn't panic and returns a reasonable error or nil + err := manager.EnsureExists() + + // We accept either success or a reasonable error + if err != nil { + // Just log the error, don't fail the test + t.Logf("EnsureExists returned error (may be expected): %v", err) + } +} + +func TestLoad_ProfileNotFound(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + manager := &Manager{Logger: logger} + + // Try to load a profile that doesn't exist + _, err := manager.Load("nonexistent-profile") + + if err == nil { + t.Fatal("Expected error for non-existent profile") + } + + // Should contain profile not found error + if !contains(err.Error(), "profile not found") && !contains(err.Error(), "Need to run") { + t.Logf("Got error (may be expected): %v", err) + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/client/conf/validate.go b/conf/validate.go similarity index 51% rename from client/conf/validate.go rename to conf/validate.go index d50362b..41d4a46 100644 --- a/client/conf/validate.go +++ b/conf/validate.go @@ -20,30 +20,30 @@ func ValidateUrl(apiEndpoint string) (*url.URL, error) { return parsedURL, nil } -func (man *Manager) IsValid(profileConfig *Credential) (bool, error) { - if profileConfig == nil { - return false, fmt.Errorf("profileConfig is nil") +func (man *Manager) IsTokenValid(tokenStr string) (bool, error) { + if tokenStr == "" { + return false, fmt.Errorf("token is empty") } - /* Checks to see if credential in credential file is still valid */ // Parse the token without verifying the signature to access the claims. - token, _, err := new(jwt.Parser).ParseUnverified(profileConfig.APIKey, jwt.MapClaims{}) + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) if err != nil { return false, fmt.Errorf("invalid token format: %v", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return false, fmt.Errorf("unable to parse claims from provided token %#v", token) + return false, fmt.Errorf("unable to parse claims from provided token") } exp, ok := claims["exp"].(float64) if !ok { - return false, fmt.Errorf("'exp' claim not found or is not a number for claims %s", claims) + return false, fmt.Errorf("'exp' claim not found or is not a number") } iat, ok := claims["iat"].(float64) if !ok { - return false, fmt.Errorf("'iat' claim not found or is not a number for claims %s", claims) + // iat is not strictly required for validity in all cases, but we'll keep it for now as per original code + return false, fmt.Errorf("'iat' claim not found or is not a number") } now := time.Now().UTC() @@ -51,19 +51,44 @@ func (man *Manager) IsValid(profileConfig *Credential) (bool, error) { iatTime := time.Unix(int64(iat), 0).UTC() if expTime.Before(now) { - return false, fmt.Errorf("key %s expired %s < %s", profileConfig.APIKey, expTime.Format(time.RFC3339), now.Format(time.RFC3339)) + return false, fmt.Errorf("token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339)) } if iatTime.After(now) { - return false, fmt.Errorf("key %s not yet valid %s > %s", profileConfig.APIKey, iatTime.Format(time.RFC3339), now.Format(time.RFC3339)) + return false, fmt.Errorf("token not yet valid: iat %s > now %s", iatTime.Format(time.RFC3339), now.Format(time.RFC3339)) } delta := expTime.Sub(now) // threshold days set to 10 if delta > 0 && delta.Hours() < float64(10*24) { daysUntilExpiration := int(delta.Hours() / 24) - if daysUntilExpiration > 0 { - return true, fmt.Errorf("warning %s: Key will expire in %d days, on %s", profileConfig.APIKey, daysUntilExpiration, expTime.Format(time.RFC3339)) + if daysUntilExpiration > 0 && man.Logger != nil { + man.Logger.Warn(fmt.Sprintf("Token will expire in %d days, on %s", daysUntilExpiration, expTime.Format(time.RFC3339))) } } + return true, nil } + +func (man *Manager) IsCredentialValid(profileConfig *Credential) (bool, error) { + if profileConfig == nil { + return false, fmt.Errorf("profileConfig is nil") + } + + accessTokenValid, accessErr := man.IsTokenValid(profileConfig.AccessToken) + apiKeyValid, apiErr := man.IsTokenValid(profileConfig.APIKey) + + if !accessTokenValid && !apiKeyValid { + return false, fmt.Errorf("both access_token and api_key are invalid: %v; %v", accessErr, apiErr) + } + + if !accessTokenValid && apiKeyValid { + return false, fmt.Errorf("access_token is invalid but api_key is valid: %v", accessErr) + } + + return true, nil +} + +func (man *Manager) IsValid(profileConfig *Credential) (bool, error) { + // Maintain backward compatibility by checking APIKey as before, but using the new helper + return man.IsTokenValid(profileConfig.APIKey) +} diff --git a/conf/validate_test.go b/conf/validate_test.go new file mode 100644 index 0000000..9c0fdb3 --- /dev/null +++ b/conf/validate_test.go @@ -0,0 +1,130 @@ +package conf + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func createTestToken(exp time.Time, iat time.Time) string { + claims := jwt.MapClaims{ + "exp": exp.Unix(), + "iat": iat.Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + // We don't need a real signature for ParseUnverified + tokenString, _ := token.SignedString([]byte("secret")) + return tokenString +} + +func TestIsTokenValid(t *testing.T) { + man := &Manager{} + now := time.Now().UTC() + + tests := []struct { + name string + token string + want bool + wantErr bool + }{ + { + name: "Valid Token", + token: createTestToken(now.Add(time.Hour), now.Add(-time.Hour)), + want: true, + wantErr: false, + }, + { + name: "Expired Token", + token: createTestToken(now.Add(-time.Hour), now.Add(-2*time.Hour)), + want: false, + wantErr: true, + }, + { + name: "Not Yet Valid Token", + token: createTestToken(now.Add(2*time.Hour), now.Add(time.Hour)), + want: false, + wantErr: true, + }, + { + name: "Empty Token", + token: "", + want: false, + wantErr: true, + }, + { + name: "Invalid Token Format", + token: "not.a.token", + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := man.IsTokenValid(tt.token) + if (err != nil) != tt.wantErr { + t.Errorf("IsTokenValid() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("IsTokenValid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsCredentialValid(t *testing.T) { + man := &Manager{} + now := time.Now().UTC() + validToken := createTestToken(now.Add(time.Hour), now.Add(-time.Hour)) + expiredToken := createTestToken(now.Add(-time.Hour), now.Add(-2*time.Hour)) + + tests := []struct { + name string + cred *Credential + want bool + wantErr bool + }{ + { + name: "Both Valid", + cred: &Credential{ + AccessToken: validToken, + APIKey: validToken, + }, + want: true, + wantErr: false, + }, + { + name: "AccessToken Invalid, APIKey Valid (Needs Refresh)", + cred: &Credential{ + AccessToken: expiredToken, + APIKey: validToken, + }, + want: false, + wantErr: true, + }, + { + name: "Both Invalid", + cred: &Credential{ + AccessToken: expiredToken, + APIKey: expiredToken, + }, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := man.IsCredentialValid(tt.cred) + if (err != nil) != tt.wantErr { + t.Errorf("IsCredentialValid() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("IsCredentialValid() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/docs/optimal-chunk-size.md b/docs/optimal-chunk-size.md index 86019ba..55cb850 100644 --- a/docs/optimal-chunk-size.md +++ b/docs/optimal-chunk-size.md @@ -1,5 +1,4 @@ - # Engineering note — Optimal Chunk Size Calculation for Multipart Uploads ## OLD: @@ -65,7 +64,7 @@ Examples: ```bash -go test ./client/upload -run '^TestOptimalChunkSize$' -v +go test ./upload -run '^TestOptimalChunkSize$' -v ``` @@ -141,12 +140,12 @@ Parameterized test cases (file size ⇒ expected chunk ⇒ expected parts) - parts: 1,048,576 / 1024 = `1024` Test design notes (concise) -1. Use table-driven subtests in `client/upload/utils_test.go`. Include fields: name, `fileSize int64`, `wantChunk int64`, `wantParts int64`. +1. Use table-driven subtests in `upload/utils_test.go`. Include fields: name, `fileSize int64`, `wantChunk int64`, `wantParts int64`. 2. For scaled cases assert: MB alignment, clamped to min/max, and exact `wantParts`. Use integer arithmetic for parts. 3. Add explicit boundary triples for each threshold: exact, -1 byte, +1 byte. 4. Include negative and zero cases to verify fallback behavior. 5. Keep tests deterministic and fast (no external deps). Execution -- Run from repo root: `go test ./client/upload -v` -- Run single test: `go test ./client/upload -run '^TestOptimalChunkSize$' -v` \ No newline at end of file +- Run from repo root: `go test ./upload -v` +- Run single test: `go test ./upload -run '^TestOptimalChunkSize$' -v` \ No newline at end of file diff --git a/client/download/batch.go b/download/batch.go similarity index 90% rename from client/download/batch.go rename to download/batch.go index be46051..967f16a 100644 --- a/client/download/batch.go +++ b/download/batch.go @@ -9,9 +9,9 @@ import ( "sync" "sync/atomic" - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" "github.com/hashicorp/go-multierror" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -21,7 +21,7 @@ import ( // downloadFiles performs bounded parallel downloads and collects ALL errors func downloadFiles( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, files []common.FileDownloadResponseObject, numParallel int, protocol string, @@ -40,13 +40,8 @@ func downloadFiles( // Scoreboard: maxRetries = 0 for now (no retry logic yet) sb := logs.NewSB(0, logger) - useProgressBars := true - for _, fdr := range files { - if fdr.Progress != nil { - useProgressBars = false - break - } - } + progress := common.GetProgress(ctx) + useProgressBars := (progress == nil) var p *mpb.Progress if useProgressBars { @@ -133,8 +128,8 @@ func downloadFiles( } writer = bar.ProxyWriter(file) - } else if fdr.Progress != nil { - tracker = newProgressWriter(file, fdr.Progress, resolveDownloadOID(*fdr), total) + } else if progress != nil { + tracker = newProgressWriter(file, progress, fdr.GUID, total) writer = tracker } diff --git a/client/download/downloader.go b/download/downloader.go similarity index 67% rename from client/download/downloader.go rename to download/downloader.go index 92683e2..044eeb8 100644 --- a/client/download/downloader.go +++ b/download/downloader.go @@ -3,12 +3,12 @@ package download import ( "context" "fmt" + "log/slog" "os" "strings" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" ) @@ -16,7 +16,7 @@ import ( // DownloadMultiple is the public entry point called from g3cmd func DownloadMultiple( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, objects []common.ManifestObject, downloadPath string, filenameFormat string, @@ -47,12 +47,12 @@ func DownloadMultiple( return fmt.Errorf("filename-format must be one of: original, guid, combined") } if (filenameFormat == "guid" || filenameFormat == "combined") && rename { - logger.Println("NOTICE: rename flag is ignored in guid/combined mode") + logger.WarnContext(ctx, "NOTICE: rename flag is ignored in guid/combined mode") rename = false } // === Warnings and user confirmation === - if err := handleWarningsAndConfirmation(logger, downloadPath, filenameFormat, rename, noPrompt); err != nil { + if err := handleWarningsAndConfirmation(ctx, logger.Logger, downloadPath, filenameFormat, rename, noPrompt); err != nil { return err // aborted by user } @@ -67,39 +67,41 @@ func DownloadMultiple( return err } - logger.Printf("Total objects: %d | To download: %d | Skipped: %d\n", - len(objects), len(toDownload), len(skipped)) + logger.InfoContext(ctx, "Summary", + "Total objects", len(objects), + "To download", len(toDownload), + "Skipped", len(skipped)) // === Download phase === downloaded, downloadErr := downloadFiles(ctx, g3i, toDownload, numParallel, protocol) // === Final summary === - logger.Printf("%d files downloaded successfully.\n", downloaded) - printRenamed(logger, renamed) - printSkipped(logger, skipped) + logger.InfoContext(ctx, fmt.Sprintf("%d files downloaded successfully.", downloaded)) + printRenamed(ctx, logger.Logger, renamed) + printSkipped(ctx, logger.Logger, skipped) if downloadErr != nil { - logger.Printf("Some downloads failed. See errors above.\n") + logger.WarnContext(ctx, "Some downloads failed. See errors above.") } return nil // we log failures but don't fail the whole command unless critical } // handleWarningsAndConfirmation prints warnings and asks for confirmation if needed -func handleWarningsAndConfirmation(logger logs.Logger, downloadPath, filenameFormat string, rename, noPrompt bool) error { +func handleWarningsAndConfirmation(ctx context.Context, logger *slog.Logger, downloadPath, filenameFormat string, rename, noPrompt bool) error { if filenameFormat == "guid" || filenameFormat == "combined" { - logger.Printf("WARNING: in %q mode, duplicate files in %q will be overwritten\n", filenameFormat, downloadPath) + logger.WarnContext(ctx, fmt.Sprintf("WARNING: in %q mode, duplicate files in %q will be overwritten", filenameFormat, downloadPath)) } else if !rename { - logger.Printf("WARNING: rename=false in original mode – duplicates in %q will be overwritten\n", downloadPath) + logger.WarnContext(ctx, fmt.Sprintf("WARNING: rename=false in original mode – duplicates in %q will be overwritten", downloadPath)) } else { - logger.Printf("NOTICE: rename=true in original mode – duplicates in %q will be renamed with a counter\n", downloadPath) + logger.InfoContext(ctx, fmt.Sprintf("NOTICE: rename=true in original mode – duplicates in %q will be renamed with a counter", downloadPath)) } if noPrompt { return nil } if !AskForConfirmation(logger, "Proceed? (y/N)") { - logger.Fatal("Aborted by user") + return fmt.Errorf("aborted by user") } return nil } @@ -107,7 +109,7 @@ func handleWarningsAndConfirmation(logger logs.Logger, downloadPath, filenameFor // prepareFiles gathers metadata, checks local files, collects skips/renames func prepareFiles( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, objects []common.ManifestObject, downloadPath, filenameFormat string, rename, skipCompleted bool, @@ -125,8 +127,8 @@ func prepareFiles( ) for _, obj := range objects { - if obj.ObjectID == "" { - logger.Println("Empty GUID, skipping entry") + if obj.GUID == "" { + logger.WarnContext(ctx, "Empty GUID, skipping entry") bar.Increment() continue } @@ -135,7 +137,7 @@ func prepareFiles( var err error if info.Name == "" || info.Size == 0 { // Very strict object id checking - info, err = AskGen3ForFileInfo(ctx, g3i, obj.ObjectID, protocol, downloadPath, filenameFormat, rename, &renamed) + info, err = AskGen3ForFileInfo(ctx, g3i, obj.GUID, protocol, downloadPath, filenameFormat, rename, &renamed) if err != nil { return nil, nil, nil, err } @@ -144,7 +146,7 @@ func prepareFiles( fdr := common.FileDownloadResponseObject{ DownloadPath: downloadPath, Filename: info.Name, - GUID: obj.ObjectID, + GUID: obj.GUID, } if !rename { @@ -152,7 +154,7 @@ func prepareFiles( } if fdr.Skip { - logger.Printf("Skipping %q (GUID: %s) – complete local copy exists\n", fdr.Filename, fdr.GUID) + logger.InfoContext(ctx, fmt.Sprintf("Skipping %q (GUID: %s) – complete local copy exists", fdr.Filename, fdr.GUID)) skipped = append(skipped, RenamedOrSkippedFileInfo{GUID: fdr.GUID, OldFilename: fdr.Filename}) } else { toDownload = append(toDownload, fdr) @@ -161,6 +163,6 @@ func prepareFiles( bar.Increment() } p.Wait() - logger.Println("Preparation complete") + logger.InfoContext(ctx, "Preparation complete") return toDownload, skipped, renamed, nil } diff --git a/client/download/file_info.go b/download/file_info.go similarity index 70% rename from client/download/file_info.go rename to download/file_info.go index e3b6a89..8fb8134 100644 --- a/client/download/file_info.go +++ b/download/file_info.go @@ -6,19 +6,19 @@ import ( "fmt" "net/http" - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/request" ) func AskGen3ForFileInfo( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, guid, protocol, downloadPath, filenameFormat string, rename bool, renamedFiles *[]RenamedOrSkippedFileInfo, ) (*IndexdResponse, error) { - hasShepherd, err := g3i.CheckForShepherdAPI(ctx) + hasShepherd, err := g3i.Fence().CheckForShepherdAPI(ctx) if err != nil { g3i.Logger().Println("Error checking Shepherd API: " + err.Error()) g3i.Logger().Println("Falling back to Indexd...") @@ -26,19 +26,29 @@ func AskGen3ForFileInfo( } if hasShepherd { - return fetchFromShepherd(ctx, g3i, guid, downloadPath, filenameFormat, renamedFiles) + info, err := fetchFromShepherd(ctx, g3i, guid, downloadPath, filenameFormat, renamedFiles) + if err == nil { + return info, nil + } + g3i.Logger().Printf("Shepherd fetch failed for %s: %v. Falling back to Indexd...\n", guid, err) + } + info, err := fetchFromIndexd(ctx, g3i, http.MethodGet, guid, protocol, downloadPath, filenameFormat, rename, renamedFiles) + if err != nil { + g3i.Logger().Printf("All meta-data lookups failed for %s: %v. Using GUID as default filename.\n", guid, err) + *renamedFiles = append(*renamedFiles, RenamedOrSkippedFileInfo{GUID: guid, OldFilename: guid, NewFilename: guid}) + return &IndexdResponse{guid, 0}, nil } - return fetchFromIndexd(ctx, g3i, http.MethodGet, guid, protocol, downloadPath, filenameFormat, rename, renamedFiles) + return info, nil } func fetchFromShepherd( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, guid, downloadPath, filenameFormat string, renamedFiles *[]RenamedOrSkippedFileInfo, ) (*IndexdResponse, error) { cred := g3i.GetCredential() - res, err := g3i.Do(ctx, + res, err := g3i.Fence().Do(ctx, &request.RequestBuilder{ Url: cred.APIEndpoint + "/" + cred.AccessToken + common.ShepherdEndpoint + "/objects/" + guid, Method: http.MethodGet, @@ -64,14 +74,14 @@ func fetchFromShepherd( func fetchFromIndexd( ctx context.Context, - g3i client.Gen3Interface, method, + g3i g3client.Gen3Interface, method, guid, protocol, downloadPath, filenameFormat string, rename bool, renamedFiles *[]RenamedOrSkippedFileInfo, ) (*IndexdResponse, error) { cred := g3i.GetCredential() - resp, err := g3i.Do( + resp, err := g3i.Fence().Do( ctx, &request.RequestBuilder{ Url: cred.APIEndpoint + common.IndexdIndexEndpoint + "/" + guid, @@ -80,11 +90,11 @@ func fetchFromIndexd( }, ) if err != nil { - return nil, fmt.Errorf("Error in fetch FromIndexd: %s", err) + return nil, fmt.Errorf("error in fetch FromIndexd: %s", err) } defer resp.Body.Close() - msg, err := g3i.ParseFenceURLResponse(resp) + msg, err := g3i.Fence().ParseFenceURLResponse(resp) if err != nil { return nil, err } diff --git a/client/download/progress_writer.go b/download/progress_writer.go similarity index 53% rename from client/download/progress_writer.go rename to download/progress_writer.go index 9ed8ab0..dd1abf0 100644 --- a/client/download/progress_writer.go +++ b/download/progress_writer.go @@ -1,24 +1,25 @@ package download import ( + "fmt" "io" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) type progressWriter struct { writer io.Writer onProgress common.ProgressCallback - oid string + hash string total int64 bytesSoFar int64 } -func newProgressWriter(writer io.Writer, onProgress common.ProgressCallback, oid string, total int64) *progressWriter { +func newProgressWriter(writer io.Writer, onProgress common.ProgressCallback, hash string, total int64) *progressWriter { return &progressWriter{ writer: writer, onProgress: onProgress, - oid: oid, + hash: hash, total: total, } } @@ -30,7 +31,7 @@ func (pw *progressWriter) Write(p []byte) (int, error) { pw.bytesSoFar += delta if progressErr := pw.onProgress(common.ProgressEvent{ Event: "progress", - Oid: pw.oid, + Oid: pw.hash, BytesSoFar: pw.bytesSoFar, BytesSinceLast: delta, }); progressErr != nil { @@ -41,28 +42,18 @@ func (pw *progressWriter) Write(p []byte) (int, error) { } func (pw *progressWriter) Finalize() error { - if pw.onProgress == nil { - return nil - } - if pw.total == 0 || pw.bytesSoFar >= pw.total { - return nil - } - delta := pw.total - pw.bytesSoFar - pw.bytesSoFar = pw.total - return pw.onProgress(common.ProgressEvent{ - Event: "progress", - Oid: pw.oid, - BytesSoFar: pw.bytesSoFar, - BytesSinceLast: delta, - }) -} - -func resolveDownloadOID(fdr common.FileDownloadResponseObject) string { - if fdr.OID != "" { - return fdr.OID - } - if fdr.GUID != "" { - return fdr.GUID + if pw.total > 0 && pw.bytesSoFar < pw.total { + delta := pw.total - pw.bytesSoFar + pw.bytesSoFar = pw.total + if pw.onProgress != nil { + _ = pw.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pw.hash, + BytesSoFar: pw.bytesSoFar, + BytesSinceLast: delta, + }) + } + return fmt.Errorf("download incomplete: %d/%d bytes", pw.bytesSoFar-delta, pw.total) } - return fdr.Filename + return nil } diff --git a/client/download/progress_writer_test.go b/download/progress_writer_test.go similarity index 95% rename from client/download/progress_writer_test.go rename to download/progress_writer_test.go index 8d573c8..b11af3d 100644 --- a/client/download/progress_writer_test.go +++ b/download/progress_writer_test.go @@ -5,7 +5,7 @@ import ( "io" "testing" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) func TestProgressWriterFinalizes(t *testing.T) { diff --git a/client/download/transfer.go b/download/transfer.go similarity index 50% rename from client/download/transfer.go rename to download/transfer.go index e54ddab..d171313 100644 --- a/client/download/transfer.go +++ b/download/transfer.go @@ -8,20 +8,19 @@ import ( "path/filepath" "strings" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/g3client" ) // DownloadSingleWithProgress downloads a single object while emitting progress events. func DownloadSingleWithProgress( ctx context.Context, - g3i client.Gen3Interface, + g3i g3client.Gen3Interface, guid string, downloadPath string, protocol string, - oid string, - progress common.ProgressCallback, ) error { + progress := common.GetProgress(ctx) var err error downloadPath, err = common.ParseRootPath(downloadPath) if err != nil { @@ -41,8 +40,6 @@ func DownloadSingleWithProgress( DownloadPath: downloadPath, Filename: info.Name, GUID: guid, - OID: oid, - Progress: progress, } protocolText := "" @@ -74,11 +71,11 @@ func DownloadSingleWithProgress( return fmt.Errorf("open local file %s: %w", fullPath, err) } - total := fdr.Response.ContentLength + fdr.Range + total := info.Size var writer io.Writer = file var tracker *progressWriter - if fdr.Progress != nil { - tracker = newProgressWriter(file, fdr.Progress, resolveDownloadOID(fdr), total) + if progress != nil { + tracker = newProgressWriter(file, progress, guid, total) writer = tracker } @@ -95,3 +92,57 @@ func DownloadSingleWithProgress( } return nil } + +// DownloadToPath downloads a single object by GUID to a specific destination file path. +// It bypasses the name lookup from Gen3 and uses the provided dstPath directly. +func DownloadToPath( + ctx context.Context, + g3i g3client.Gen3Interface, + guid string, + dstPath string, +) error { + progress := common.GetProgress(ctx) + hash := common.GetOid(ctx) + logger := g3i.Logger() + // logger.Printf("Downloading %s to %s\n", guid, dstPath) + + fdr := common.FileDownloadResponseObject{ + GUID: guid, + } + + if err := GetDownloadResponse(ctx, g3i, &fdr, ""); err != nil { + logger.FailedContext(ctx, dstPath, filepath.Base(dstPath), common.FileMetadata{}, guid, 0, false) + return err + } + defer fdr.Response.Body.Close() + + if dir := filepath.Dir(dstPath); dir != "." { + if err := os.MkdirAll(dir, 0766); err != nil { + logger.FailedContext(ctx, dstPath, filepath.Base(dstPath), common.FileMetadata{}, guid, 0, false) + return fmt.Errorf("mkdir for %s: %w", dstPath, err) + } + } + + file, err := os.Create(dstPath) + if err != nil { + logger.FailedContext(ctx, dstPath, filepath.Base(dstPath), common.FileMetadata{}, guid, 0, false) + return fmt.Errorf("create local file %s: %w", dstPath, err) + } + defer file.Close() + + var writer io.Writer = file + if progress != nil { + total := fdr.Response.ContentLength + tracker := newProgressWriter(file, progress, hash, total) + writer = tracker + defer tracker.Finalize() + } + + if _, err := io.Copy(writer, fdr.Response.Body); err != nil { + logger.FailedContext(ctx, dstPath, filepath.Base(dstPath), common.FileMetadata{}, guid, 0, false) + return fmt.Errorf("copy to %s: %w", dstPath, err) + } + + logger.SucceededContext(ctx, dstPath, guid) + return nil +} diff --git a/client/download/transfer_test.go b/download/transfer_test.go similarity index 58% rename from client/download/transfer_test.go rename to download/transfer_test.go index 7c702dc..aab05e7 100644 --- a/client/download/transfer_test.go +++ b/download/transfer_test.go @@ -3,6 +3,7 @@ package download import ( "bytes" "context" + "encoding/json" "errors" "io" "net/http" @@ -12,42 +13,76 @@ import ( "strings" "testing" - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/indexd" + "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/request" + "github.com/calypr/data-client/sower" ) type fakeGen3Download struct { cred *conf.Credential - logger *logs.TeeLogger + logger *logs.Gen3Logger doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) } func (f *fakeGen3Download) GetCredential() *conf.Credential { return f.cred } -func (f *fakeGen3Download) Logger() *logs.TeeLogger { return f.logger } -func (f *fakeGen3Download) New(method, url string) *request.RequestBuilder { - return &request.RequestBuilder{Method: method, Url: url} +func (f *fakeGen3Download) Logger() *logs.Gen3Logger { return f.logger } +func (f *fakeGen3Download) ExportCredential(ctx context.Context, cred *conf.Credential) error { + return nil } -func (f *fakeGen3Download) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { +func (f *fakeGen3Download) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } +func (f *fakeGen3Download) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } +func (f *fakeGen3Download) Sower() sower.SowerInterface { return nil } + +type fakeFence struct { + fence.FenceInterface + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeFence) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { return f.doFunc(ctx, req) } -func (f *fakeGen3Download) CheckPrivileges(context.Context) (map[string]any, error) { - return nil, nil +func (f *fakeFence) New(method, url string) *request.RequestBuilder { + return &request.RequestBuilder{Method: method, Url: url, Headers: make(map[string]string)} } -func (f *fakeGen3Download) CheckForShepherdAPI(context.Context) (bool, error) { return false, nil } -func (f *fakeGen3Download) DeleteRecord(context.Context, string) (string, error) { - return "", nil +func (f *fakeFence) CheckForShepherdAPI(ctx context.Context) (bool, error) { return false, nil } +func (f *fakeFence) ResolveOID(ctx context.Context, oid string) (fence.FenceResponse, error) { + return fence.FenceResponse{}, nil } -func (f *fakeGen3Download) GetDownloadPresignedUrl(context.Context, string, string) (string, error) { +func (f *fakeFence) GetDownloadPresignedUrl(ctx context.Context, guid, protocol string) (string, error) { + if guid == "test-fallback" { + return "", errors.New("fence fallback") + } return "https://download.example.com/object", nil } -func (f *fakeGen3Download) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { - return (&api.Functions{}).ParseFenceURLResponse(resp) +func (f *fakeFence) ParseFenceURLResponse(resp *http.Response) (fence.FenceResponse, error) { + var msg fence.FenceResponse + if resp != nil && resp.Body != nil { + json.NewDecoder(resp.Body).Decode(&msg) + } + return msg, nil +} + +type fakeIndexd struct { + indexd.IndexdInterface + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeIndexd) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + return f.doFunc(ctx, req) +} + +func (f *fakeIndexd) New(method, url string) *request.RequestBuilder { + return &request.RequestBuilder{Method: method, Url: url, Headers: make(map[string]string)} +} + +func (f *fakeIndexd) GetDownloadURL(ctx context.Context, did string, accessType string) (*drs.AccessURL, error) { + return &drs.AccessURL{URL: "https://download.example.com/object"}, nil } -func (f *fakeGen3Download) ExportCredential(context.Context, *conf.Credential) error { return nil } -func (f *fakeGen3Download) NewAccessToken(context.Context) error { return nil } func TestDownloadSingleWithProgressEmitsEvents(t *testing.T) { payload := bytes.Repeat([]byte("d"), 64) @@ -62,7 +97,7 @@ func TestDownloadSingleWithProgressEmitsEvents(t *testing.T) { fake := &fakeGen3Download{ cred: &conf.Credential{APIEndpoint: "https://example.com", AccessToken: "token"}, - logger: logs.NewTeeLogger("", "", io.Discard), + logger: logs.NewGen3Logger(nil, "", ""), doFunc: func(_ context.Context, req *request.RequestBuilder) (*http.Response, error) { switch { case strings.Contains(req.Url, common.IndexdIndexEndpoint): @@ -75,7 +110,8 @@ func TestDownloadSingleWithProgressEmitsEvents(t *testing.T) { }, } - err := DownloadSingleWithProgress(context.Background(), fake, "guid-123", downloadPath, "", "oid-123", progress) + ctx := common.WithProgress(context.Background(), progress) + err := DownloadSingleWithProgress(ctx, fake, "guid-123", downloadPath, "") if err != nil { t.Fatalf("download failed: %v", err) } @@ -110,7 +146,7 @@ func TestDownloadSingleWithProgressFinalizeOnError(t *testing.T) { fake := &fakeGen3Download{ cred: &conf.Credential{APIEndpoint: "https://example.com", AccessToken: "token"}, - logger: logs.NewTeeLogger("", "", io.Discard), + logger: logs.NewGen3Logger(nil, "", ""), doFunc: func(_ context.Context, req *request.RequestBuilder) (*http.Response, error) { switch { case strings.Contains(req.Url, common.IndexdIndexEndpoint): @@ -123,7 +159,8 @@ func TestDownloadSingleWithProgressFinalizeOnError(t *testing.T) { }, } - err := DownloadSingleWithProgress(context.Background(), fake, "guid-123", downloadPath, "", "oid-123", progress) + ctx := common.WithProgress(context.Background(), progress) + err := DownloadSingleWithProgress(ctx, fake, "guid-123", downloadPath, "") if err == nil { t.Fatal("expected download error") } diff --git a/client/download/types.go b/download/types.go similarity index 90% rename from client/download/types.go rename to download/types.go index 651b97e..c910b67 100644 --- a/client/download/types.go +++ b/download/types.go @@ -3,8 +3,8 @@ package download import ( "os" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/logs" ) type IndexdResponse struct { @@ -18,7 +18,7 @@ type RenamedOrSkippedFileInfo struct { } func validateLocalFileStat( - logger logs.Logger, + logger *logs.Gen3Logger, fdr *common.FileDownloadResponseObject, filesize int64, skipCompleted bool, diff --git a/download/url_resolution.go b/download/url_resolution.go new file mode 100644 index 0000000..d7427c3 --- /dev/null +++ b/download/url_resolution.go @@ -0,0 +1,87 @@ +package download + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" +) + +// GetDownloadResponse gets presigned URL and prepares HTTP response +func GetDownloadResponse(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject, protocolText string) error { + // 1. Try Fence first + url, err := g3.Fence().GetDownloadPresignedUrl(ctx, fdr.GUID, protocolText) + if err == nil && url != "" { + fdr.PresignedURL = url + } else { + // 2. Fallback to IndexD DRS endpoint + accessType := "s3" + if strings.HasPrefix(protocolText, "?protocol=") { + accessType = strings.TrimPrefix(protocolText, "?protocol=") + } else if protocolText == "?protocol=gs" { + accessType = "gs" + } + + accessURL, errIdx := g3.Indexd().GetDownloadURL(ctx, fdr.GUID, accessType) + if errIdx == nil && accessURL != nil && accessURL.URL != "" { + fdr.PresignedURL = accessURL.URL + // Some DRS providers might return required headers + // This is not currently used by makeDownloadRequest but good to have for future + } else { + if err != nil { + return err + } + if errIdx != nil { + return errIdx + } + return fmt.Errorf("failed to resolve download URL for %s", fdr.GUID) + } + } + + return makeDownloadRequest(ctx, g3, fdr) +} + +func isCloudPresignedURL(url string) bool { + return strings.Contains(url, "X-Amz-Signature") || + strings.Contains(url, "X-Goog-Signature") || + strings.Contains(url, "Signature=") || + strings.Contains(url, "AWSAccessKeyId=") || + strings.Contains(url, "Expires=") +} + +func makeDownloadRequest(ctx context.Context, g3 client.Gen3Interface, fdr *common.FileDownloadResponseObject) error { + skipAuth := isCloudPresignedURL(fdr.PresignedURL) + rb := g3.Fence().New(http.MethodGet, fdr.PresignedURL).WithSkipAuth(skipAuth) + + if fdr.Range > 0 { + rb.WithHeader("Range", "bytes="+strconv.FormatInt(fdr.Range, 10)+"-") + } + + resp, err := g3.Fence().Do(ctx, rb) + + if err != nil { + return errors.New("Request failed: " + strings.ReplaceAll(err.Error(), fdr.PresignedURL, "")) + } + + // Check for non-success status codes + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + defer resp.Body.Close() // Ensure the body is closed + + bodyBytes, err := io.ReadAll(resp.Body) + bodyString := "" + if err == nil { + bodyString = string(bodyBytes) + } + + return fmt.Errorf("non-OK response: %d, body: %s", resp.StatusCode, bodyString) + } + + fdr.Response = resp + return nil +} diff --git a/client/download/utils.go b/download/utils.go similarity index 67% rename from client/download/utils.go rename to download/utils.go index 864a0c6..7209c44 100644 --- a/client/download/utils.go +++ b/download/utils.go @@ -7,19 +7,22 @@ import ( "strconv" "strings" - "github.com/calypr/data-client/client/logs" + "context" + "fmt" + "log/slog" ) // AskForConfirmation asks user for confirmation before proceed, will wait if user entered garbage -func AskForConfirmation(logger logs.Logger, s string) bool { +func AskForConfirmation(logger *slog.Logger, s string) bool { reader := bufio.NewReader(os.Stdin) for { - logger.Printf("%s [y/n]: ", s) + logger.Info(fmt.Sprintf("%s [y/n]: ", s)) response, err := reader.ReadString('\n') if err != nil { - logger.Fatal("Error occurred during parsing user's confirmation: " + err.Error()) + logger.Error("Error occurred during parsing user's confirmation: " + err.Error()) + os.Exit(1) } switch strings.ToLower(strings.TrimSpace(response)) { @@ -60,20 +63,15 @@ func truncateFilename(name string, max int) string { } // printRenamed shows renamed files in final summary -func printRenamed(logger logs.Logger, renamed []RenamedOrSkippedFileInfo) { - if len(renamed) == 0 { - return - } - logger.Printf("%d files renamed:\n", len(renamed)) +func printRenamed(ctx context.Context, logger *slog.Logger, renamed []RenamedOrSkippedFileInfo) { for _, r := range renamed { - logger.Printf(" %q (GUID: %s) → %q\n", r.OldFilename, r.GUID, r.NewFilename) + logger.InfoContext(ctx, fmt.Sprintf("Renamed %q to %q (GUID: %s)", r.OldFilename, r.NewFilename, r.GUID)) } } // printSkipped shows skipped files in final summary -func printSkipped(logger logs.Logger, skipped []RenamedOrSkippedFileInfo) { - if len(skipped) == 0 { - return +func printSkipped(ctx context.Context, logger *slog.Logger, skipped []RenamedOrSkippedFileInfo) { + for _, s := range skipped { + logger.InfoContext(ctx, fmt.Sprintf("Skipped %q (GUID: %s)", s.OldFilename, s.GUID)) } - logger.Printf("%d files skipped (complete local copy exists)\n", len(skipped)) } diff --git a/fence/client.go b/fence/client.go new file mode 100644 index 0000000..4a5cacf --- /dev/null +++ b/fence/client.go @@ -0,0 +1,637 @@ +package fence + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "log/slog" + + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/request" + "github.com/hashicorp/go-version" +) + +// FenceBucketEndpoint is the endpoint postfix for FENCE bucket list +const FenceBucketEndpoint = "/user/data/buckets" + +//go:generate mockgen -destination=../mocks/mock_fence.go -package=mocks github.com/calypr/data-client/fence FenceInterface + +// FenceInterface defines the interface for Fence client +type FenceInterface interface { + request.RequestInterface + + NewAccessToken(ctx context.Context) (string, error) + CheckPrivileges(ctx context.Context) (map[string]any, error) + CheckForShepherdAPI(ctx context.Context) (bool, error) + DeleteRecord(ctx context.Context, guid string) (string, error) + GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) + + UserPing(ctx context.Context) (*PingResp, error) + + // Bucket details + GetBucketDetails(ctx context.Context, bucket string) (*S3Bucket, error) + + // Upload methods + InitUpload(ctx context.Context, filename string, bucket string, guid string) (FenceResponse, error) + GetUploadPresignedUrl(ctx context.Context, guid string, filename string, bucket string) (FenceResponse, error) + + // Multipart methods + InitMultipartUpload(ctx context.Context, filename string, bucket string, guid string) (FenceResponse, error) + GenerateMultipartPresignedURL(ctx context.Context, key string, uploadID string, partNumber int, bucket string) (string, error) + CompleteMultipartUpload(ctx context.Context, key string, uploadID string, parts []MultipartPart, bucket string) error + ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) + + RefreshToken(ctx context.Context) error +} + +// FenceClient implements FenceInterface +// FenceClient implements FenceInterface +type FenceClient struct { + request.RequestInterface + cred *conf.Credential + logger *slog.Logger +} + +// NewFenceClient creates a new FenceClient +func NewFenceClient(req request.RequestInterface, cred *conf.Credential, logger *slog.Logger) FenceInterface { + return &FenceClient{ + RequestInterface: req, + cred: cred, + logger: logger, + } +} + +func (f *FenceClient) NewAccessToken(ctx context.Context) (string, error) { + if f.cred.APIKey == "" { + return "", errors.New("APIKey is required to refresh access token") + } + + payload, err := json.Marshal(map[string]string{"api_key": f.cred.APIKey}) + if err != nil { + return "", err + } + bodyReader := bytes.NewReader(payload) + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Method: http.MethodPost, + Url: f.cred.APIEndpoint + common.FenceAccessTokenEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: bodyReader, + }, + ) + + if err != nil { + return "", fmt.Errorf("error when calling Request.Do: %s", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", errors.New("failed to refresh token, status: " + strconv.Itoa(resp.StatusCode)) + } + + var result common.AccessTokenStruct + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", errors.New("failed to parse token response: " + err.Error()) + } + + return result.AccessToken, nil +} + +func (f *FenceClient) RefreshToken(ctx context.Context) error { + token, err := f.NewAccessToken(ctx) + if err != nil { + return err + } + f.cred.AccessToken = token + return nil +} + +func (f *FenceClient) CheckPrivileges(ctx context.Context) (map[string]any, error) { + resp, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceUserEndpoint, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return nil, errors.New("error occurred when getting response from remote: " + err.Error()) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var data map[string]any + err = json.Unmarshal(bodyBytes, &data) + if err != nil { + return nil, errors.New("error occurred when unmarshalling response: " + err.Error()) + } + + resourceAccess, ok := data["authz"].(map[string]any) + + // If the `authz` section (Arborist permissions) is empty or missing, try get `project_access` section (Fence permissions) + if len(resourceAccess) == 0 || !ok { + resourceAccess, ok = data["project_access"].(map[string]any) + if !ok { + return nil, errors.New("not possible to read access privileges of user") + } + } + + return resourceAccess, nil +} + +func (f *FenceClient) CheckForShepherdAPI(ctx context.Context) (bool, error) { + // Check if Shepherd is enabled + if f.cred.UseShepherd == "false" { + return false, nil + } + if f.cred.UseShepherd != "true" && common.DefaultUseShepherd == false { + return false, nil + } + // If Shepherd is enabled, make sure that the commons has a compatible version of Shepherd deployed. + // Compare the version returned from the Shepherd version endpoint with the minimum acceptable Shepherd version. + var minShepherdVersion string + if f.cred.MinShepherdVersion == "" { + minShepherdVersion = common.DefaultMinShepherdVersion + } else { + minShepherdVersion = f.cred.MinShepherdVersion + } + + res, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.ShepherdVersionEndpoint, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return false, errors.New("Error occurred during generating HTTP request: " + err.Error()) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return false, nil + } + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + return false, errors.New("Error occurred when reading HTTP request: " + err.Error()) + } + body, err := strconv.Unquote(string(bodyBytes)) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) + } + // Compare the version in the response to the target version + ver, err := version.NewVersion(body) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing version from Shepherd: %v: %v", string(body), err) + } + minVer, err := version.NewVersion(minShepherdVersion) + if err != nil { + return false, fmt.Errorf("Error occurred when parsing minimum acceptable Shepherd version: %v: %v", minShepherdVersion, err) + } + if ver.GreaterThanOrEqual(minVer) { + return true, nil + } + return false, fmt.Errorf("Shepherd is enabled, but %v does not have correct Shepherd version. (Need Shepherd version >=%v, got %v)", f.cred.APIEndpoint, minVer, ver) +} + +func (f *FenceClient) DeleteRecord(ctx context.Context, guid string) (string, error) { + hasShepherd, err := f.CheckForShepherdAPI(ctx) + if err != nil { + f.logger.Warn(fmt.Sprintf("WARNING: Error checking Shepherd API: %v. Falling back to Fence.\n", err)) + } else if hasShepherd { + resp, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.ShepherdEndpoint + "/objects/" + guid, + Method: http.MethodDelete, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode == 204 { + return "Record with GUID " + guid + " has been deleted", nil + } + return "", fmt.Errorf("shepherd delete failed: %d", resp.StatusCode) + } + + resp, err := f.Do(ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceDataEndpoint + "/" + guid, + Method: http.MethodDelete, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNoContent { + return "Record with GUID " + guid + " has been deleted", nil + } + + _, err = f.ParseFenceURLResponse(resp) + if err != nil { + return "", err + } + return "Record with GUID " + guid + " has been deleted", nil +} + +func (f *FenceClient) GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { + hasShepherd, err := f.CheckForShepherdAPI(ctx) + if err == nil && hasShepherd { + return f.resolveFromShepherd(ctx, guid) + } + return f.resolveFromFence(ctx, guid, protocolText) +} + +func (f *FenceClient) resolveFromShepherd(ctx context.Context, guid string) (string, error) { + url := fmt.Sprintf("%s%s/objects/%s/download", f.cred.APIEndpoint, common.ShepherdEndpoint, guid) + resp, err := f.Do(ctx, &request.RequestBuilder{ + Url: url, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("shepherd error: %d", resp.StatusCode) + } + + var result struct { + URL string `json:"url"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to decode shepherd response: %w", err) + } + + return result.URL, nil +} + +func (f *FenceClient) resolveFromFence(ctx context.Context, guid, protocolText string) (string, error) { + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceDataDownloadEndpoint + "/" + guid + protocolText, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return "", errors.New("failed to get URL from Fence via Do: " + err.Error()) + } + defer resp.Body.Close() + + msg, err := f.ParseFenceURLResponse(resp) + if err != nil || msg.URL == "" { + return "", errors.New("failed to get URL from Fence via ParseFenceURLResponse: " + err.Error()) + } + + return msg.URL, nil +} + +func (f *FenceClient) GetBucketDetails(ctx context.Context, bucket string) (*S3Bucket, error) { + url := f.cred.APIEndpoint + "/user/data/buckets" + resp, err := f.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Token: f.cred.AccessToken, + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch bucket information: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var bucketInfo S3BucketsResponse + if err := json.NewDecoder(resp.Body).Decode(&bucketInfo); err != nil { + return nil, fmt.Errorf("failed to decode bucket information: %w", err) + } + + if info, exists := bucketInfo.S3Buckets[bucket]; exists { + if info.EndpointURL != "" && info.Region != "" { + return info, nil + } + return nil, errors.New("endpoint_url or region not found for bucket") + } + + return nil, nil +} + +func (f *FenceClient) InitUpload(ctx context.Context, filename string, bucket string, guid string) (FenceResponse, error) { + payload := map[string]string{ + "file_name": filename, + } + if bucket != "" { + payload["bucket"] = bucket + } + if guid != "" { + payload["guid"] = guid + } + + buf, err := common.ToJSONReader(payload) + if err != nil { + return FenceResponse{}, err + } + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Method: http.MethodPost, + Url: f.cred.APIEndpoint + common.FenceDataUploadEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: buf, + Token: f.cred.AccessToken, + }) + if err != nil { + return FenceResponse{}, err + } + defer resp.Body.Close() + + return f.ParseFenceURLResponse(resp) +} + +func (f *FenceClient) GetUploadPresignedUrl(ctx context.Context, guid string, filename string, bucket string) (FenceResponse, error) { + endPointPostfix := common.FenceDataUploadEndpoint + "/" + guid + "?file_name=" + url.QueryEscape(filename) + if bucket != "" { + endPointPostfix += "&bucket=" + bucket + } + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + endPointPostfix, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Token: f.cred.AccessToken, + Method: http.MethodGet, + }, + ) + if err != nil { + return FenceResponse{}, err + } + defer resp.Body.Close() + + return f.ParseFenceURLResponse(resp) +} + +func (f *FenceClient) InitMultipartUpload(ctx context.Context, filename string, bucket string, guid string) (FenceResponse, error) { + reader, err := common.ToJSONReader( + InitRequestObject{ + Filename: filename, + Bucket: bucket, + GUID: guid, + }, + ) + if err != nil { + return FenceResponse{}, err + } + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Method: http.MethodPost, + Url: f.cred.APIEndpoint + common.FenceDataMultipartInitEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: reader, + Token: f.cred.AccessToken, + }, + ) + + if err != nil { + return FenceResponse{}, err + } + defer resp.Body.Close() + + return f.ParseFenceURLResponse(resp) +} + +func (f *FenceClient) GenerateMultipartPresignedURL(ctx context.Context, key string, uploadID string, partNumber int, bucket string) (string, error) { + reader, err := common.ToJSONReader( + MultipartUploadRequestObject{ + Key: key, + UploadID: uploadID, + PartNumber: partNumber, + Bucket: bucket, + }, + ) + if err != nil { + return "", err + } + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceDataMultipartUploadEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Method: http.MethodPost, + Body: reader, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return "", err + } + defer resp.Body.Close() + + msg, err := f.ParseFenceURLResponse(resp) + if err != nil { + return "", err + } + + return msg.PresignedURL, nil +} + +func (f *FenceClient) CompleteMultipartUpload(ctx context.Context, key string, uploadID string, parts []MultipartPart, bucket string) error { + multipartCompleteObject := MultipartCompleteRequestObject{Key: key, UploadID: uploadID, Parts: parts, Bucket: bucket} + + reader, err := common.ToJSONReader(multipartCompleteObject) + if err != nil { + return err + } + + resp, err := f.Do( + ctx, + &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceDataMultipartCompleteEndpoint, + Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, + Body: reader, + Method: http.MethodPost, + Token: f.cred.AccessToken, + }, + ) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusNoContent { + return nil + } + + _, err = f.ParseFenceURLResponse(resp) + return err +} + +func (f *FenceClient) ParseFenceURLResponse(resp *http.Response) (FenceResponse, error) { + msg := FenceResponse{} + if resp == nil { + return msg, errors.New("nil response received") + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return msg, fmt.Errorf("failed to read response body: %w", err) + } + bodyStr := string(bodyBytes) + + if len(bodyBytes) > 0 { + err = json.Unmarshal(bodyBytes, &msg) + if err != nil { + return msg, fmt.Errorf("failed to decode JSON: %w (Raw body: %s)", err, bodyStr) + } + } + + if !(resp.StatusCode == 200 || resp.StatusCode == 201 || resp.StatusCode == 204) { + strUrl := resp.Request.URL.String() + switch resp.StatusCode { + case http.StatusUnauthorized: + return msg, fmt.Errorf("401 Unauthorized: %s (URL: %s)", bodyStr, strUrl) + case http.StatusForbidden: + return msg, fmt.Errorf("403 Forbidden: %s (URL: %s)", bodyStr, strUrl) + case http.StatusNotFound: + return msg, fmt.Errorf("404 Not Found: %s (URL: %s)", bodyStr, strUrl) + case http.StatusInternalServerError: + return msg, fmt.Errorf("500 Internal Server Error: %s (URL: %s)", bodyStr, strUrl) + case http.StatusServiceUnavailable: + return msg, fmt.Errorf("503 Service Unavailable: %s (URL: %s)", bodyStr, strUrl) + case http.StatusBadGateway: + return msg, fmt.Errorf("502 Bad Gateway: %s (URL: %s)", bodyStr, strUrl) + default: + return msg, fmt.Errorf("unexpected error (%d): %s (URL: %s)", resp.StatusCode, bodyStr, strUrl) + } + } + + if strings.Contains(bodyStr, "Can't find a location for the data") { + return msg, errors.New("the provided GUID is not found") + } + + return msg, nil +} + +func (f *FenceClient) UserPing(ctx context.Context) (*PingResp, error) { + resp, err := f.Do(ctx, &request.RequestBuilder{ + Url: f.cred.APIEndpoint + common.FenceUserEndpoint, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user info, status: %d", resp.StatusCode) + } + + var uResp FenceUserResp + if err := json.NewDecoder(resp.Body).Decode(&uResp); err != nil { + return nil, err + } + + bucketResp, err := f.Do(ctx, &request.RequestBuilder{ + Url: f.cred.APIEndpoint + FenceBucketEndpoint, + Method: http.MethodGet, + Token: f.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer bucketResp.Body.Close() + + if bucketResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get bucket info, status: %d", bucketResp.StatusCode) + } + + var bResp S3BucketsResponse + if err := json.NewDecoder(bucketResp.Body).Decode(&bResp); err != nil { + return nil, err + } + + return &PingResp{ + Profile: f.cred.Profile, + Username: uResp.Username, + Endpoint: f.cred.APIEndpoint, + BucketPrograms: ParseBucketResp(bResp), + YourAccess: ParseUserResp(uResp), + }, nil +} + +func ParseBucketResp(resp S3BucketsResponse) map[string]string { + bucketsByProgram := make(map[string]string) + + // Check both S3_BUCKETS and s3_buckets + s3Buckets := resp.S3Buckets + if len(s3Buckets) == 0 { + s3Buckets = resp.S3BucketsLower + } + + for bucketName, bucketInfo := range s3Buckets { + var programs strings.Builder + if len(bucketInfo.Programs) > 1 { + for i, p := range bucketInfo.Programs { + if i > 0 { + programs.WriteString(",") + } + programs.WriteString(p) + } + } else if len(bucketInfo.Programs) == 1 { + programs.WriteString(bucketInfo.Programs[0]) + } + bucketsByProgram[bucketName] = programs.String() + } + return bucketsByProgram +} + +func ParseUserResp(resp FenceUserResp) map[string]string { + servicesByPath := make(map[string]string) + for path, permissions := range resp.Authz { + var services strings.Builder + seenServices := make(map[string]bool) + for _, p := range permissions { + if !seenServices[p.Method] { + if services.Len() > 0 { + services.WriteString(",") + } + services.WriteString(p.Method) + seenServices[p.Method] = true + } + } + if services.Len() > 0 { + servicesByPath[path] = services.String() + } + } + return servicesByPath +} diff --git a/fence/client_test.go b/fence/client_test.go new file mode 100644 index 0000000..6a85de3 --- /dev/null +++ b/fence/client_test.go @@ -0,0 +1,250 @@ +package fence + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/request" +) + +type mockFenceServer struct{} + +func (m *mockFenceServer) handler(t *testing.T) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + switch { + case r.Method == http.MethodPost && path == common.FenceAccessTokenEndpoint: + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(common.AccessTokenStruct{AccessToken: "new-access-token"}) + return + case r.Method == http.MethodGet && path == common.FenceUserEndpoint: + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "username": "test-user", + "authz": map[string]any{ + "/resource": []map[string]string{ + {"method": "read", "service": "fence"}, + }, + }, + }) + return + case r.Method == http.MethodGet && path == common.ShepherdVersionEndpoint: + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`"2.0.0"`)) + return + case r.Method == http.MethodDelete && strings.HasPrefix(path, common.ShepherdEndpoint+"/objects/"): + w.WriteHeader(http.StatusNoContent) + return + case r.Method == http.MethodDelete && strings.HasPrefix(path, common.FenceDataEndpoint+"/"): + w.WriteHeader(http.StatusNoContent) + return + case r.Method == http.MethodGet && strings.HasSuffix(path, "/download"): + // Shepherd download + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"url": "https://download.url"}) + return + case r.Method == http.MethodGet && strings.Contains(path, common.FenceDataDownloadEndpoint+"/"): + // Fence download + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(FenceResponse{URL: "https://download.url"}) + return + case r.Method == http.MethodGet && path == "/user/data/buckets": + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(S3BucketsResponse{ + S3Buckets: map[string]*S3Bucket{ + "test-bucket": { + EndpointURL: "https://s3.amazonaws.com", + Region: "us-east-1", + }, + }, + }) + return + case r.Method == http.MethodPost && path == common.FenceDataUploadEndpoint: + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(FenceResponse{GUID: "new-guid", URL: "https://upload.url"}) + return + case r.Method == http.MethodGet && strings.HasPrefix(path, common.FenceDataUploadEndpoint+"/"): + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(FenceResponse{URL: "https://upload.url"}) + return + } + + w.WriteHeader(http.StatusNotFound) + } +} + +func newTestClient(server *httptest.Server) FenceInterface { + cred := &conf.Credential{APIEndpoint: server.URL, Profile: "test", AccessToken: "test-token", APIKey: "test-key"} + logger, _ := logs.New("test") + config := conf.NewConfigure(logger.Logger) + req := request.NewRequestInterface(logger, cred, config) + return NewFenceClient(req, cred, logger.Logger) +} + +func TestFenceClient_NewAccessToken(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + token, err := client.NewAccessToken(context.Background()) + if err != nil { + t.Fatalf("NewAccessToken error: %v", err) + } + if token != "new-access-token" { + t.Errorf("expected token new-access-token, got %s", token) + } +} + +func TestFenceClient_CheckPrivileges(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + privs, err := client.CheckPrivileges(context.Background()) + if err != nil { + t.Fatalf("CheckPrivileges error: %v", err) + } + if _, ok := privs["/resource"]; !ok { + t.Errorf("expected /resource privilege") + } +} + +func TestFenceClient_CheckForShepherdAPI(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + cred := &conf.Credential{ + APIEndpoint: server.URL, + UseShepherd: "true", + } + logger, _ := logs.New("test") + req := request.NewRequestInterface(logger, cred, conf.NewConfigure(logger.Logger)) + client := NewFenceClient(req, cred, logger.Logger) + + hasShepherd, err := client.CheckForShepherdAPI(context.Background()) + if err != nil { + t.Fatalf("CheckForShepherdAPI error: %v", err) + } + if !hasShepherd { + t.Errorf("expected Shepherd to be detected") + } +} + +func TestFenceClient_DeleteRecord(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + // Test Fence fallback (shepherd check returns false or handled by mock behavior) + msg, err := client.DeleteRecord(context.Background(), "guid-1") + if err != nil { + t.Fatalf("DeleteRecord error: %v", err) + } + if !strings.Contains(msg, "has been deleted") { + t.Errorf("unexpected message: %s", msg) + } +} + +func TestFenceClient_GetBucketDetails(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + info, err := client.GetBucketDetails(context.Background(), "test-bucket") + if err != nil { + t.Fatalf("GetBucketDetails error: %v", err) + } + if info.Region != "us-east-1" { + t.Errorf("expected region us-east-1, got %s", info.Region) + } + + info, err = client.GetBucketDetails(context.Background(), "unknown-bucket") + if err != nil { + t.Fatalf("unexpected error for unknown bucket: %v", err) + } + if info != nil { + t.Errorf("expected nil info for unknown bucket") + } +} + +func TestFenceClient_UploadFlow(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + resp, err := client.InitUpload(context.Background(), "file.txt", "bucket", "") + if err != nil { + t.Fatalf("InitUpload error: %v", err) + } + if resp.URL != "https://upload.url" { + t.Errorf("expected upload URL, got %s", resp.URL) + } + + resp, err = client.GetUploadPresignedUrl(context.Background(), "guid-1", "file.txt", "bucket") + if err != nil { + t.Fatalf("GetUploadPresignedUrl error: %v", err) + } + if resp.URL != "https://upload.url" { + t.Errorf("expected upload URL, got %s", resp.URL) + } +} + +func TestFenceClient_GetDownloadPresignedUrl_Fence(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + url, err := client.GetDownloadPresignedUrl(context.Background(), "guid-1", "") + if err != nil { + t.Fatalf("GetDownloadPresignedUrl error: %v", err) + } + if url != "https://download.url" { + t.Errorf("expected download URL, got %s", url) + } +} + +func TestFenceClient_UserPing(t *testing.T) { + mock := &mockFenceServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + resp, err := client.UserPing(context.Background()) + if err != nil { + t.Fatalf("UserPing error: %v", err) + } + + if resp.Username != "test-user" { + t.Errorf("expected username test-user, got %s", resp.Username) + } + + if _, ok := resp.YourAccess["/resource"]; !ok { + t.Errorf("expected /resource access") + } + + if resp.BucketPrograms["test-bucket"] != "" { + // Our mock for /user/data/buckets returns a bucket but no programs by default unless we update it + // In my update to types.go, I added Programs to S3Bucket. + } +} diff --git a/fence/types.go b/fence/types.go new file mode 100644 index 0000000..2352dbb --- /dev/null +++ b/fence/types.go @@ -0,0 +1,93 @@ +package fence + +// MultipartPart represents a part of a multipart upload +type MultipartPart struct { + PartNumber int `json:"PartNumber"` + ETag string `json:"ETag"` +} + +// FenceResponse represents the standard response from Fence data endpoints +type FenceResponse struct { + URL string `json:"url"` + UploadURL string `json:"upload_url"` // Alias found in some Fence versions + GUID string `json:"guid"` + UploadID string `json:"uploadId"` + PresignedURL string `json:"presigned_url"` + FileName string `json:"file_name"` + URLs []string `json:"urls"` + Size int64 `json:"size"` +} + +// InitRequestObject represents the payload for initializing an upload +type InitRequestObject struct { + Filename string `json:"file_name"` + Bucket string `json:"bucket,omitempty"` + GUID string `json:"guid,omitempty"` +} + +// MultipartUploadRequestObject represents the payload for getting a presigned URL for a part +type MultipartUploadRequestObject struct { + Key string `json:"key"` + UploadID string `json:"uploadId"` + PartNumber int `json:"partNumber"` + Bucket string `json:"bucket,omitempty"` +} + +// MultipartCompleteRequestObject represents the payload for completing a multipart upload +type MultipartCompleteRequestObject struct { + Key string `json:"key"` + UploadID string `json:"uploadId"` + Parts []MultipartPart `json:"parts"` + Bucket string `json:"bucket,omitempty"` +} + +type S3Bucket struct { + EndpointURL string `json:"endpoint_url"` + Programs []string `json:"programs,omitempty"` + Region string `json:"region"` +} + +type S3BucketsResponse struct { + GSBuckets map[string]any `json:"GS_BUCKETS,omitempty"` + S3Buckets map[string]*S3Bucket `json:"S3_BUCKETS,omitempty"` + // Some versions of fence use lowercase + S3BucketsLower map[string]*S3Bucket `json:"s3_buckets,omitempty"` +} + +type UserPermission struct { + Method string `json:"method"` + Service string `json:"service"` +} + +type FenceUserResp struct { + Active bool `json:"active"` + Authz map[string][]UserPermission `json:"authz"` + Azp *string `json:"azp"` + CertificatesUploaded []any `json:"certificates_uploaded"` + DisplayName string `json:"display_name"` + Email string `json:"email"` + Ga4GhPassportV1 []any `json:"ga4gh_passport_v1"` + Groups []any `json:"groups"` + Idp string `json:"idp"` + IsAdmin bool `json:"is_admin"` + Message string `json:"message"` + Name string `json:"name"` + PhoneNumber string `json:"phone_number"` + PreferredUsername string `json:"preferred_username"` + PrimaryGoogleServiceAccount *string `json:"primary_google_service_account"` + ProjectAccess map[string]any `json:"project_access"` + Resources []string `json:"resources"` + ResourcesGranted []any `json:"resources_granted"` + Role string `json:"role"` + Sub string `json:"sub"` + UserID int `json:"user_id"` + Username string `json:"username"` +} + +type PingResp struct { + Profile string `yaml:"profile" json:"profile"` + Username string `yaml:"username" json:"username"` + Endpoint string `yaml:"endpoint" json:"endpoint"` + BucketPrograms map[string]string `yaml:"bucket_programs" json:"bucket_programs"` + YourAccess map[string]string `yaml:"your_access" json:"your_access"` +} diff --git a/g3client/client.go b/g3client/client.go new file mode 100644 index 0000000..2741aa1 --- /dev/null +++ b/g3client/client.go @@ -0,0 +1,246 @@ +package g3client + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/indexd" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/request" + "github.com/calypr/data-client/requestor" + "github.com/calypr/data-client/sower" + version "github.com/hashicorp/go-version" +) + +//go:generate mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/g3client Gen3Interface + +type Gen3Interface interface { + GetCredential() *conf.Credential + Logger() *logs.Gen3Logger + ExportCredential(ctx context.Context, cred *conf.Credential) error + Fence() fence.FenceInterface + Indexd() indexd.IndexdInterface + Sower() sower.SowerInterface + Requestor() requestor.RequestorInterface +} + +func NewGen3InterfaceFromCredential(cred *conf.Credential, logger *logs.Gen3Logger, opts ...Option) Gen3Interface { + config := conf.NewConfigure(logger.Logger) + reqInterface := request.NewRequestInterface(logger, cred, config) + + client := &Gen3Client{ + config: config, + RequestInterface: reqInterface, + credential: cred, + logger: logger, + } + + for _, opt := range opts { + opt(client) + } + + client.initializeClients() + + return client +} + +func (g *Gen3Client) initializeClients() { + shouldInit := func(ct ClientType) bool { + if len(g.requestedClients) == 0 { + return true + } + for _, c := range g.requestedClients { + if c == ct { + return true + } + } + return false + } + + if shouldInit(FenceClient) { + g.fence = fence.NewFenceClient(g.RequestInterface, g.credential, g.logger.Logger) + } + if shouldInit(IndexdClient) { + g.indexd = indexd.NewIndexdClient(g.RequestInterface, g.credential, g.logger.Logger) + } + if shouldInit(SowerClient) { + g.sower = sower.NewSowerClient(g.RequestInterface, g.credential.APIEndpoint) + } + if shouldInit(RequestorClient) { + g.requestor = requestor.NewRequestorClient(g.RequestInterface, g.credential) + } +} + +type Gen3Client struct { + Ctx context.Context + fence fence.FenceInterface + indexd indexd.IndexdInterface + sower sower.SowerInterface + requestor requestor.RequestorInterface + config conf.ManagerInterface + request.RequestInterface + + credential *conf.Credential + logger *logs.Gen3Logger + + requestedClients []ClientType +} + +type ClientType string + +const ( + FenceClient ClientType = "fence" + IndexdClient ClientType = "indexd" + SowerClient ClientType = "sower" + RequestorClient ClientType = "requestor" +) + +type Option func(*Gen3Client) + +func WithClients(clients ...ClientType) Option { + return func(g *Gen3Client) { + g.requestedClients = clients + } +} + +func (g *Gen3Client) Fence() fence.FenceInterface { + return g.fence +} + +func (g *Gen3Client) Indexd() indexd.IndexdInterface { + return g.indexd +} + +func (g *Gen3Client) Sower() sower.SowerInterface { + return g.sower +} + +func (g *Gen3Client) Requestor() requestor.RequestorInterface { + return g.requestor +} + +func (g *Gen3Client) Logger() *logs.Gen3Logger { + return g.logger +} + +func (g *Gen3Client) GetCredential() *conf.Credential { + return g.credential +} + +func (g *Gen3Client) ExportCredential(ctx context.Context, cred *conf.Credential) error { + if cred.Profile == "" { + return fmt.Errorf("profile name is required") + } + if cred.APIEndpoint == "" { + return fmt.Errorf("API endpoint is required") + } + + // Normalize endpoint + cred.APIEndpoint = strings.TrimSpace(cred.APIEndpoint) + cred.APIEndpoint = strings.TrimSuffix(cred.APIEndpoint, "/") + + // Validate URL format + parsedURL, err := conf.ValidateUrl(cred.APIEndpoint) + if err != nil { + return fmt.Errorf("invalid apiendpoint URL: %w", err) + } + fenceBase := parsedURL.Scheme + "://" + parsedURL.Host + if _, err := g.config.Load(cred.Profile); err != nil && !errors.Is(err, conf.ErrProfileNotFound) { + return err + } + + if cred.APIKey != "" { + // Always refresh the access token — ignore any old one that might be in the struct + token, err := g.fence.NewAccessToken(ctx) + if err != nil { + if strings.Contains(err.Error(), "401") { + return fmt.Errorf("authentication failed (401) for %s — your API key is invalid, revoked, or expired", fenceBase) + } + if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "no such host") { + return fmt.Errorf("cannot reach Fence at %s — is this a valid Gen3 commons?", fenceBase) + } + return fmt.Errorf("failed to refresh access token: %w", err) + } + g.credential.AccessToken = token + } else { + g.logger.Warn("WARNING: Your profile will only be valid for 24 hours since you have only provided a refresh token for authentication") + } + + // Clean up shepherd flags + cred.UseShepherd = strings.TrimSpace(cred.UseShepherd) + cred.MinShepherdVersion = strings.TrimSpace(cred.MinShepherdVersion) + + if cred.MinShepherdVersion != "" { + if _, err = version.NewVersion(cred.MinShepherdVersion); err != nil { + return fmt.Errorf("invalid min-shepherd-version: %w", err) + } + } + + if err := g.config.Save(cred); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + + return nil +} + +// EnsureValidCredential checks if the credential is valid and refreshes it if the access token is expired but the API key is valid. +// It accepts an optional fClient; if nil, it will initialize one internally if needed for refresh. +func EnsureValidCredential(ctx context.Context, cred *conf.Credential, config conf.ManagerInterface, logger *logs.Gen3Logger, fClient fence.FenceInterface) error { + if valid, err := config.IsCredentialValid(cred); !valid { + if strings.Contains(err.Error(), "access_token is invalid but api_key is valid") { + // Try to refresh the token + if fClient == nil { + reqInterface := request.NewRequestInterface(logger, cred, config) + fClient = fence.NewFenceClient(reqInterface, cred, logger.Logger) + } + newToken, refreshErr := fClient.NewAccessToken(ctx) + if refreshErr == nil { + cred.AccessToken = newToken + err = config.Save(cred) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to save refreshed token: %v", err)) + } + return nil + } + return fmt.Errorf("failed to refresh access token: %v (original error: %v)", refreshErr, err) + } + return fmt.Errorf("invalid credential: %v", err) + } + return nil +} + +// NewGen3Interface returns a Gen3Client that embeds the credential and implements Gen3Interface. +func NewGen3Interface(profile string, logger *logs.Gen3Logger, opts ...Option) (Gen3Interface, error) { + config := conf.NewConfigure(logger.Logger) + cred, err := config.Load(profile) + if err != nil { + return nil, err + } + + reqInterface := request.NewRequestInterface(logger, cred, config) + + // We need a temporary Fence client to refresh tokens if needed + fClient := fence.NewFenceClient(reqInterface, cred, logger.Logger) + if err := EnsureValidCredential(context.Background(), cred, config, logger, fClient); err != nil { + return nil, err + } + + client := &Gen3Client{ + config: config, + RequestInterface: reqInterface, + credential: cred, + logger: logger, + } + + for _, opt := range opts { + opt(client) + } + + client.initializeClients() + + return client, nil +} diff --git a/go.mod b/go.mod index a40c2e0..c39b763 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,12 @@ module github.com/calypr/data-client go 1.24.2 require ( + github.com/aws/aws-sdk-go-v2 v1.41.1 + github.com/aws/aws-sdk-go-v2/config v1.32.7 + github.com/aws/aws-sdk-go-v2/credentials v1.19.7 + github.com/aws/aws-sdk-go-v2/service/s3 v1.95.1 github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/go-version v1.8.0 @@ -12,11 +17,27 @@ require ( go.uber.org/mock v0.6.0 golang.org/x/sync v0.19.0 gopkg.in/ini.v1 v1.67.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.17 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect + github.com/aws/smithy-go v1.24.0 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -24,5 +45,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/stretchr/testify v1.11.1 // indirect golang.org/x/sys v0.39.0 // indirect ) diff --git a/go.sum b/go.sum index 57dfebd..d4cffb0 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,44 @@ github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1o github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= +github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= +github.com/aws/aws-sdk-go-v2 v1.41.1/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 h1:489krEF9xIGkOaaX3CE/Be2uWjiXrkCH6gUX+bZA/BU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4/go.mod h1:IOAPF6oT9KCsceNTvvYMNHy0+kMF8akOjeDvPENWxp4= +github.com/aws/aws-sdk-go-v2/config v1.32.7 h1:vxUyWGUwmkQ2g19n7JY/9YL8MfAIl7bTesIUykECXmY= +github.com/aws/aws-sdk-go-v2/config v1.32.7/go.mod h1:2/Qm5vKUU/r7Y+zUk/Ptt2MDAEKAfUtKc1+3U1Mo3oY= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7 h1:tHK47VqqtJxOymRrNtUXN5SP/zUTvZKeLx4tH6PGQc8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7/go.mod h1:qOZk8sPDrxhf+4Wf4oT2urYJrYt3RejHSzgAquYeppw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 h1:I0GyV8wiYrP8XpA70g1HBcQO1JlQxCMTW9npl5UbDHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17/go.mod h1:tyw7BOl5bBe/oqvoIeECFJjMdzXoa/dfVz3QQ5lgHGA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 h1:xOLELNKGp2vsiteLsvLPwxC+mYmO6OZ8PYgiuPJzF8U= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17/go.mod h1:5M5CI3D12dNOtH3/mk6minaRwI2/37ifCURZISxA/IQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 h1:WWLqlh79iO48yLkj1v3ISRNiv+3KdQoZ6JWyfcsyQik= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17/go.mod h1:EhG22vHRrvF8oXSTYStZhJc1aUgKtnJe+aOiFEV90cM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17 h1:JqcdRG//czea7Ppjb+g/n4o8i/R50aTBHkA7vu0lK+k= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17/go.mod h1:CO+WeGmIdj/MlPel2KwID9Gt7CNq4M65HUfBW97liM0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8 h1:Z5EiPIzXKewUQK0QTMkutjiaPVeVYXX7KIqhXu/0fXs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8/go.mod h1:FsTpJtvC4U1fyDXk7c71XoDv3HlRm8V3NiYLeYLh5YE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 h1:RuNSMoozM8oXlgLG/n6WLaFGoea7/CddrCfIiSA+xdY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17/go.mod h1:F2xxQ9TZz5gDWsclCtPQscGpP0VUOc8RqgFM3vDENmU= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.17 h1:bGeHBsGZx0Dvu/eJC0Lh9adJa3M1xREcndxLNZlve2U= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.17/go.mod h1:dcW24lbU0CzHusTE8LLHhRLI42ejmINN8Lcr22bwh/g= +github.com/aws/aws-sdk-go-v2/service/s3 v1.95.1 h1:C2dUPSnEpy4voWFIq3JNd8gN0Y5vYGDo44eUE58a/p8= +github.com/aws/aws-sdk-go-v2/service/s3 v1.95.1/go.mod h1:5jggDlZ2CLQhwJBiZJb4vfk4f0GxWdEDruWKEJ1xOdo= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 h1:VrhDvQib/i0lxvr3zqlUwLwJP4fpmpyD9wYG1vfSu+Y= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5/go.mod h1:k029+U8SY30/3/ras4G/Fnv/b88N4mAfliNn08Dem4M= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 h1:v6EiMvhEYBoHABfbGB4alOYmCIrcgyPPiBE1wZAEbqk= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9/go.mod h1:yifAsgBxgJWn3ggx70A3urX2AN49Y5sJTD1UQFlfqBw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 h1:gd84Omyu9JLriJVCbGApcLzVR3XtmC4ZDPcAI6Ftvds= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13/go.mod h1:sTGThjphYE4Ohw8vJiRStAcu3rbjtXRsdNB0TvZ5wwo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/pJ1jOWYlFDJTjRQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= @@ -13,6 +51,8 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -42,8 +82,8 @@ github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiT github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/vbauerster/mpb/v8 v8.11.2 h1:OqLoHznUVU7SKS/WV+1dB5/hm20YLheYupiHhL5+M1Y= github.com/vbauerster/mpb/v8 v8.11.2/go.mod h1:mEB/M353al1a7wMUNtiymmPsEkGlJgeJmtlbY5adCJ8= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= @@ -53,6 +93,7 @@ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= diff --git a/indexd/add_url.go b/indexd/add_url.go new file mode 100644 index 0000000..af85298 --- /dev/null +++ b/indexd/add_url.go @@ -0,0 +1,106 @@ +package indexd + +import ( + "context" + "fmt" + "slices" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/indexd/drs" +) + +// UpsertIndexdRecord creates or updates an indexd record with a new URL. +func (c *IndexdClient) UpsertIndexdRecord(ctx context.Context, url string, sha256 string, fileSize int64, projectId string) (*drs.DRSObject, error) { + uuid := drs.DrsUUID(projectId, sha256) + + records, err := c.GetObjectByHash(ctx, "sha256", sha256) + if err != nil { + return nil, fmt.Errorf("error querying indexd server: %v", err) + } + + var matchingRecord *drs.DRSObject + for i := range records { + if records[i].Id == uuid { + matchingRecord = &records[i] + break + } + } + + if matchingRecord != nil { + existingURLs := IndexdURLFromDrsAccessURLs(matchingRecord.AccessMethods) + if slices.Contains(existingURLs, url) { + c.logger.Debug("Nothing to do: file already registered") + return matchingRecord, nil + } + + c.logger.Debug("updating existing record with new url") + updatedRecord := drs.DRSObject{AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: url}}}} + return c.UpdateRecord(ctx, &updatedRecord, matchingRecord.Id) + } + + // If no record exists, create one + c.logger.Debug("creating new record") + _, key, err := ParseS3URL(url) + if err != nil { + return nil, err + } + + drsObj, err := drs.BuildDrsObj(key, sha256, fileSize, uuid, "placeholder-bucket", projectId) + if err != nil { + return nil, err + } + + return c.RegisterRecord(ctx, drsObj) +} + +// AddURL implements the AddURL logic ported from git-drs. +func (c *IndexdClient) AddURL( + ctx context.Context, + fClient fence.FenceInterface, + s3URL string, + sha256 string, + awsAccessKey string, + awsSecretKey string, + region string, + endpoint string, + s3Client *s3.Client, +) (S3Meta, error) { + if err := ValidateInputs(s3URL, sha256); err != nil { + return S3Meta{}, err + } + + bucket, _, err := ParseS3URL(s3URL) + if err != nil { + return S3Meta{}, err + } + + var bucketDetails *fence.S3Bucket + if fClient != nil { + bucketDetails, err = fClient.GetBucketDetails(ctx, bucket) + if err != nil { + c.logger.Debug(fmt.Sprintf("Warning: unable to get bucket details from Gen3: %v", err)) + } + } + + size, modifiedDate, err := FetchS3MetadataWithBucketDetails( + ctx, s3URL, awsAccessKey, awsSecretKey, region, endpoint, bucketDetails, s3Client, c.logger, + ) + if err != nil { + return S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) + } + + // This part needs project ID. In git-drs it was in the client config. + projectId := "unknown-project" + // ... (logic to get project ID) + + _, err = c.UpsertIndexdRecord(ctx, s3URL, sha256, size, projectId) + if err != nil { + return S3Meta{}, fmt.Errorf("failed to upsert indexd record: %w", err) + } + + return S3Meta{ + Size: size, + LastModified: modifiedDate, + }, nil +} diff --git a/indexd/client.go b/indexd/client.go new file mode 100644 index 0000000..29ea378 --- /dev/null +++ b/indexd/client.go @@ -0,0 +1,515 @@ +package indexd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/request" +) + +//go:generate mockgen -destination=../mocks/mock_indexd.go -package=mocks github.com/calypr/data-client/indexd IndexdInterface + +// IndexdInterface defines the interface for Indexd client +type IndexdInterface interface { + request.RequestInterface + + GetObject(ctx context.Context, id string) (*drs.DRSObject, error) + RegisterIndexdRecord(ctx context.Context, indexdObj *IndexdRecord) (*drs.DRSObject, error) + DeleteIndexdRecord(ctx context.Context, did string) error + GetObjectByHash(ctx context.Context, hashType, hashValue string) ([]drs.DRSObject, error) + GetDownloadURL(ctx context.Context, did string, accessType string) (*drs.AccessURL, error) + ListObjectsByProject(ctx context.Context, projectId string) (chan drs.DRSObjectResult, error) + UpdateRecord(ctx context.Context, updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) + + ListObjects(ctx context.Context) (chan drs.DRSObjectResult, error) + GetProjectSample(ctx context.Context, projectId string, limit int) ([]drs.DRSObject, error) + DeleteRecordsByProject(ctx context.Context, projectId string) error + DeleteRecordByHash(ctx context.Context, hashValue string, projectId string) error + RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) + UpsertIndexdRecord(ctx context.Context, url string, sha256 string, fileSize int64, projectId string) (*drs.DRSObject, error) + AddURL(ctx context.Context, fClient fence.FenceInterface, s3URL, sha256, awsAccessKey, awsSecretKey, region, endpoint string, s3Client *s3.Client) (S3Meta, error) +} + +// IndexdClient implements IndexdInterface +type IndexdClient struct { + request.RequestInterface + cred *conf.Credential + logger *slog.Logger +} + +// NewIndexdClient creates a new IndexdClient +func NewIndexdClient(req request.RequestInterface, cred *conf.Credential, logger *slog.Logger) IndexdInterface { + return &IndexdClient{ + RequestInterface: req, + cred: cred, + logger: logger, + } +} + +func (c *IndexdClient) GetObject(ctx context.Context, id string) (*drs.DRSObject, error) { + url := fmt.Sprintf("%s/ga4gh/drs/v1/objects/%s", c.cred.APIEndpoint, id) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("object %s not found", id) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get object %s: %s (status: %d)", id, string(body), resp.StatusCode) + } + + var out OutputObject + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + return ConvertOutputObjectToDRSObject(&out), nil +} + +func (c *IndexdClient) RegisterIndexdRecord(ctx context.Context, indexdObj *IndexdRecord) (*drs.DRSObject, error) { + indexdObjForm := IndexdRecordForm{ + IndexdRecord: *indexdObj, + Form: "object", + } + + jsonBytes, err := json.Marshal(indexdObjForm) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/index/index", c.cred.APIEndpoint) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodPost, + Url: url, + Body: bytes.NewBuffer(jsonBytes), + Headers: map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + }, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to register record %s: %s (status: %d)", indexdObj.Did, string(body), resp.StatusCode) + } + + return IndexdRecordToDrsObject(indexdObj) +} + +func (c *IndexdClient) DeleteIndexdRecord(ctx context.Context, did string) error { + // First get the record to get the revision (rev) + record, err := c.getIndexdRecordByDID(ctx, did) + if err != nil { + return err + } + + url := fmt.Sprintf("%s/index/index/%s?rev=%s", c.cred.APIEndpoint, did, record.Rev) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodDelete, + Url: url, + Headers: map[string]string{ + "Accept": "application/json", + }, + Token: c.cred.AccessToken, + }) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to delete record %s: %s (status: %d)", did, string(body), resp.StatusCode) + } + + return nil +} + +func (c *IndexdClient) getIndexdRecordByDID(ctx context.Context, did string) (*OutputInfo, error) { + url := fmt.Sprintf("%s/index/index/%s", c.cred.APIEndpoint, did) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get indexd record %s: %s (status: %d)", did, string(body), resp.StatusCode) + } + + var info OutputInfo + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return nil, err + } + return &info, nil +} + +func (c *IndexdClient) GetObjectByHash(ctx context.Context, hashType, hashValue string) ([]drs.DRSObject, error) { + url := fmt.Sprintf("%s/index/index?hash=%s:%s", c.cred.APIEndpoint, hashType, hashValue) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Headers: map[string]string{ + "Accept": "application/json", + }, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to query by hash %s:%s: %s (status: %d)", hashType, hashValue, string(body), resp.StatusCode) + } + + var records ListRecords + if err := json.NewDecoder(resp.Body).Decode(&records); err != nil { + return nil, err + } + + out := make([]drs.DRSObject, 0, len(records.Records)) + for _, r := range records.Records { + drsObj, err := IndexdRecordToDrsObject(r.ToIndexdRecord()) + if err != nil { + return nil, err + } + out = append(out, *drsObj) + } + return out, nil +} + +func (c *IndexdClient) GetDownloadURL(ctx context.Context, did string, accessType string) (*drs.AccessURL, error) { + url := fmt.Sprintf("%s/ga4gh/drs/v1/objects/%s/access/%s", c.cred.APIEndpoint, did, accessType) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get download URL for %s: %s (status: %d)", did, string(body), resp.StatusCode) + } + + var accessURL drs.AccessURL + if err := json.NewDecoder(resp.Body).Decode(&accessURL); err != nil { + return nil, err + } + return &accessURL, nil +} + +func (c *IndexdClient) ListObjectsByProject(ctx context.Context, projectId string) (chan drs.DRSObjectResult, error) { + const PAGESIZE = 50 + + resourcePath, err := drs.ProjectToResource(projectId) + if err != nil { + return nil, err + } + + out := make(chan drs.DRSObjectResult, PAGESIZE) + + go func() { + defer close(out) + pageNum := 0 + active := true + + for active { + url := fmt.Sprintf("%s/index/index?authz=%s&limit=%d&page=%d", + c.cred.APIEndpoint, resourcePath, PAGESIZE, pageNum) + + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: url, + Headers: map[string]string{ + "Accept": "application/json", + }, + Token: c.cred.AccessToken, + }) + + if err != nil { + out <- drs.DRSObjectResult{Error: err} + break + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + out <- drs.DRSObjectResult{Error: fmt.Errorf("api error %d: %s", resp.StatusCode, string(body))} + break + } + + var page ListRecords + err = json.NewDecoder(resp.Body).Decode(&page) + resp.Body.Close() + + if err != nil { + out <- drs.DRSObjectResult{Error: err} + break + } + + if len(page.Records) == 0 { + active = false + break + } + + for _, elem := range page.Records { + drsObj, err := elem.ToIndexdRecord().ToDrsObject() + if err != nil { + out <- drs.DRSObjectResult{Error: err} + continue + } + out <- drs.DRSObjectResult{Object: drsObj} + } + pageNum++ + } + }() + + return out, nil +} + +func (c *IndexdClient) UpdateRecord(ctx context.Context, updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { + // Get current revision from existing record + record, err := c.getIndexdRecordByDID(ctx, did) + if err != nil { + return nil, fmt.Errorf("could not retrieve existing record for DID %s: %v", did, err) + } + + // Build update payload starting with existing record values + updatePayload := UpdateInputInfo{ + URLs: record.URLs, + FileName: record.FileName, + Version: record.Version, + Authz: record.Authz, + ACL: record.ACL, + Metadata: record.Metadata, + } + + // Apply updates from updateInfo + if len(updateInfo.AccessMethods) > 0 { + newURLs := make([]string, 0, len(updateInfo.AccessMethods)) + for _, a := range updateInfo.AccessMethods { + newURLs = append(newURLs, a.AccessURL.URL) + } + updatePayload.URLs = appendUnique(updatePayload.URLs, newURLs) + + authz := IndexdAuthzFromDrsAccessMethods(updateInfo.AccessMethods) + updatePayload.Authz = appendUnique(updatePayload.Authz, authz) + } + + if updateInfo.Name != "" { + updatePayload.FileName = updateInfo.Name + } + + if updateInfo.Version != "" { + updatePayload.Version = updateInfo.Version + } + + if updateInfo.Description != "" { + if updatePayload.Metadata == nil { + updatePayload.Metadata = make(map[string]any) + } + updatePayload.Metadata["description"] = updateInfo.Description + } + + jsonBytes, err := json.Marshal(updatePayload) + if err != nil { + return nil, fmt.Errorf("error marshaling indexd update payload: %v", err) + } + + url := fmt.Sprintf("%s/index/index/%s?rev=%s", c.cred.APIEndpoint, did, record.Rev) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodPut, + Url: url, + Body: bytes.NewBuffer(jsonBytes), + Headers: map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + }, + Token: c.cred.AccessToken, + }) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to update record %s: %s (status: %d)", did, string(body), resp.StatusCode) + } + + return c.GetObject(ctx, did) +} + +func (c *IndexdClient) ListObjects(ctx context.Context) (chan drs.DRSObjectResult, error) { + url := fmt.Sprintf("%s/ga4gh/drs/v1/objects", c.cred.APIEndpoint) + const PAGESIZE = 50 + out := make(chan drs.DRSObjectResult, 10) + + go func() { + defer close(out) + pageNum := 0 + active := true + for active { + fullURL := fmt.Sprintf("%s?limit=%d&page=%d", url, PAGESIZE, pageNum) + resp, err := c.Do(ctx, &request.RequestBuilder{ + Method: http.MethodGet, + Url: fullURL, + Token: c.cred.AccessToken, + }) + + if err != nil { + out <- drs.DRSObjectResult{Error: err} + return + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + out <- drs.DRSObjectResult{Error: fmt.Errorf("api error %d: %s", resp.StatusCode, string(body))} + return + } + + var page drs.DRSPage + err = json.NewDecoder(resp.Body).Decode(&page) + resp.Body.Close() + + if err != nil { + out <- drs.DRSObjectResult{Error: err} + return + } + + if len(page.DRSObjects) == 0 { + active = false + break + } + + for _, elem := range page.DRSObjects { + out <- drs.DRSObjectResult{Object: &elem} + } + pageNum++ + } + }() + return out, nil +} + +func (c *IndexdClient) GetProjectSample(ctx context.Context, projectId string, limit int) ([]drs.DRSObject, error) { + if limit <= 0 { + limit = 1 + } + + objChan, err := c.ListObjectsByProject(ctx, projectId) + if err != nil { + return nil, err + } + + result := make([]drs.DRSObject, 0, limit) + for objResult := range objChan { + if objResult.Error != nil { + return nil, objResult.Error + } + result = append(result, *objResult.Object) + + if len(result) >= limit { + go func() { + for range objChan { + } + }() + break + } + } + + return result, nil +} + +func (c *IndexdClient) DeleteRecordsByProject(ctx context.Context, projectId string) error { + recs, err := c.ListObjectsByProject(ctx, projectId) + if err != nil { + return err + } + for rec := range recs { + if rec.Error != nil { + return rec.Error + } + err := c.DeleteIndexdRecord(ctx, rec.Object.Id) + if err != nil { + c.logger.Error(fmt.Sprintf("DeleteRecordsByProject Error for %s: %v", rec.Object.Id, err)) + continue + } + } + return nil +} + +func (c *IndexdClient) DeleteRecordByHash(ctx context.Context, hashValue string, projectId string) error { + records, err := c.GetObjectByHash(ctx, "sha256", hashValue) + if err != nil { + return fmt.Errorf("error getting records for hash %s: %v", hashValue, err) + } + if len(records) == 0 { + return fmt.Errorf("no records found for hash %s", hashValue) + } + + matchingRecord, err := drs.FindMatchingRecord(records, projectId) + if err != nil { + return fmt.Errorf("error finding matching record for project %s: %v", projectId, err) + } + if matchingRecord == nil { + return fmt.Errorf("no matching record found for project %s", projectId) + } + + return c.DeleteIndexdRecord(ctx, matchingRecord.Id) +} + +func (c *IndexdClient) RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) { + indexdRecord, err := IndexdRecordFromDrsObject(record) + if err != nil { + return nil, fmt.Errorf("error converting DRS object to indexd record: %v", err) + } + + return c.RegisterIndexdRecord(ctx, indexdRecord) +} + +func appendUnique(existing []string, toAdd []string) []string { + seen := make(map[string]bool) + for _, v := range existing { + seen[v] = true + } + for _, v := range toAdd { + if !seen[v] { + existing = append(existing, v) + seen[v] = true + } + } + return existing +} diff --git a/indexd/client_test.go b/indexd/client_test.go new file mode 100644 index 0000000..2fb76ed --- /dev/null +++ b/indexd/client_test.go @@ -0,0 +1,266 @@ +package indexd + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/calypr/data-client/conf" + drs "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/request" +) + +type mockIndexdServer struct { + mu sync.Mutex + listProjectPages int + listObjectsPages int + lastUpdatePayload UpdateInputInfo +} + +func (m *mockIndexdServer) handler(t *testing.T) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + switch { + case r.Method == http.MethodGet && path == "/index/index": + if hashQuery := r.URL.Query().Get("hash"); hashQuery != "" { + record := sampleOutputInfo() + page := ListRecords{Records: []OutputInfo{record}} + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(page) + return + } + if r.URL.Query().Get("authz") != "" { + m.mu.Lock() + page := m.listProjectPages + m.listProjectPages++ + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + if page == 0 { + _ = json.NewEncoder(w).Encode(ListRecords{Records: []OutputInfo{sampleOutputInfo()}}) + } else { + _ = json.NewEncoder(w).Encode(ListRecords{Records: []OutputInfo{}}) + } + return + } + + case r.Method == http.MethodPost && path == "/index/index": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"did":"did-1"}`)) + return + case r.Method == http.MethodGet && strings.HasPrefix(path, "/ga4gh/drs/v1/objects"): + if path == "/ga4gh/drs/v1/objects" { + m.mu.Lock() + page := m.listObjectsPages + m.listObjectsPages++ + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + if page == 0 { + _ = json.NewEncoder(w).Encode(drs.DRSPage{DRSObjects: []drs.DRSObject{sampleDRSObject()}}) + } else { + _ = json.NewEncoder(w).Encode(drs.DRSPage{DRSObjects: []drs.DRSObject{}}) + } + return + } + obj := sampleOutputObject() + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(obj) + return + case r.Method == http.MethodGet && strings.HasPrefix(path, "/index/index/"): + record := sampleOutputInfo() + record.Rev = "rev-1" + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(record) + return + case r.Method == http.MethodPut && strings.HasPrefix(path, "/index/index/"): + body, _ := io.ReadAll(r.Body) + payload := UpdateInputInfo{} + _ = json.Unmarshal(body, &payload) + m.mu.Lock() + m.lastUpdatePayload = payload + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + return + case r.Method == http.MethodDelete && strings.HasPrefix(path, "/index/index/"): + w.WriteHeader(http.StatusNoContent) + return + } + w.WriteHeader(http.StatusNotFound) + } +} + +func sampleOutputInfo() OutputInfo { + return OutputInfo{ + Did: "did-1", + FileName: "file.txt", + URLs: []string{"s3://bucket/key"}, + Authz: []string{"/programs/test/projects/proj"}, + Hashes: hash.HashInfo{SHA256: "sha-256"}, + Size: 123, + } +} + +func sampleDRSObject() drs.DRSObject { + return drs.DRSObject{ + Id: "did-1", + Name: "file.txt", + Size: 123, + Checksums: hash.HashInfo{ + SHA256: "sha-256", + }, + AccessMethods: []drs.AccessMethod{ + { + Type: "s3", + AccessURL: drs.AccessURL{URL: "s3://bucket/key"}, + Authorizations: &drs.Authorizations{Value: "/programs/test/projects/proj"}, + }, + }, + } +} + +func sampleOutputObject() OutputObject { + return OutputObject{ + Id: "did-1", + Name: "file.txt", + Size: 123, + Checksums: []hash.Checksum{ + {Checksum: "sha-256", Type: hash.ChecksumTypeSHA256}, + }, + } +} + +func newTestClient(server *httptest.Server) IndexdInterface { + cred := &conf.Credential{APIEndpoint: server.URL, Profile: "test", AccessToken: "test-token"} + logger, _ := logs.New("test") + config := conf.NewConfigure(logger.Logger) + req := request.NewRequestInterface(logger, cred, config) + return NewIndexdClient(req, cred, logger.Logger) +} + +func TestIndexdClient_ListAndQueryDirect(t *testing.T) { + mock := &mockIndexdServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + records, err := client.GetObjectByHash(context.Background(), "sha256", "sha-256") + if err != nil { + t.Fatalf("GetObjectByHash error: %v", err) + } + if len(records) != 1 || records[0].Id != "did-1" { + t.Fatalf("unexpected records: %+v", records) + } + + objChan, err := client.ListObjectsByProject(context.Background(), "test-proj") + if err != nil { + t.Fatalf("ListObjectsByProject error: %v", err) + } + var found bool + for res := range objChan { + if res.Error != nil { + t.Fatalf("ListObjectsByProject result error: %v", res.Error) + } + if res.Object != nil && res.Object.Id == "did-1" { + found = true + } + } + if !found { + t.Fatalf("expected object from ListObjectsByProject") + } + + listChan, err := client.ListObjects(context.Background()) + if err != nil { + t.Fatalf("ListObjects error: %v", err) + } + var listCount int + for res := range listChan { + if res.Error != nil { + t.Fatalf("ListObjects result error: %v", res.Error) + } + if res.Object != nil { + listCount++ + } + } + if listCount != 1 { + t.Fatalf("expected 1 object from ListObjects, got %d", listCount) + } +} + +func TestIndexdClient_RegisterAndUpdateDirect(t *testing.T) { + mock := &mockIndexdServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + drsObj := &drs.DRSObject{ + Id: "did-1", + Name: "file.txt", + Size: 123, + Checksums: hash.HashInfo{SHA256: "sha-256"}, + AccessMethods: []drs.AccessMethod{ + { + Type: "s3", + AccessURL: drs.AccessURL{URL: "s3://bucket/key"}, + Authorizations: &drs.Authorizations{Value: "/programs/test/projects/proj"}, + }, + }, + } + + obj, err := client.RegisterRecord(context.Background(), drsObj) + if err != nil { + t.Fatalf("RegisterRecord error: %v", err) + } + if obj.Id != "did-1" { + t.Fatalf("unexpected DRS object: %+v", obj) + } + + update := &drs.DRSObject{ + Name: "file-updated.txt", + Version: "v2", + Description: "updated", + AccessMethods: []drs.AccessMethod{ + { + Type: "s3", + AccessURL: drs.AccessURL{URL: "s3://bucket/other"}, + Authorizations: &drs.Authorizations{Value: "/programs/test/projects/proj"}, + }, + }, + } + + _, err = client.UpdateRecord(context.Background(), update, "did-1") + if err != nil { + t.Fatalf("UpdateRecord error: %v", err) + } + + mock.mu.Lock() + payload := mock.lastUpdatePayload + mock.mu.Unlock() + + if len(payload.URLs) != 2 { + t.Fatalf("expected URLs to include appended entries, got %+v", payload.URLs) + } +} + +func TestIndexdClient_GetObjectDirect(t *testing.T) { + mock := &mockIndexdServer{} + server := httptest.NewServer(mock.handler(t)) + defer server.Close() + + client := newTestClient(server) + + record, err := client.GetObject(context.Background(), "did-1") + if err != nil { + t.Fatalf("GetObject error: %v", err) + } + if record.Id != "did-1" { + t.Fatalf("unexpected record: %+v", record) + } +} diff --git a/indexd/convert.go b/indexd/convert.go new file mode 100644 index 0000000..0fb44d9 --- /dev/null +++ b/indexd/convert.go @@ -0,0 +1,99 @@ +package indexd + +// Conversion functions between drs.DRSObject and IndexdRecord + +import ( + "fmt" + "net/url" + + "github.com/calypr/data-client/indexd/drs" +) + +// IndexdRecordFromDrsObject represents a simplified version of an indexd record for conversion purposes +func IndexdRecordFromDrsObject(drsObj *drs.DRSObject) (*IndexdRecord, error) { + indexdObj := &IndexdRecord{ + Did: drsObj.Id, + Size: drsObj.Size, + FileName: drsObj.Name, + URLs: IndexdURLFromDrsAccessURLs(drsObj.AccessMethods), + Authz: IndexdAuthzFromDrsAccessMethods(drsObj.AccessMethods), + Hashes: drsObj.Checksums, + } + return indexdObj, nil +} + +func IndexdRecordToDrsObject(indexdObj *IndexdRecord) (*drs.DRSObject, error) { + accessMethods, err := DRSAccessMethodsFromIndexdURLs(indexdObj.URLs, indexdObj.Authz) + if err != nil { + return nil, err + } + for _, am := range accessMethods { + if am.Authorizations == nil || am.Authorizations.Value == "" { + return nil, fmt.Errorf("access method missing authorization %v, %v", indexdObj, indexdObj.Authz) + } + } + + return &drs.DRSObject{ + Id: indexdObj.Did, + Size: indexdObj.Size, + Name: indexdObj.FileName, + AccessMethods: accessMethods, + Checksums: indexdObj.Hashes, + }, nil +} + +func DRSAccessMethodsFromIndexdURLs(urls []string, authz []string) ([]drs.AccessMethod, error) { + var accessMethods []drs.AccessMethod + for _, urlString := range urls { + var method drs.AccessMethod + method.AccessURL = drs.AccessURL{URL: urlString} + + parsed, err := url.Parse(urlString) + if err != nil { + return nil, fmt.Errorf("failed to parse url %q: %v", urlString, err) + } + if parsed.Scheme == "" { + // default to https if no scheme or parse error + method.Type = "https" + } else { + method.Type = parsed.Scheme + } + + // check if authz is null or 0-length, then error + if authz == nil { + return nil, fmt.Errorf("authz is required") + } + + // NOTE: a record can only have 1 authz entry atm + method.Authorizations = &drs.Authorizations{Value: authz[0]} + accessMethods = append(accessMethods, method) + } + return accessMethods, nil +} + +// IndexdAuthzFromDrsAccessMethods extracts authz values from DRS access methods +func IndexdAuthzFromDrsAccessMethods(accessMethods []drs.AccessMethod) []string { + var authz []string + for _, drsURL := range accessMethods { + if drsURL.Authorizations != nil { + authz = append(authz, drsURL.Authorizations.Value) + } + } + return authz +} + +func IndexdURLFromDrsAccessURLs(accessMethods []drs.AccessMethod) []string { + var urls []string + for _, drsURL := range accessMethods { + urls = append(urls, drsURL.AccessURL.URL) + } + return urls +} + +func (inr *IndexdRecord) ToDrsObject() (*drs.DRSObject, error) { + o, err := IndexdRecordToDrsObject(inr) + if err != nil { + return nil, err + } + return o, nil +} diff --git a/indexd/drs/drs.go b/indexd/drs/drs.go new file mode 100644 index 0000000..46ea800 --- /dev/null +++ b/indexd/drs/drs.go @@ -0,0 +1,87 @@ +package drs + +import ( + "fmt" + "strings" + + "github.com/calypr/data-client/indexd/hash" + "github.com/google/uuid" +) + +// NAMESPACE is the UUID namespace used for generating DRS UUIDs +var NAMESPACE = uuid.NewMD5(uuid.NameSpaceURL, []byte("calypr.org")) + +func ProjectToResource(project string) (string, error) { + if !strings.Contains(project, "-") { + return "", fmt.Errorf("error: invalid project ID %s, ID should look like -", project) + } + projectIdArr := strings.SplitN(project, "-", 2) + return "/programs/" + projectIdArr[0] + "/projects/" + projectIdArr[1], nil +} + +// From git-drs/drsmap/drs_map.go + +func DrsUUID(projectId string, hash string) string { + // create UUID based on project ID and hash + hashStr := fmt.Sprintf("%s:%s", projectId, hash) + return uuid.NewSHA1(NAMESPACE, []byte(hashStr)).String() +} + +func FindMatchingRecord(records []DRSObject, projectId string) (*DRSObject, error) { + if len(records) == 0 { + return nil, nil + } + + // Convert project ID to resource path format for comparison + expectedAuthz, err := ProjectToResource(projectId) + if err != nil { + return nil, fmt.Errorf("error converting project ID to resource format: %v", err) + } + + for _, record := range records { + for _, access := range record.AccessMethods { + if access.Authorizations != nil && access.Authorizations.Value == expectedAuthz { + return &record, nil + } + } + } + + return nil, nil +} + +// DRS UUID generation using SHA1 (compatible with git-drs) +func GenerateDrsID(projectId, hash string) string { + return DrsUUID(projectId, hash) +} + +func BuildDrsObj(fileName string, checksum string, size int64, drsId string, bucketName string, projectId string) (*DRSObject, error) { + if bucketName == "" { + return nil, fmt.Errorf("error: bucket name is empty") + } + + fileURL := fmt.Sprintf("s3://%s/%s/%s", bucketName, drsId, checksum) + + authzStr, err := ProjectToResource(projectId) + if err != nil { + return nil, err + } + authorizations := Authorizations{ + Value: authzStr, + } + + drsObj := DRSObject{ + Id: drsId, + Name: fileName, + AccessMethods: []AccessMethod{{ + Type: "s3", + AccessURL: AccessURL{ + URL: fileURL, + }, + Authorizations: &authorizations, + }}, + Checksums: hash.HashInfo{SHA256: checksum}, + Size: size, + } + + return &drsObj, nil +} diff --git a/indexd/drs/types.go b/indexd/drs/types.go new file mode 100644 index 0000000..d17cd45 --- /dev/null +++ b/indexd/drs/types.go @@ -0,0 +1,56 @@ +package drs + +import ( + "github.com/calypr/data-client/indexd/hash" +) + +type ChecksumType = hash.ChecksumType +type Checksum = hash.Checksum +type HashInfo = hash.HashInfo + +type AccessURL struct { + URL string `json:"url"` + Headers []string `json:"headers"` +} + +type Authorizations struct { + Value string `json:"value"` +} + +type AccessMethod struct { + Type string `json:"type"` + AccessURL AccessURL `json:"access_url"` + AccessID string `json:"access_id,omitempty"` + Cloud string `json:"cloud,omitempty"` + Region string `json:"region,omitempty"` + Available string `json:"available,omitempty"` + Authorizations *Authorizations `json:"Authorizations,omitempty"` +} + +type Contents struct { +} + +type DRSPage struct { + DRSObjects []DRSObject `json:"drs_objects"` +} + +type DRSObjectResult struct { + Object *DRSObject + Error error +} + +type DRSObject struct { + Id string `json:"id"` + Name string `json:"name"` + SelfURI string `json:"self_uri,omitempty"` + Size int64 `json:"size"` + CreatedTime string `json:"created_time,omitempty"` + UpdatedTime string `json:"updated_time,omitempty"` + Version string `json:"version,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Checksums hash.HashInfo `json:"checksums"` + AccessMethods []AccessMethod `json:"access_methods"` + Contents []Contents `json:"contents,omitempty"` + Description string `json:"description,omitempty"` + Aliases []string `json:"aliases,omitempty"` +} diff --git a/indexd/hash/hash.go b/indexd/hash/hash.go new file mode 100644 index 0000000..11ee3a1 --- /dev/null +++ b/indexd/hash/hash.go @@ -0,0 +1,144 @@ +package hash + +import ( + "encoding/json" + "fmt" +) + +// ChecksumType represents the digest method used to create the checksum +type ChecksumType string + +// IANA Named Information Hash Algorithm Registry values and other common types +const ( + ChecksumTypeSHA1 ChecksumType = "sha1" + ChecksumTypeSHA256 ChecksumType = "sha256" + ChecksumTypeSHA512 ChecksumType = "sha512" + ChecksumTypeMD5 ChecksumType = "md5" + ChecksumTypeETag ChecksumType = "etag" + ChecksumTypeCRC32C ChecksumType = "crc32c" + ChecksumTypeTrunc512 ChecksumType = "trunc512" +) + +// IsValid checks if the checksum type is a known/recommended value +func (ct ChecksumType) IsValid() bool { + switch ct { + case ChecksumTypeSHA256, ChecksumTypeSHA512, ChecksumTypeSHA1, ChecksumTypeMD5, + ChecksumTypeETag, ChecksumTypeCRC32C, ChecksumTypeTrunc512: + return true + default: + return false + } +} + +// String returns the string representation of the checksum type +func (ct ChecksumType) String() string { + return string(ct) +} + +var SupportedChecksums = map[string]bool{ + string(ChecksumTypeSHA1): true, + string(ChecksumTypeSHA256): true, + string(ChecksumTypeSHA512): true, + string(ChecksumTypeMD5): true, + string(ChecksumTypeETag): true, + string(ChecksumTypeCRC32C): true, + string(ChecksumTypeTrunc512): true, +} + +type Checksum struct { + Checksum string `json:"checksum"` + Type ChecksumType `json:"type"` +} + +type HashInfo struct { + MD5 string `json:"md5,omitempty"` + SHA string `json:"sha,omitempty"` + SHA256 string `json:"sha256,omitempty"` + SHA512 string `json:"sha512,omitempty"` + CRC string `json:"crc,omitempty"` + ETag string `json:"etag,omitempty"` +} + +// UnmarshalJSON accepts both the DRS map-based schema and the array-of-checksums schema. +func (h *HashInfo) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *h = HashInfo{} + return nil + } + + var mapPayload map[string]string + if err := json.Unmarshal(data, &mapPayload); err == nil { + *h = ConvertStringMapToHashInfo(mapPayload) + return nil + } + + var checksumPayload []Checksum + if err := json.Unmarshal(data, &checksumPayload); err == nil { + *h = ConvertChecksumsToHashInfo(checksumPayload) + return nil + } + + return fmt.Errorf("unsupported HashInfo payload: %s", string(data)) +} + +func ConvertStringMapToHashInfo(inputHashes map[string]string) HashInfo { + hashInfo := HashInfo{} + + for key, value := range inputHashes { + if !SupportedChecksums[key] { + continue // Disregard unsupported types + } + switch key { + case string(ChecksumTypeMD5): + hashInfo.MD5 = value + case string(ChecksumTypeSHA1): + hashInfo.SHA = value + case string(ChecksumTypeSHA256): + hashInfo.SHA256 = value + case string(ChecksumTypeSHA512): + hashInfo.SHA512 = value + case string(ChecksumTypeCRC32C): + hashInfo.CRC = value + case string(ChecksumTypeETag): + hashInfo.ETag = value + } + } + + return hashInfo +} + +func ConvertHashInfoToMap(hashes HashInfo) map[string]string { + result := make(map[string]string) + if hashes.MD5 != "" { + result["md5"] = hashes.MD5 + } + if hashes.SHA != "" { + result["sha"] = hashes.SHA + } + if hashes.SHA256 != "" { + result["sha256"] = hashes.SHA256 + } + if hashes.SHA512 != "" { + result["sha512"] = hashes.SHA512 + } + if hashes.CRC != "" { + result["crc"] = hashes.CRC + } + if hashes.ETag != "" { + result["etag"] = hashes.ETag + } + return result +} + +func ConvertChecksumsToMap(checksums []Checksum) map[string]string { + result := make(map[string]string, len(checksums)) + for _, c := range checksums { + result[string(c.Type)] = c.Checksum + } + return result +} + +func ConvertChecksumsToHashInfo(checksums []Checksum) HashInfo { + checksumMap := ConvertChecksumsToMap(checksums) + return ConvertStringMapToHashInfo(checksumMap) +} diff --git a/indexd/hash/hash_test.go b/indexd/hash/hash_test.go new file mode 100644 index 0000000..f08c7ea --- /dev/null +++ b/indexd/hash/hash_test.go @@ -0,0 +1,53 @@ +package hash + +import ( + "encoding/json" + "testing" +) + +func TestChecksumType_IsValid(t *testing.T) { + tests := []struct { + name string + ct ChecksumType + want bool + }{ + {"valid sha256", ChecksumTypeSHA256, true}, + {"valid md5", ChecksumTypeMD5, true}, + {"invalid type", "invalid", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.ct.IsValid(); got != tt.want { + t.Errorf("ChecksumType.IsValid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHashInfo_UnmarshalJSON_Map(t *testing.T) { + jsonMap := `{"sha256": "hash-val", "md5": "md5-val"}` + var h HashInfo + if err := json.Unmarshal([]byte(jsonMap), &h); err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + if h.SHA256 != "hash-val" { + t.Errorf("expected SHA256 hash-val, got %s", h.SHA256) + } + if h.MD5 != "md5-val" { + t.Errorf("expected MD5 md5-val, got %s", h.MD5) + } +} + +func TestHashInfo_UnmarshalJSON_List(t *testing.T) { + jsonList := `[{"type": "sha256", "checksum": "hash-val"}, {"type": "md5", "checksum": "md5-val"}]` + var h HashInfo + if err := json.Unmarshal([]byte(jsonList), &h); err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + if h.SHA256 != "hash-val" { + t.Errorf("expected SHA256 hash-val, got %s", h.SHA256) + } + if h.MD5 != "md5-val" { + t.Errorf("expected MD5 md5-val, got %s", h.MD5) + } +} diff --git a/indexd/records.go b/indexd/records.go new file mode 100644 index 0000000..72e2de6 --- /dev/null +++ b/indexd/records.go @@ -0,0 +1,97 @@ +package indexd + +// https://github.com/uc-cdis/indexd/blob/master/openapis/swagger.yaml + +import ( + "github.com/calypr/data-client/indexd/hash" +) + +// subset of the OpenAPI spec for the InputInfo object in indexd +// TODO: make another object based on VersionInputInfo that has content_created_date and so can handle a POST of dates via indexd/ +type IndexdRecord struct { + // Unique identifier for the record (UUID) + Did string `json:"did"` + + // Human-readable file name + FileName string `json:"file_name,omitempty"` + + // List of URLs where the file can be accessed + URLs []string `json:"urls"` + + // Hashes of the file (e.g., md5, sha256) + Size int64 `json:"size"` + + // List of access control lists (ACLs) + ACL []string `json:"acl,omitempty"` + + // List of authorization policies + Authz []string `json:"authz,omitempty"` + + Hashes hash.HashInfo `json:"hashes,omitzero"` + + // Additional metadata as key-value pairs + Metadata map[string]string `json:"metadata,omitempty"` + + // Version of the record (optional) + Version string `json:"version,omitempty"` +} + +// create indexd record struct used for POSTs that is IndexdRecord with form field +type IndexdRecordForm struct { + IndexdRecord + Form string `json:"form"` + Rev string `json:"rev,omitempty"` +} + +type ListRecordsResult struct { + Record *OutputInfo + Error error +} + +type ListRecords struct { + IDs []string `json:"ids"` + Records []OutputInfo `json:"records"` + Size int64 `json:"size"` + Start int64 `json:"start"` + Limit int64 `json:"limit"` + FileName string `json:"file_name"` + URLs []string `json:"urls"` + ACL []string `json:"acl"` + Authz []string `json:"authz"` + Hashes hash.HashInfo `json:"hashes"` + Metadata map[string]any `json:"metadata"` + Version string `json:"version"` +} + +type OutputInfo struct { + Did string `json:"did"` + BaseID string `json:"baseid"` + Rev string `json:"rev"` + Form string `json:"form"` + Size int64 `json:"size"` + FileName string `json:"file_name"` + Version string `json:"version"` + Uploader string `json:"uploader"` + URLs []string `json:"urls"` + ACL []string `json:"acl"` + Authz []string `json:"authz"` + Hashes hash.HashInfo `json:"hashes"` + UpdatedDate string `json:"updated_date"` + CreatedDate string `json:"created_date"` + Metadata map[string]any `json:"metadata"` + URLsMetadata map[string]any `json:"urls_metadata"` +} + +func (outputInfo *OutputInfo) ToIndexdRecord() *IndexdRecord { + return &IndexdRecord{ + Did: outputInfo.Did, + Size: outputInfo.Size, + FileName: outputInfo.FileName, + URLs: outputInfo.URLs, + ACL: outputInfo.ACL, + Authz: outputInfo.Authz, + Hashes: outputInfo.Hashes, + //Metadata: outputInfo.Metadata, //TODO: re-enable metadata. One is map[string]string, the other is map[string]interface{} + Version: outputInfo.Version, + } +} diff --git a/indexd/s3_utils.go b/indexd/s3_utils.go new file mode 100644 index 0000000..09997c8 --- /dev/null +++ b/indexd/s3_utils.go @@ -0,0 +1,124 @@ +package indexd + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/data-client/fence" +) + +// ParseS3URL parses a URL like s3://bucket/key and returns (bucket, key, error). +func ParseS3URL(s3url string) (string, string, error) { + s3Prefix := "s3://" + if !strings.HasPrefix(s3url, s3Prefix) { + return "", "", fmt.Errorf("S3 URL requires prefix 's3://': %s", s3url) + } + trimmed := strings.TrimPrefix(s3url, s3Prefix) + slashIndex := strings.Index(trimmed, "/") + if slashIndex == -1 || slashIndex == len(trimmed)-1 { + return "", "", fmt.Errorf("invalid S3 file URL: %s", s3url) + } + return trimmed[:slashIndex], trimmed[slashIndex+1:], nil +} + +// ValidateInputs checks if S3 URL and SHA256 hash are valid. +func ValidateInputs(s3URL, sha256 string) error { + if s3URL == "" { + return fmt.Errorf("S3 URL is required") + } + if sha256 == "" { + return fmt.Errorf("SHA256 hash is required") + } + if !strings.HasPrefix(s3URL, "s3://") { + return fmt.Errorf("invalid S3 URL: must start with s3://") + } + if len(sha256) != 64 { + return fmt.Errorf("invalid SHA256 hash: must be 64 characters") + } + return nil +} + +// FetchS3MetadataWithBucketDetails fetches S3 metadata (size and modified date) for a given S3 URL. +func FetchS3MetadataWithBucketDetails( + ctx context.Context, + s3URL string, + awsAccessKey string, + awsSecretKey string, + region string, + endpoint string, + bucketDetails *fence.S3Bucket, + s3Client *s3.Client, + logger *slog.Logger, +) (int64, string, error) { + bucket, key, err := ParseS3URL(s3URL) + if err != nil { + return 0, "", err + } + + if s3Client == nil { + var configOptions []func(*awsConfig.LoadOptions) error + if awsAccessKey != "" && awsSecretKey != "" { + configOptions = append(configOptions, + awsConfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(awsAccessKey, awsSecretKey, "")), + ) + } + + regionToUse := "" + if region != "" { + regionToUse = region + } else if bucketDetails != nil && bucketDetails.Region != "" { + regionToUse = bucketDetails.Region + } + if regionToUse != "" { + configOptions = append(configOptions, awsConfig.WithRegion(regionToUse)) + } + + cfg, err := awsConfig.LoadDefaultConfig(ctx, configOptions...) + if err != nil { + return 0, "", fmt.Errorf("unable to load AWS SDK config: %w", err) + } + + endpointToUse := "" + if endpoint != "" { + endpointToUse = endpoint + } else if bucketDetails != nil && bucketDetails.EndpointURL != "" { + endpointToUse = bucketDetails.EndpointURL + } + + s3Client = s3.NewFromConfig(cfg, func(o *s3.Options) { + if endpointToUse != "" { + o.BaseEndpoint = aws.String(endpointToUse) + } + o.UsePathStyle = true + }) + } + + input := &s3.HeadObjectInput{ + Bucket: &bucket, + Key: aws.String(key), + } + + resp, err := s3Client.HeadObject(ctx, input) + if err != nil { + return 0, "", fmt.Errorf("failed to head object: %w", err) + } + + var contentLength int64 + if resp.ContentLength != nil { + contentLength = *resp.ContentLength + } + + var lastModified string + if resp.LastModified != nil { + lastModified = resp.LastModified.Format(time.RFC3339) + } + + return contentLength, lastModified, nil +} diff --git a/indexd/tests/add-url-integration_test.go b/indexd/tests/add-url-integration_test.go new file mode 100644 index 0000000..0500b10 --- /dev/null +++ b/indexd/tests/add-url-integration_test.go @@ -0,0 +1,68 @@ +package indexd_tests + +// // TODO: fix this during add-url fix +// import ( +// "testing" + +// "github.com/calypr/git-drs/utils" +// "github.com/stretchr/testify/require" +// ) + +// //////////////////// +// // E2E TESTS // +// // & MISC TESTS // +// //////////////////// + +// // TestAddURL_E2E_IdempotentSameURL tests end-to-end idempotency +// func TestAddURL_E2E_IdempotentSameURL(t *testing.T) { +// // Arrange: Start mock servers +// gen3Mock := NewMockGen3Server(t, "http://localhost:9000") +// defer gen3Mock.Close() + +// s3Mock := NewMockS3Server(t) +// defer s3Mock.Close() + +// indexdMock := NewMockIndexdServer(t) +// defer indexdMock.Close() + +// // Pre-populate S3 with test object +// s3Mock.AddObject("test-bucket", "sample.bam", 1024) + +// // TODO: This test is limited because AddURL has hardcoded config.LoadConfig() +// // In a real scenario, we'd need to mock that too or refactor AddURL to accept config +// t.Skip("Requires AddURL refactoring to accept config parameter") +// } + +// // TestAddURL_E2E_UpdateDifferentURL tests updating record with different URL +// // TODO: stubbed +// func TestAddURL_E2E_UpdateDifferentURL(t *testing.T) { +// // TODO: This test is skipped because it requires AddURL refactoring +// // See TestAddURL_E2E_IdempotentSameURL for explanation +// t.Skip("Requires AddURL refactoring to accept config parameter") +// } + +// // TestAddURL_E2E_LFSNotTracked tests LFS validation +// func TestAddURL_E2E_LFSNotTracked(t *testing.T) { +// // This test validates the LFS tracking check +// // The actual utils.IsLFSTracked function is tested separately in utils package + +// // Test the pattern matching logic by verifying ParseGitAttributes works +// gitattributesContent := `*.bam filter=lfs diff=lfs merge=lfs -text +// *.vcf filter=lfs diff=lfs merge=lfs -text` + +// attributes, err := utils.ParseGitAttributes(gitattributesContent) +// require.NoError(t, err) +// require.GreaterOrEqual(t, len(attributes), 2) + +// // Verify .bam pattern exists +// found := false +// for _, attr := range attributes { +// if attr.Pattern == "*.bam" { +// if filter, exists := attr.Attributes["filter"]; exists { +// require.Equal(t, "lfs", filter) +// found = true +// } +// } +// } +// require.True(t, found, "*.bam pattern with lfs filter should exist") +// } diff --git a/indexd/tests/client_read_test.go.todo b/indexd/tests/client_read_test.go.todo new file mode 100644 index 0000000..51857e1 --- /dev/null +++ b/indexd/tests/client_read_test.go.todo @@ -0,0 +1,134 @@ +package indexd_tests + +import ( + "testing" + + "github.com/calypr/git-drs/drs/hash" + "github.com/stretchr/testify/require" +) + +/////////////////// +// READ TESTS // +/////////////////// + +// Integration tests for READ operations on IndexdClient using mock indexd server. +// These tests verify non-mutating operations that query and retrieve data: +// - GetRecord / GetIndexdRecordByDID - Retrieve a single record by DID +// - GetObjectsByHash - Query records by hash +// - GetDownloadURL - Get signed download URLs +// - GetProjectId - Simple getter for project ID + +// TestIndexdClient_GetRecord tests retrieving a record via the client method with mocked auth +func TestIndexdClient_GetRecord(t *testing.T) { + // Arrange: Start mock server + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + // Pre-populate mock with test record + testRecord := newTestRecord("uuid-test-123") + addRecordToMockServer(mockServer, testRecord) + + // Act: Use client method with mocked auth (tests actual client logic) + client := testIndexdClientWithMockAuth(mockServer.URL()) + record, err := client.GetIndexdRecordByDID(testRecord.Did) + + // Assert: Test actual client logic + require.NoError(t, err) + require.NotNil(t, record) + require.Equal(t, testRecord.Did, record.Did) + require.Equal(t, testRecord.Size, record.Size) + require.Equal(t, testRecord.FileName, record.FileName) +} + +// TestIndexdClient_GetRecord_NotFound tests error handling for non-existent records +func TestIndexdClient_GetRecord_NotFound(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + // Act: Use client method to request non-existent record + client := testIndexdClientWithMockAuth(mockServer.URL()) + record, err := client.GetIndexdRecordByDID("does-not-exist") + + // Assert: Client should handle 404 errors properly + require.Error(t, err) + require.Nil(t, record) + require.Contains(t, err.Error(), "failed to get record") +} + +/////////////////////////////// +// GetObjectsByHash Tests +/////////////////////////////// + +// TestIndexdClient_GetObjectsByHash tests hash-based queries via client method with mocked auth +func TestIndexdClient_GetObjectsByHash(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + testRecord := newTestRecord("uuid-test-456", withTestRecordSize(2048)) + sha256 := testRecord.Hashes["sha256"] + addRecordWithHashIndex(mockServer, testRecord, "sha256", sha256) + + // Create client with mocked auth + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act: Call the actual client method + results, err := client.GetObjectByHash(&hash.Checksum{Type: "sha256", Checksum: sha256}) + + // Assert: Verify client method works end-to-end + require.NoError(t, err) + require.Len(t, results, 1) + + // Verify correct record was returned + record := results[0] + require.Equal(t, testRecord.Did, record.Id) + require.Equal(t, testRecord.Size, record.Size) + require.Equal(t, sha256, record.Checksums.SHA256) + + require.Equal(t, testRecord.URLs[0], record.AccessMethods[0].AccessURL.URL) + require.Equal(t, testRecord.Authz[0], record.AccessMethods[0].Authorizations.Value) + + // Test: Query with non-existent hash + emptyResults, err := client.GetObjectByHash(&hash.Checksum{Type: "sha256", Checksum: "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}) + require.NoError(t, err) + require.Len(t, emptyResults, 0) +} + +/////////////////////////////// +// GetProjectId Tests +/////////////////////////////// + +// TestIndexdClient_GetProjectId tests the simple getter for project ID +func TestIndexdClient_GetProjectId(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act + projectId := client.GetProjectId() + + // Assert: Should return the project ID set during client creation + require.Equal(t, "test-project", projectId, "Should return configured project ID") +} + +// TestIndexdClient_GetProjectId_ConsistentAcrossCalls tests that GetProjectId is consistent +func TestIndexdClient_GetProjectId_ConsistentAcrossCalls(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act: Call multiple times + projectId1 := client.GetProjectId() + projectId2 := client.GetProjectId() + projectId3 := client.GetProjectId() + + // Assert: All calls should return the same value + require.Equal(t, projectId1, projectId2, "GetProjectId should be consistent") + require.Equal(t, projectId2, projectId3, "GetProjectId should be consistent") + require.Equal(t, "test-project", projectId1) +} diff --git a/indexd/tests/client_write_test.go.todo b/indexd/tests/client_write_test.go.todo new file mode 100644 index 0000000..1f6ee62 --- /dev/null +++ b/indexd/tests/client_write_test.go.todo @@ -0,0 +1,369 @@ +package indexd_tests + +import ( + "testing" + + indexd_client "github.com/calypr/git-drs/client/indexd" + "github.com/calypr/git-drs/drs" + "github.com/calypr/git-drs/drs/hash" + "github.com/stretchr/testify/require" +) + +/////////////////// +// WRITE TESTS // +/////////////////// + +// Integration tests for WRITE operations on IndexdClient using mock indexd server. +// These tests verify mutating operations that create, update, or delete data: +// - RegisterRecord / RegisterIndexdRecord - Create new records +// - UpdateRecord / UpdateRecord - Modify existing records +// - DeleteRecord / DeleteIndexdRecord - Remove records + +// TestIndexdClient_RegisterRecord tests the high-level RegisterRecord method +// which converts a DRSObject to IndexdRecord and registers it +func TestIndexdClient_RegisterRecord(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Create a DRS object to register + drsObject := &drs.DRSObject{ + Id: "uuid-drs-register-test", + Name: "test-file.bam", + Size: 3000, + Checksums: hash.HashInfo{ + SHA256: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + AccessMethods: []drs.AccessMethod{ + { + AccessURL: drs.AccessURL{ + URL: "s3://drs-test-bucket/test-file.bam", + }, + Authorizations: &drs.Authorizations{ + Value: "/programs/drs-test/projects/test", + }, + }, + }, + } + + // Act: Call RegisterRecord which should: + // 1. Convert DRSObject to IndexdRecord + // 2. Call RegisterIndexdRecord + // 3. Return the registered DRSObject + result, err := client.RegisterRecord(drsObject) + + // Assert + require.NoError(t, err, "RegisterRecord should succeed") + require.NotNil(t, result, "Should return a valid DRSObject") + + // Verify the record was created in the mock server + storedRecord := mockServer.GetRecord(drsObject.Id) + require.NotNil(t, storedRecord, "Record should be stored in mock server") + require.Equal(t, drsObject.Name, storedRecord.FileName) + require.Equal(t, drsObject.Size, storedRecord.Size) + require.Contains(t, storedRecord.URLs, "s3://drs-test-bucket/test-file.bam") + + // Verify the returned DRS object matches + require.Equal(t, drsObject.Id, result.Id) + require.Equal(t, drsObject.Name, result.Name) + require.Equal(t, drsObject.Size, result.Size) +} + +// TestIndexdClient_RegisterRecord_MissingDID tests error handling when DID is missing +func TestIndexdClient_RegisterRecord_MissingDID(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Create a DRS object without ID (mock server will reject it) + invalidDrsObject := &drs.DRSObject{ + Name: "test-file.bam", + Size: 3000, + // Missing Id field - mock server should reject + } + + // Act + result, err := client.RegisterRecord(invalidDrsObject) + + // Assert: Should fail when registering with server (missing DID) + require.Error(t, err, "Should fail when DID is missing") + require.Nil(t, result) + require.Contains(t, err.Error(), "Missing required field: did") +} + +// TestIndexdClient_RegisterIndexdRecord_CreatesNewRecord tests record creation via client method +func TestIndexdClient_RegisterIndexdRecord_CreatesNewRecord(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Create input record to register + // IndexdRecord used here is the client-side object + // We don't use the newTestRecord helper bc that's the [mock] server-side object + newRecord := &indexd_client.IndexdRecord{ + Did: "uuid-register-test", + FileName: "new-file.bam", + Size: 5000, + URLs: []string{"s3://bucket/new-file.bam"}, + Authz: []string{"/workspace/test"}, + Hashes: hash.HashInfo{ + SHA256: "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + }, + Metadata: map[string]string{ + "source": "test", + }, + } + + // Act: Call the RegisterIndexdRecord client method + // This tests: + // 1. Wrapping IndexdRecord in IndexdRecordForm with form="object" + // 2. Setting correct headers (Content-Type, accept) + // 3. Injecting auth header via MockAuthHandler + // 4. POSTing to /index/index endpoint + // 5. Handling 200 OK response + // 6. Querying the new record via GET /ga4gh/drs/v1/objects/{did} + // 7. Returning a valid DRSObject + drsObj, err := client.RegisterIndexdRecord(newRecord) + + // Assert: Verify the client method executed successfully + require.NoError(t, err, "RegisterIndexdRecord should succeed") + require.NotNil(t, drsObj, "Should return a valid DRSObject") + + // Verify the stored record matches input + storedRecord := mockServer.GetRecord(newRecord.Did) + require.NotNil(t, storedRecord, "Record should be stored in mock server after POST") + require.Equal(t, newRecord.FileName, storedRecord.FileName) + require.Equal(t, newRecord.Size, storedRecord.Size) + require.Equal(t, newRecord.URLs, storedRecord.URLs) + require.Equal(t, newRecord.Hashes.SHA256, storedRecord.Hashes["sha256"]) + + // Verify the returned DRS object matches input + require.Equal(t, newRecord.Did, drsObj.Id, "DRS object ID should match DID") + require.Equal(t, newRecord.FileName, drsObj.Name, "DRS object name should match FileName") + require.Equal(t, newRecord.Size, drsObj.Size, "DRS object size should match") + require.NotEmpty(t, drsObj.Checksums.SHA256, "Should have SHA256 checksum") + require.Equal(t, newRecord.Hashes.SHA256, drsObj.Checksums.SHA256) + require.Len(t, drsObj.AccessMethods, 1, "Should have one access method") + require.Equal(t, newRecord.URLs[0], drsObj.AccessMethods[0].AccessURL.URL) +} + +/////////////////////////////// +// UpdateRecord / UpdateRecord Tests +/////////////////////////////// + +// TestIndexdClient_UpdateRecord_AppendsURLs tests updating record via client method +func TestIndexdClient_UpdateRecord_AppendsURLs(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + originalRecord := newTestRecord("uuid-update-test", + withTestRecordFileName("file.bam"), + withTestRecordSize(2048), + withTestRecordURLs("s3://original-bucket/file.bam"), + withTestRecordHash("sha256", "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd")) + addRecordToMockServer(mockServer, originalRecord) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Create update info with new URL + newURL := "s3://new-bucket/file-v2.bam" + updateInfo := &drs.DRSObject{ + AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: newURL}}}, + } + + // Act: Call the UpdateRecord client method + // This tests: + // 1. Getting the existing record via GET /index/{did} + // 2. Appending new URLs to existing URLs + // 3. Marshaling UpdateInputInfo to JSON + // 4. Setting correct headers (Content-Type, accept) + // 5. Injecting auth header via MockAuthHandler + // 6. PUTting to /index/index/{did} endpoint with new URLs + // 7. Handling 200 OK response + // 8. Querying the updated record via GET /ga4gh/drs/v1/objects/{did} + // 9. Returning a valid DRSObject + drsObj, err := client.UpdateRecord(updateInfo, originalRecord.Did) + + // Assert: Verify the client method executed successfully + require.NoError(t, err, "UpdateRecord should succeed") + require.NotNil(t, drsObj, "Should return a valid DRSObject") + + // Verify the URLs were appended correctly + updatedRecord := mockServer.GetRecord(originalRecord.Did) + require.NotNil(t, updatedRecord) + require.Equal(t, 2, len(updatedRecord.URLs), "Should have appended new URL to existing") + require.Contains(t, updatedRecord.URLs, originalRecord.URLs[0]) + require.Contains(t, updatedRecord.URLs, newURL) + + // Verify the returned DRS object + require.Equal(t, originalRecord.Did, drsObj.Id, "DRS object ID should match DID") + require.Equal(t, originalRecord.FileName, drsObj.Name, "DRS object name should match FileName") + require.Equal(t, originalRecord.Size, drsObj.Size, "DRS object size should match") + require.NotEmpty(t, drsObj.Checksums.SHA256, "Should have SHA256 checksum") + require.Equal(t, originalRecord.Hashes["sha256"], drsObj.Checksums.SHA256) + require.Len(t, drsObj.AccessMethods, 2, "Should have two access methods (URLs)") + urls := []string{drsObj.AccessMethods[0].AccessURL.URL, drsObj.AccessMethods[1].AccessURL.URL} + require.Contains(t, urls, originalRecord.URLs[0]) + require.Contains(t, urls, newURL) +} + +// TestIndexdClient_RegisterFile_UsesSingleHashQuery verifies RegisterFile reuses +// the initial hash lookup when checking downloadability. +func TestIndexdClient_RegisterFile_UsesSingleHashQuery(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + mockServer.signedURLBase = mockServer.URL() + "/signed" + + record := newTestRecord("uuid-register-file-test", + withTestRecordHash("sha256", testSHA256Hash), + withTestRecordURLs("s3://test-bucket/test-file.bam")) + addRecordWithHashIndex(mockServer, record, "sha256", testSHA256Hash) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act + result, err := client.RegisterFile(testSHA256Hash) + + // Assert + require.NoError(t, err, "RegisterFile should not error when file is downloadable") + require.NotNil(t, result, "RegisterFile should return the existing DRS object") + require.Equal(t, 1, mockServer.HashQueryCount(), "expected a single hash query during RegisterFile") +} + +// TestIndexdClient_UpdateRecord_Idempotent tests URL appending idempotency via client method +func TestIndexdClient_UpdateRecord_Idempotent(t *testing.T) { + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + originalRecord := newTestRecord("uuid-update-idempotent", + withTestRecordURLs("s3://bucket1/file.bam"), + withTestRecordHash("sha256", "aaaa...")) + addRecordToMockServer(mockServer, originalRecord) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Create update info with same URL (should be idempotent) + updateInfo := &drs.DRSObject{ + AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: originalRecord.URLs[0]}}}, + } + + // call the UpdateRecord client method + drsObj, err := client.UpdateRecord(updateInfo, originalRecord.Did) + require.NoError(t, err) + + // Verify URL wasn't duplicated + updated := mockServer.GetRecord(drsObj.Id) + require.NotNil(t, updated) + require.Equal(t, 1, len(updated.URLs)) + require.Equal(t, originalRecord.URLs[0], updated.URLs[0]) +} + +/////////////////////////////// +// DeleteRecord / DeleteIndexdRecord Tests +/////////////////////////////// + +// TestIndexdClient_DeleteRecord tests deleting a record by OID +func TestIndexdClient_DeleteRecord(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + // Pre-populate with a test record + testHash := "1111111111111111111111111111111111111111111111111111111111111111" + testRecord := newTestRecord("uuid-delete-by-oid", + withTestRecordFileName("delete-me.bam"), + withTestRecordSize(4096), + withTestRecordHash("sha256", testHash)) + addRecordWithHashIndex(mockServer, testRecord, "sha256", testHash) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Verify record exists before deletion + recordBefore := mockServer.GetRecord(testRecord.Did) + require.NotNil(t, recordBefore, "Record should exist before deletion") + + // Act: Delete by OID (which is the hash) + err := client.DeleteRecord(testHash) + + // Assert + require.NoError(t, err, "DeleteRecord should succeed") + + // Verify record was deleted + recordAfter := mockServer.GetRecord(testRecord.Did) + require.Nil(t, recordAfter, "Record should be deleted") +} + +// TestIndexdClient_DeleteRecord_NotFound tests deleting a non-existent record +func TestIndexdClient_DeleteRecord_NotFound(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act: Try to delete a record that doesn't exist + nonExistentHash := "9999999999999999999999999999999999999999999999999999999999999999" + err := client.DeleteRecord(nonExistentHash) + + // Assert: Should return error + require.Error(t, err, "Should fail when record doesn't exist") + require.Contains(t, err.Error(), "no records found for OID") +} + +// TestIndexdClient_DeleteRecord_NoMatchingProject tests deletion when record exists but for different project +func TestIndexdClient_DeleteRecord_NoMatchingProject(t *testing.T) { + // Arrange + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + // Create a record with a DIFFERENT project authorization + testHash := "2222222222222222222222222222222222222222222222222222222222222222" + differentProjectAuthz := "/programs/other-program/projects/other-project" + testRecord := newTestRecord("uuid-different-project", + withTestRecordFileName("other-project.bam"), + withTestRecordHash("sha256", testHash)) + testRecord.Authz = []string{differentProjectAuthz} // Override with different project + addRecordWithHashIndex(mockServer, testRecord, "sha256", testHash) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Act: Try to delete - should fail because project doesn't match + err := client.DeleteRecord(testHash) + + // Assert + require.Error(t, err, "Should fail when no matching project") + require.Contains(t, err.Error(), "no matching record found for project") + + // Verify record still exists (wasn't deleted) + recordAfter := mockServer.GetRecord(testRecord.Did) + require.NotNil(t, recordAfter, "Record should still exist") +} + +// TestIndexdClient_DeleteIndexdRecord_Removes tests record deletion via client method +func TestIndexdClient_DeleteIndexdRecord_Removes(t *testing.T) { + mockServer := NewMockIndexdServer(t) + defer mockServer.Close() + + testRecord := newTestRecord("uuid-delete-test", withTestRecordURLs("s3://bucket/file.bam")) + addRecordToMockServer(mockServer, testRecord) + + client := testIndexdClientWithMockAuth(mockServer.URL()) + + // Delete record via client method + err := client.DeleteIndexdRecord(testRecord.Did) + + require.NoError(t, err) + + // Verify it's gone + deletedRecord := mockServer.GetRecord(testRecord.Did) + require.Nil(t, deletedRecord) +} diff --git a/indexd/tests/mock_servers_test.go b/indexd/tests/mock_servers_test.go new file mode 100644 index 0000000..869cd51 --- /dev/null +++ b/indexd/tests/mock_servers_test.go @@ -0,0 +1,610 @@ +package indexd_tests + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "testing" + "time" + + indexd_client "github.com/calypr/data-client/indexd" + "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/indexd/hash" +) + +////////////////// +// MOCK SERVERS // +////////////////// + +// MockIndexdRecord represents a stored Indexd record in memory +type MockIndexdRecord struct { + Did string `json:"did"` + FileName string `json:"file_name"` + Size int64 `json:"size"` + Hashes map[string]string `json:"hashes"` + URLs []string `json:"urls"` + Authz []string `json:"authz"` + Metadata map[string]string `json:"metadata"` + CreatedAt time.Time `json:"-"` // Not serialized +} + +// MockIndexdServer simulates an Indexd server with in-memory storage +type MockIndexdServer struct { + httpServer *httptest.Server + records map[string]*MockIndexdRecord + hashIndex map[string][]string // hash -> [DIDs] + signedURLBase string + hashQueryCount int + recordMutex sync.RWMutex +} + +// NewMockIndexdServer creates and starts a mock Indexd server +func NewMockIndexdServer(t *testing.T) *MockIndexdServer { + mis := &MockIndexdServer{ + records: make(map[string]*MockIndexdRecord), + hashIndex: make(map[string][]string), + signedURLBase: "https://signed-url.example.com", + } + + mux := http.NewServeMux() + + // Register handlers for /index and /index/ paths + // /index matches exact path and query params (POST, GET with ?hash=) + mux.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) { + // POST /index - create record + if r.Method == http.MethodPost { + mis.handleCreateRecord(w, r) + return + } + + // GET /index?hash=... - query by hash + if r.Method == http.MethodGet { + mis.handleQueryByHash(w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + // /index/index handles /index/index for POST and /index/index?hash= for GET + mux.HandleFunc("/index/index", func(w http.ResponseWriter, r *http.Request) { + // POST /index/index - create record + if r.Method == http.MethodPost { + mis.handleCreateRecord(w, r) + return + } + + // GET /index/index?hash=... - query by hash + if r.Method == http.MethodGet { + mis.handleQueryByHash(w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + // /ga4gh/drs/v1/objects/ handles GET requests for DRS object and signed URLs + mux.HandleFunc("/ga4gh/drs/v1/objects/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract path after /ga4gh/drs/v1/objects/ + path := strings.TrimPrefix(r.URL.Path, "/ga4gh/drs/v1/objects/") + if path == "" { + http.Error(w, "Missing object ID", http.StatusBadRequest) + return + } + + // Split path to determine if this is object request or access request + pathParts := strings.Split(path, "/") + + if len(pathParts) == 1 { + // GET /ga4gh/drs/v1/objects/{id} - get DRS object + mis.handleGetDRSObject(w, r, pathParts[0]) + } else if len(pathParts) == 3 && pathParts[1] == "access" { + // GET /ga4gh/drs/v1/objects/{id}/access/{accessId} - get signed URL + mis.handleGetSignedURL(w, r, pathParts[0], pathParts[2]) + } else { + http.Error(w, "Invalid path", http.StatusBadRequest) + } + }) + + mux.HandleFunc("/signed/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // /index/ matches /index/{guid} (trailing slash pattern) + mux.HandleFunc("/index/", func(w http.ResponseWriter, r *http.Request) { + // Extract DID from path: /index/{guid} -> {guid} + // This handles both /index/{id} and /index/index/{id} + path := r.URL.Path + var did string + + if strings.HasPrefix(path, "/index/index/") { + did = strings.TrimPrefix(path, "/index/index/") + } else { + did = strings.TrimPrefix(path, "/index/") + } + + if did == "" || did == "index" { + http.Error(w, "Missing DID", http.StatusBadRequest) + return + } + + switch r.Method { + case http.MethodGet: + mis.handleGetRecord(w, r, did) + case http.MethodPut: + mis.handleUpdateRecord(w, r, did) + case http.MethodDelete: + mis.handleDeleteRecord(w, r, did) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mis.httpServer = httptest.NewServer(mux) + return mis +} + +func (mis *MockIndexdServer) handleGetRecord(w http.ResponseWriter, r *http.Request, did string) { + mis.recordMutex.RLock() + record, exists := mis.records[did] + mis.recordMutex.RUnlock() + + if !exists { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{"error": "Record not found"}) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(record) +} + +func (mis *MockIndexdServer) handleGetDRSObject(w http.ResponseWriter, r *http.Request, id string) { + mis.recordMutex.RLock() + record, exists := mis.records[id] + mis.recordMutex.RUnlock() + if !exists { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{"error": "Object not found"}) + return + } + + // Build standard DRS checksums array + checksums := []map[string]string{} + for typ, sum := range record.Hashes { + if sum != "" { + checksums = append(checksums, map[string]string{ + "type": strings.ToLower(typ), + "checksum": sum, + }) + } + } + + // Build access methods + accessMethods := []map[string]any{} + for i, url := range record.URLs { + am := map[string]any{ + "type": "https", + "access_id": fmt.Sprintf("https-%d", i), + "access_url": map[string]string{"url": url}, + } + // Only add authorizations if present, and as a SINGLE object (not array) + if len(record.Authz) > 0 { + am["authorizations"] = map[string]string{ + "value": record.Authz[0], + } + } + accessMethods = append(accessMethods, am) + } + + // Full response + response := map[string]any{ + "id": record.Did, + "name": record.FileName, + "size": record.Size, + "created_time": record.CreatedAt.Format(time.RFC3339), + "checksums": checksums, + "access_methods": accessMethods, + "description": "Mock DRS object from Indexd record", + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) +} + +func (mis *MockIndexdServer) handleGetSignedURL(w http.ResponseWriter, r *http.Request, objectId, accessId string) { + mis.recordMutex.RLock() + _, exists := mis.records[objectId] + mis.recordMutex.RUnlock() + + if !exists { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{"error": "Object not found"}) + return + } + + // Create a mock signed URL + base := strings.TrimSuffix(mis.signedURLBase, "/") + signedURL := drs.AccessURL{ + URL: fmt.Sprintf("%s/%s/%s", base, objectId, accessId), + Headers: []string{}, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(signedURL) +} + +func (mis *MockIndexdServer) handleCreateRecord(w http.ResponseWriter, r *http.Request) { + // Handle IndexdRecordForm (client sends this with POST) + var form struct { + indexd_client.IndexdRecord + Form string `json:"form"` + Rev string `json:"rev"` + } + + if err := json.NewDecoder(r.Body).Decode(&form); err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + // Extract the core record data + record := MockIndexdRecord{ + Did: form.Did, + FileName: form.FileName, + Size: form.Size, + URLs: form.URLs, + Authz: form.Authz, + Hashes: hash.ConvertHashInfoToMap(form.Hashes), + Metadata: form.Metadata, // Already map[string]string from IndexdRecord + CreatedAt: time.Now(), + } + + if record.Did == "" { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "Missing required field: did"}) + return + } + + mis.recordMutex.Lock() + defer mis.recordMutex.Unlock() + + if _, exists := mis.records[record.Did]; exists { + w.WriteHeader(http.StatusConflict) + json.NewEncoder(w).Encode(map[string]string{"error": "Record already exists"}) + return + } + + // Index by hash for queryability + for hashType, hash := range record.Hashes { + if hash != "" { // Only index non-empty hashes + key := hashType + ":" + hash + mis.hashIndex[key] = append(mis.hashIndex[key], record.Did) + } + } + + mis.records[record.Did] = &record + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(record) +} + +func (mis *MockIndexdServer) handleUpdateRecord(w http.ResponseWriter, r *http.Request, did string) { + mis.recordMutex.Lock() + defer mis.recordMutex.Unlock() + + record, exists := mis.records[did] + if !exists { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{"error": "Record not found"}) + return + } + + var update struct { + URLs []string `json:"urls"` + } + if err := json.NewDecoder(r.Body).Decode(&update); err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + // Append new URLs (avoid duplicates) + for _, newURL := range update.URLs { + if !slices.Contains(record.URLs, newURL) { + record.URLs = append(record.URLs, newURL) + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(record) +} + +func (mis *MockIndexdServer) handleQueryByHash(w http.ResponseWriter, r *http.Request) { + hashQuery := r.URL.Query().Get("hash") // format: "sha256:aaaa..." + + mis.recordMutex.Lock() + mis.hashQueryCount++ + mis.recordMutex.Unlock() + + mis.recordMutex.RLock() + dids, exists := mis.hashIndex[hashQuery] + mis.recordMutex.RUnlock() + + outputRecords := []indexd_client.OutputInfo{} + if exists { + mis.recordMutex.RLock() + for _, did := range dids { + if record, ok := mis.records[did]; ok { + // Convert sha256 hash string to HashInfo struct + hashes := hash.HashInfo{} + if sha256, ok := record.Hashes["sha256"]; ok { + hashes.SHA256 = sha256 + } + + // Convert metadata + metadata := make(map[string]any) + for k, v := range record.Metadata { + metadata[k] = v + } + + outputRecords = append(outputRecords, indexd_client.OutputInfo{ + Did: record.Did, + Size: record.Size, + Hashes: hashes, + URLs: record.URLs, + Authz: record.Authz, + Metadata: metadata, + }) + } + } + mis.recordMutex.RUnlock() + } + + w.Header().Set("Content-Type", "application/json") + // Return wrapped in ListRecords object matching Indexd API + response := indexd_client.ListRecords{ + Records: outputRecords, + IDs: dids, + Size: int64(len(outputRecords)), + } + json.NewEncoder(w).Encode(response) +} + +func (mis *MockIndexdServer) handleDeleteRecord(w http.ResponseWriter, r *http.Request, did string) { + mis.recordMutex.Lock() + defer mis.recordMutex.Unlock() + + _, exists := mis.records[did] + if !exists { + w.WriteHeader(http.StatusNotFound) + return + } + + delete(mis.records, did) + w.WriteHeader(http.StatusNoContent) +} + +// URL returns the mock server URL +func (mis *MockIndexdServer) URL() string { + return mis.httpServer.URL +} + +// Close closes the mock server +func (mis *MockIndexdServer) Close() { + mis.httpServer.Close() +} + +// GetAllRecords returns all records for testing purposes +func (mis *MockIndexdServer) GetAllRecords() []*MockIndexdRecord { + mis.recordMutex.RLock() + defer mis.recordMutex.RUnlock() + + records := make([]*MockIndexdRecord, 0, len(mis.records)) + for _, record := range mis.records { + records = append(records, record) + } + return records +} + +// GetRecord retrieves a single record by DID +func (mis *MockIndexdServer) GetRecord(did string) *MockIndexdRecord { + mis.recordMutex.RLock() + defer mis.recordMutex.RUnlock() + return mis.records[did] +} + +// HashQueryCount returns the number of hash query requests observed by the mock server. +func (mis *MockIndexdServer) HashQueryCount() int { + mis.recordMutex.RLock() + defer mis.recordMutex.RUnlock() + return mis.hashQueryCount +} + +// MockGen3Server simulates Gen3 /user/data/buckets endpoint +type MockGen3Server struct { + httpServer *httptest.Server + s3Endpoint string +} + +// NewMockGen3Server creates and starts a mock Gen3 server +func NewMockGen3Server(t *testing.T, s3Endpoint string) *MockGen3Server { + mgs := &MockGen3Server{ + s3Endpoint: s3Endpoint, + } + + mux := http.NewServeMux() + + mux.HandleFunc("/user/data/buckets", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + response := map[string]any{ + "S3_BUCKETS": map[string]any{ + "test-bucket": map[string]any{ + "region": "us-west-2", + "endpoint_url": mgs.s3Endpoint, + "programs": []string{"test-program"}, + }, + }, + "GS_BUCKETS": map[string]any{}, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }) + + mgs.httpServer = httptest.NewServer(mux) + return mgs +} + +// URL returns the mock server URL +func (mgs *MockGen3Server) URL() string { + return mgs.httpServer.URL +} + +// Client returns the mock server HTTP client +func (mgs *MockGen3Server) Client() *http.Client { + return mgs.httpServer.Client() +} + +// Close closes the mock server +func (mgs *MockGen3Server) Close() { + mgs.httpServer.Close() +} + +// MockS3Object represents a stored S3 object +type MockS3Object struct { + Size int64 + LastModified time.Time + ContentType string +} + +// MockS3Server simulates S3 HEAD object endpoint +type MockS3Server struct { + httpServer *httptest.Server + objects map[string]*MockS3Object // "bucket/key" -> object + objMutex sync.RWMutex +} + +// NewMockS3Server creates and starts a mock S3 server +func NewMockS3Server(t *testing.T) *MockS3Server { + mss := &MockS3Server{ + objects: make(map[string]*MockS3Object), + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/") + if path == "" { + http.Error(w, "Not found", http.StatusNotFound) + return + } + + if r.Method == http.MethodHead || r.Method == http.MethodGet { + mss.handleHeadObject(w, r, path) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mss.httpServer = httptest.NewServer(mux) + return mss +} + +func (mss *MockS3Server) handleHeadObject(w http.ResponseWriter, r *http.Request, path string) { + mss.objMutex.RLock() + object, exists := mss.objects[path] + mss.objMutex.RUnlock() + + if !exists { + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", object.Size)) + w.Header().Set("Last-Modified", object.LastModified.UTC().Format(http.TimeFormat)) + w.Header().Set("Content-Type", object.ContentType) + w.Header().Set("ETag", fmt.Sprintf("\"%x\"", object.LastModified.Unix())) + + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusOK) + w.Write(make([]byte, 0)) + } +} + +// AddObject adds a mock S3 object for testing +func (mss *MockS3Server) AddObject(bucket, key string, size int64) { + path := bucket + "/" + key + mss.objMutex.Lock() + defer mss.objMutex.Unlock() + + mss.objects[path] = &MockS3Object{ + Size: size, + LastModified: time.Now(), + ContentType: "application/octet-stream", + } +} + +// URL returns the mock server URL +func (mss *MockS3Server) URL() string { + return mss.httpServer.URL +} + +// Close closes the mock server +func (mss *MockS3Server) Close() { + mss.httpServer.Close() +} + +// Helper functions for type conversion +func convertMockRecordToDRSObject(record *MockIndexdRecord) *drs.DRSObject { + + // Convert URLs to AccessMethods + accessMethods := make([]drs.AccessMethod, 0) + for i, url := range record.URLs { + // Get the first authz as the authorization for this access method + var authzPtr *drs.Authorizations + if len(record.Authz) > 0 { + authzPtr = &drs.Authorizations{ + Value: record.Authz[0], + } + } + + accessMethods = append(accessMethods, drs.AccessMethod{ + Type: "https", + AccessID: fmt.Sprintf("access-method-%d", i), + AccessURL: drs.AccessURL{ + URL: url, + Headers: []string{}, + }, + Authorizations: authzPtr, + }) + } + + return &drs.DRSObject{ + Id: record.Did, + Name: record.FileName, + Size: record.Size, + Checksums: hash.ConvertStringMapToHashInfo(record.Hashes), + AccessMethods: accessMethods, + CreatedTime: record.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + Description: "DRS object created from Indexd record", + } +} diff --git a/indexd/types.go b/indexd/types.go new file mode 100644 index 0000000..dff0e48 --- /dev/null +++ b/indexd/types.go @@ -0,0 +1,75 @@ +package indexd + +import ( + "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/indexd/hash" +) + +type OutputObject struct { + Id string `json:"id"` + Name string `json:"name"` + SelfURI string `json:"self_uri,omitempty"` + Size int64 `json:"size"` + CreatedTime string `json:"created_time,omitempty"` + UpdatedTime string `json:"updated_time,omitempty"` + Version string `json:"version,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Checksums []hash.Checksum `json:"checksums"` + AccessMethods []drs.AccessMethod `json:"access_methods"` + Contents []drs.Contents `json:"contents,omitempty"` + Description string `json:"description,omitempty"` + Aliases []string `json:"aliases,omitempty"` +} + +func ConvertOutputObjectToDRSObject(in *OutputObject) *drs.DRSObject { + if in == nil { + return nil + } + + hashInfo := hash.ConvertChecksumsToHashInfo(in.Checksums) + + return &drs.DRSObject{ + Id: in.Id, + Name: in.Name, + SelfURI: in.SelfURI, + Size: in.Size, + CreatedTime: in.CreatedTime, + UpdatedTime: in.UpdatedTime, + Version: in.Version, + MimeType: in.MimeType, + Checksums: hashInfo, + AccessMethods: in.AccessMethods, + Contents: in.Contents, + Description: in.Description, + Aliases: in.Aliases, + } +} + +// UpdateInputInfo is the put object for index records +type UpdateInputInfo struct { + // Human-readable file name + FileName string `json:"file_name,omitempty"` + + // Additional metadata as key-value pairs + Metadata map[string]any `json:"metadata,omitempty"` + + // URL-specific metadata as key-value pairs + URLsMetadata map[string]any `json:"urls_metadata,omitempty"` + + // Version of the record + Version string `json:"version,omitempty"` + + // List of URLs where the file can be accessed + URLs []string `json:"urls,omitempty"` + + // List of access control lists (ACLs) + ACL []string `json:"acl,omitempty"` + + // List of authorization policies + Authz []string `json:"authz,omitempty"` +} + +type S3Meta struct { + Size int64 + LastModified string +} diff --git a/indexd/types_test.go b/indexd/types_test.go new file mode 100644 index 0000000..3125f03 --- /dev/null +++ b/indexd/types_test.go @@ -0,0 +1,60 @@ +package indexd + +import ( + "testing" + + "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/indexd/hash" +) + +func TestConvertOutputObjectToDRSObject(t *testing.T) { + out := &OutputObject{ + Id: "did-1", + Name: "file.txt", + SelfURI: "drs://server/did-1", + Size: 12345, + CreatedTime: "2023-01-01T00:00:00Z", + UpdatedTime: "2023-01-02T00:00:00Z", + Version: "v1", + MimeType: "text/plain", + Checksums: []hash.Checksum{ + {Type: hash.ChecksumTypeSHA256, Checksum: "sha256-hash"}, + {Type: hash.ChecksumTypeMD5, Checksum: "md5-hash"}, + }, + AccessMethods: []drs.AccessMethod{ + { + Type: "s3", + AccessURL: drs.AccessURL{ + URL: "s3://bucket/key", + }, + }, + }, + Description: "A test file", + Aliases: []string{"alias1"}, + } + + drsObj := ConvertOutputObjectToDRSObject(out) + + if drsObj.Id != out.Id { + t.Errorf("expected Id %s, got %s", out.Id, drsObj.Id) + } + if drsObj.Name != out.Name { + t.Errorf("expected Name %s, got %s", out.Name, drsObj.Name) + } + if drsObj.Size != out.Size { + t.Errorf("expected Size %d, got %d", out.Size, drsObj.Size) + } + // Verify Checksums conversion (slice to HashInfo) + if drsObj.Checksums.SHA256 != "sha256-hash" { + t.Errorf("expected SHA256 %s, got %s", "sha256-hash", drsObj.Checksums.SHA256) + } + if drsObj.Checksums.MD5 != "md5-hash" { + t.Errorf("expected MD5 %s, got %s", "md5-hash", drsObj.Checksums.MD5) + } + if len(drsObj.AccessMethods) != 1 { + t.Errorf("expected 1 access method, got %d", len(drsObj.AccessMethods)) + } + if drsObj.AccessMethods[0].AccessURL.URL != "s3://bucket/key" { + t.Errorf("expected access URL s3://bucket/key, got %s", drsObj.AccessMethods[0].AccessURL.URL) + } +} diff --git a/client/logs/factory.go b/logs/factory.go similarity index 66% rename from client/logs/factory.go rename to logs/factory.go index f63a63b..5a428f5 100644 --- a/client/logs/factory.go +++ b/logs/factory.go @@ -2,14 +2,14 @@ package logs import ( "fmt" - "io" + "log/slog" "os" "os/user" "path/filepath" "time" ) -func New(profile string, opts ...Option) (*TeeLogger, func()) { +func New(profile string, opts ...Option) (*Gen3Logger, func()) { cfg := defaults() for _, o := range opts { o(cfg) @@ -19,15 +19,15 @@ func New(profile string, opts ...Option) (*TeeLogger, func()) { logDir := filepath.Join(usr.HomeDir, ".gen3", "logs") os.MkdirAll(logDir, 0755) - var writers []io.Writer + var handlers []slog.Handler var messageFile *os.File if cfg.baseLogger != nil { - writers = append(writers, cfg.baseLogger.Writer()) + handlers = append(handlers, cfg.baseLogger.Handler()) } if cfg.console { - writers = append(writers, os.Stderr) + handlers = append(handlers, slog.NewTextHandler(os.Stderr, nil)) } if cfg.messageFile { @@ -39,12 +39,23 @@ func New(profile string, opts ...Option) (*TeeLogger, func()) { f, err := os.OpenFile(filepath.Join(logDir, filename), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err == nil { messageFile = f - writers = append(writers, f) + handlers = append(handlers, slog.NewTextHandler(f, nil)) fmt.Fprintf(f, "[%s] Message log started\n", time.Now().Format(time.RFC3339)) } } - t := NewTeeLogger(logDir, profile, writers...) + var rootHandler slog.Handler + if len(handlers) == 0 { + rootHandler = slog.NewTextHandler(os.Stderr, nil) + } else if len(handlers) == 1 { + rootHandler = handlers[0] + } else { + rootHandler = NewTeeHandler(handlers...) + } + + sl := slog.New(NewProgressHandler(rootHandler)) + + t := NewGen3Logger(sl, logDir, profile) if cfg.enableScoreboard { t.scoreboard = NewSB(5, t) diff --git a/logs/handler.go b/logs/handler.go new file mode 100644 index 0000000..d751714 --- /dev/null +++ b/logs/handler.go @@ -0,0 +1,102 @@ +package logs + +import ( + "context" + "log/slog" + + "github.com/calypr/data-client/common" +) + +// ProgressHandler is a slog.Handler that captures log messages and +// forwards them to a ProgressCallback if one is present in the context. +type ProgressHandler struct { + next slog.Handler +} + +func NewProgressHandler(next slog.Handler) *ProgressHandler { + if next == nil { + next = slog.Default().Handler() + } + return &ProgressHandler{next: next} +} + +func (h *ProgressHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.next.Enabled(ctx, level) +} + +func (h *ProgressHandler) Handle(ctx context.Context, r slog.Record) error { + // Call the next handler in the chain (original logging) + err := h.next.Handle(ctx, r) + + // In addition, try to bubble up to progress callback + cb := common.GetProgress(ctx) + if cb != nil { + oid := common.GetOid(ctx) + // We send an event of type "log" + attrs := make(map[string]any) + r.Attrs(func(a slog.Attr) bool { + attrs[a.Key] = a.Value.Any() + return true + }) + _ = cb(common.ProgressEvent{ + Event: "log", + Oid: oid, + Message: r.Message, + Level: r.Level.String(), + Attrs: attrs, + }) + } + + return err +} + +func (h *ProgressHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &ProgressHandler{next: h.next.WithAttrs(attrs)} +} + +func (h *ProgressHandler) WithGroup(name string) slog.Handler { + return &ProgressHandler{next: h.next.WithGroup(name)} +} + +// TeeHandler fans out log records to multiple handlers +type TeeHandler struct { + handlers []slog.Handler +} + +func NewTeeHandler(handlers ...slog.Handler) slog.Handler { + return &TeeHandler{handlers: handlers} +} + +func (h *TeeHandler) Enabled(ctx context.Context, level slog.Level) bool { + for _, hand := range h.handlers { + if hand.Enabled(ctx, level) { + return true + } + } + return false +} + +func (h *TeeHandler) Handle(ctx context.Context, r slog.Record) error { + for _, hand := range h.handlers { + if hand.Enabled(ctx, r.Level) { + _ = hand.Handle(ctx, r) + } + } + return nil +} + +func (h *TeeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + newHandlers := make([]slog.Handler, len(h.handlers)) + for i, hand := range h.handlers { + newHandlers[i] = hand.WithAttrs(attrs) + } + return &TeeHandler{handlers: newHandlers} +} + +func (h *TeeHandler) WithGroup(name string) slog.Handler { + newHandlers := make([]slog.Handler, len(h.handlers)) + for i, hand := range h.handlers { + newHandlers[i] = hand.WithGroup(name) + } + return &TeeHandler{handlers: newHandlers} +} diff --git a/logs/logger.go b/logs/logger.go new file mode 100644 index 0000000..f6a55f0 --- /dev/null +++ b/logs/logger.go @@ -0,0 +1,35 @@ +package logs + +import ( + "log/slog" +) + +type Option func(*config) + +type config struct { + console bool + messageFile bool + failedLog bool + succeededLog bool + enableScoreboard bool + baseLogger *slog.Logger +} + +func WithConsole() Option { return func(c *config) { c.console = true } } +func WithNoConsole() Option { return func(c *config) { c.console = false } } +func WithMessageFile() Option { return func(c *config) { c.messageFile = true } } +func WithNoMessageFile() Option { return func(c *config) { c.messageFile = false } } +func WithFailedLog() Option { return func(c *config) { c.failedLog = true } } +func WithSucceededLog() Option { return func(c *config) { c.succeededLog = true } } +func WithScoreboard() Option { return func(c *config) { c.enableScoreboard = true } } +func WithBaseLogger(base *slog.Logger) Option { return func(c *config) { c.baseLogger = base } } + +func defaults() *config { + return &config{ + console: true, + messageFile: true, + failedLog: true, + succeededLog: true, + baseLogger: nil, + } +} diff --git a/logs/logger_test.go b/logs/logger_test.go new file mode 100644 index 0000000..7e689f8 --- /dev/null +++ b/logs/logger_test.go @@ -0,0 +1,210 @@ +package logs + +import ( + "io" + "log/slog" + "os" + "testing" +) + +func TestNewSlogNoOpLogger(t *testing.T) { + logger := NewSlogNoOpLogger() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + // Verify it's a valid slog.Logger + logger.Info("test message") // Should not panic + logger.Error("test error") // Should not panic +} + +func TestNew_WithDefaults(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + if logger.Logger == nil { + t.Error("Expected non-nil embedded slog logger") + } +} + +func TestNew_WithConsoleOption(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile, WithConsole()) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + // Test that we can log without errors + logger.Info("test console message") +} + +func TestNew_WithMessageFileOption(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile, WithMessageFile()) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + // Test that we can log without errors + logger.Info("test file message") +} + +func TestNew_WithScoreboardOption(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile, WithScoreboard()) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + if logger.scoreboard == nil { + t.Error("Expected non-nil scoreboard when WithScoreboard option is used") + } +} + +func TestNew_WithFailedLogOption(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile, WithFailedLog()) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + if logger.failedPath == "" { + t.Error("Expected non-empty failed path when WithFailedLog option is used") + } +} + +func TestNew_WithSucceededLogOption(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile, WithSucceededLog()) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + if logger.succeededPath == "" { + t.Error("Expected non-empty succeeded path when WithSucceededLog option is used") + } +} + +func TestNew_WithBaseLogger(t *testing.T) { + profile := "test-profile" + baseLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + logger, cleanup := New(profile, WithBaseLogger(baseLogger)) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + // Test that we can log without errors + logger.Info("test with base logger") +} + +func TestNew_WithMultipleOptions(t *testing.T) { + profile := "test-profile" + baseLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + logger, cleanup := New(profile, + WithBaseLogger(baseLogger), + WithConsole(), + WithMessageFile(), + WithScoreboard(), + ) + defer cleanup() + + if logger == nil { + t.Fatal("Expected non-nil logger") + } + + if logger.Logger == nil { + t.Error("Expected non-nil embedded slog logger") + } + + if logger.scoreboard == nil { + t.Error("Expected non-nil scoreboard") + } + + // Test that we can log without errors + logger.Info("test with multiple options") +} + +func TestGen3Logger_Info(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Info("test info message") +} + +func TestGen3Logger_Error(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Error("test error message") +} + +func TestGen3Logger_Warn(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Warn("test warning message") +} + +func TestGen3Logger_Debug(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Debug("test debug message") +} + +func TestGen3Logger_Printf(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Printf("test printf message: %s", "value") +} + +func TestGen3Logger_Println(t *testing.T) { + profile := "test-profile" + logger, cleanup := New(profile) + defer cleanup() + + // Should not panic + logger.Println("test println message") +} + +// testLogger implements the Logger interface for testing +type testLogger struct { + writer io.Writer +} + +func (l *testLogger) Printf(format string, v ...any) {} +func (l *testLogger) Println(v ...any) {} +func (l *testLogger) Fatalf(format string, v ...any) {} +func (l *testLogger) Fatal(v ...any) {} +func (l *testLogger) Writer() io.Writer { return l.writer } diff --git a/logs/noop.go b/logs/noop.go new file mode 100644 index 0000000..f705772 --- /dev/null +++ b/logs/noop.go @@ -0,0 +1,11 @@ +package logs + +import ( + "io" + "log/slog" +) + +// NewSlogNoOpLogger creates a no-op slog logger for testing. +func NewSlogNoOpLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} diff --git a/client/logs/scoreboard.go b/logs/scoreboard.go similarity index 95% rename from client/logs/scoreboard.go rename to logs/scoreboard.go index c2f4d25..bf43083 100644 --- a/client/logs/scoreboard.go +++ b/logs/scoreboard.go @@ -10,12 +10,12 @@ import ( type Scoreboard struct { mu sync.Mutex Counts []int // index 0 = success on first try, 1 = after 1 retry, ..., last = failed - log Logger + log *Gen3Logger } // New creates a new scoreboard // maxRetryCount = how many retries you allow before giving up -func NewSB(maxRetryCount int, log Logger) *Scoreboard { +func NewSB(maxRetryCount int, log *Gen3Logger) *Scoreboard { return &Scoreboard{ Counts: make([]int, maxRetryCount+2), // +2: one for success-on-first, one for final failure log: log, diff --git a/logs/tee_logger.go b/logs/tee_logger.go new file mode 100644 index 0000000..228e06b --- /dev/null +++ b/logs/tee_logger.go @@ -0,0 +1,217 @@ +package logs + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "os" + "runtime" + "sync" + "time" + + "log/slog" + + "github.com/calypr/data-client/common" +) + +// --- Gen3Logger Implementation --- +type Gen3Logger struct { + *slog.Logger + mu sync.RWMutex + scoreboard *Scoreboard + + failedMu sync.Mutex + FailedMap map[string]common.RetryObject // Maps filePath to FileMetadata + failedPath string + + succeededMu sync.Mutex + succeededMap map[string]string // Maps filePath to GUID + succeededPath string +} + +// NewGen3Logger creates a new Gen3Logger wrapping the provided slog.Logger. +func NewGen3Logger(logger *slog.Logger, logDir, profile string) *Gen3Logger { + if logger == nil { + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } + return &Gen3Logger{ + Logger: logger, + FailedMap: make(map[string]common.RetryObject), + succeededMap: make(map[string]string), + } +} + +// loadJSON is an internal helper to load JSON from a file path. +func loadJSON(path string, v any) { + data, _ := os.ReadFile(path) + if len(data) > 0 { + json.Unmarshal(data, v) + } +} + +// --- Core logging helper --- + +// logWithSkip logs a message at the given level, skipping `skip` stack frames for source attribution. +func (t *Gen3Logger) logWithSkip(ctx context.Context, level slog.Level, skip int, msg string, args ...any) { + if !t.Enabled(ctx, level) { + return + } + var pcs [1]uintptr + runtime.Callers(skip, pcs[:]) + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.Add(args...) + _ = t.Handler().Handle(ctx, r) +} + +// --- slog.Logger Method Overrides for accurate source attribution --- + +func (t *Gen3Logger) Info(msg string, args ...any) { + t.logWithSkip(context.Background(), slog.LevelInfo, 3, msg, args...) +} + +func (t *Gen3Logger) InfoContext(ctx context.Context, msg string, args ...any) { + t.logWithSkip(ctx, slog.LevelInfo, 3, msg, args...) +} + +func (t *Gen3Logger) Error(msg string, args ...any) { + t.logWithSkip(context.Background(), slog.LevelError, 3, msg, args...) +} + +func (t *Gen3Logger) ErrorContext(ctx context.Context, msg string, args ...any) { + t.logWithSkip(ctx, slog.LevelError, 3, msg, args...) +} + +func (t *Gen3Logger) Warn(msg string, args ...any) { + t.logWithSkip(context.Background(), slog.LevelWarn, 3, msg, args...) +} + +func (t *Gen3Logger) WarnContext(ctx context.Context, msg string, args ...any) { + t.logWithSkip(ctx, slog.LevelWarn, 3, msg, args...) +} + +func (t *Gen3Logger) Debug(msg string, args ...any) { + t.logWithSkip(context.Background(), slog.LevelDebug, 3, msg, args...) +} + +func (t *Gen3Logger) DebugContext(ctx context.Context, msg string, args ...any) { + t.logWithSkip(ctx, slog.LevelDebug, 3, msg, args...) +} + +// --- Legacy fmt-style methods --- + +func (t *Gen3Logger) Printf(format string, v ...any) { + t.logWithSkip(context.Background(), slog.LevelInfo, 3, fmt.Sprintf(format, v...)) +} + +func (t *Gen3Logger) Println(v ...any) { + t.logWithSkip(context.Background(), slog.LevelInfo, 3, fmt.Sprint(v...)) +} + +func (t *Gen3Logger) Fatalf(format string, v ...any) { + t.logWithSkip(context.Background(), slog.LevelError, 3, fmt.Sprintf(format, v...)) + os.Exit(1) +} + +func (t *Gen3Logger) Fatal(v ...any) { + t.logWithSkip(context.Background(), slog.LevelError, 3, fmt.Sprint(v...)) + os.Exit(1) +} + +// Writer returns os.Stderr for legacy compatibility (used by Scoreboard's tabwriter). +func (t *Gen3Logger) Writer() io.Writer { + return os.Stderr +} + +// Scoreboard returns the embedded Scoreboard. +func (t *Gen3Logger) Scoreboard() *Scoreboard { + return t.scoreboard +} + +// --- Succeeded/Failed log map methods --- + +func (t *Gen3Logger) GetSucceededLogMap() map[string]string { + t.succeededMu.Lock() + defer t.succeededMu.Unlock() + copiedMap := make(map[string]string, len(t.succeededMap)) + maps.Copy(copiedMap, t.succeededMap) + return copiedMap +} + +func (t *Gen3Logger) GetFailedLogMap() map[string]common.RetryObject { + t.failedMu.Lock() + defer t.failedMu.Unlock() + copiedMap := make(map[string]common.RetryObject, len(t.FailedMap)) + maps.Copy(copiedMap, t.FailedMap) + return copiedMap +} + +func (t *Gen3Logger) DeleteFromFailedLog(path string) { + t.failedMu.Lock() + defer t.failedMu.Unlock() + delete(t.FailedMap, path) +} + +func (t *Gen3Logger) GetSucceededCount() int { + return len(t.succeededMap) +} + +func (t *Gen3Logger) writeFailedSync(e common.RetryObject) { + t.failedMu.Lock() + defer t.failedMu.Unlock() + t.FailedMap[e.SourcePath] = e + data, _ := json.MarshalIndent(t.FailedMap, "", " ") + os.WriteFile(t.failedPath, data, 0644) +} + +func (t *Gen3Logger) writeSucceededSync(path, guid string) { + t.succeededMu.Lock() + defer t.succeededMu.Unlock() + t.succeededMap[path] = guid + data, _ := json.MarshalIndent(t.succeededMap, "", " ") + os.WriteFile(t.succeededPath, data, 0644) +} + +// --- Tracking Methods --- + +// --- Tracking Methods --- + +func (t *Gen3Logger) Failed(filePath, filename string, metadata common.FileMetadata, guid string, retryCount int, multipart bool) { + t.failedHelper(context.Background(), filePath, filename, metadata, guid, retryCount, multipart, 4) +} + +func (t *Gen3Logger) FailedContext(ctx context.Context, filePath, filename string, metadata common.FileMetadata, guid string, retryCount int, multipart bool) { + t.failedHelper(ctx, filePath, filename, metadata, guid, retryCount, multipart, 4) +} + +func (t *Gen3Logger) failedHelper(ctx context.Context, filePath, filename string, metadata common.FileMetadata, guid string, retryCount int, multipart bool, skip int) { + msg := fmt.Sprintf("Failed: %s (GUID: %s, Retry: %d)", filePath, guid, retryCount) + t.logWithSkip(ctx, slog.LevelError, skip, msg) + if t.failedPath != "" { + t.writeFailedSync(common.RetryObject{ + SourcePath: filePath, + ObjectKey: filename, + FileMetadata: metadata, + GUID: guid, + RetryCount: retryCount, + Multipart: multipart, + }) + } +} + +func (t *Gen3Logger) Succeeded(filePath, guid string) { + t.succeededHelper(context.Background(), filePath, guid, 4) +} + +func (t *Gen3Logger) SucceededContext(ctx context.Context, filePath, guid string) { + t.succeededHelper(ctx, filePath, guid, 4) +} + +func (t *Gen3Logger) succeededHelper(ctx context.Context, filePath, guid string, skip int) { + msg := fmt.Sprintf("Succeeded: %s (GUID: %s)", filePath, guid) + t.logWithSkip(ctx, slog.LevelInfo, skip, msg) + if t.succeededPath != "" { + t.writeSucceededSync(filePath, guid) + } +} diff --git a/client/mocks/mock_configure.go b/mocks/mock_configure.go similarity index 94% rename from client/mocks/mock_configure.go rename to mocks/mock_configure.go index 4ff0813..48aa6bc 100644 --- a/client/mocks/mock_configure.go +++ b/mocks/mock_configure.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/conf (interfaces: ManagerInterface) +// Source: github.com/calypr/data-client/conf (interfaces: ManagerInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/client/conf ManagerInterface +// mockgen -destination=../mocks/mock_configure.go -package=mocks github.com/calypr/data-client/conf ManagerInterface // // Package mocks is a generated GoMock package. @@ -12,7 +12,7 @@ package mocks import ( reflect "reflect" - conf "github.com/calypr/data-client/client/conf" + conf "github.com/calypr/data-client/conf" gomock "go.uber.org/mock/gomock" ) diff --git a/mocks/mock_fence.go b/mocks/mock_fence.go new file mode 100644 index 0000000..f2577d0 --- /dev/null +++ b/mocks/mock_fence.go @@ -0,0 +1,252 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/calypr/data-client/fence (interfaces: FenceInterface) +// +// Generated by this command: +// +// mockgen -destination=../mocks/mock_fence.go -package=mocks github.com/calypr/data-client/fence FenceInterface +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + http "net/http" + reflect "reflect" + + fence "github.com/calypr/data-client/fence" + request "github.com/calypr/data-client/request" + gomock "go.uber.org/mock/gomock" +) + +// MockFenceInterface is a mock of FenceInterface interface. +type MockFenceInterface struct { + ctrl *gomock.Controller + recorder *MockFenceInterfaceMockRecorder + isgomock struct{} +} + +// MockFenceInterfaceMockRecorder is the mock recorder for MockFenceInterface. +type MockFenceInterfaceMockRecorder struct { + mock *MockFenceInterface +} + +// NewMockFenceInterface creates a new mock instance. +func NewMockFenceInterface(ctrl *gomock.Controller) *MockFenceInterface { + mock := &MockFenceInterface{ctrl: ctrl} + mock.recorder = &MockFenceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFenceInterface) EXPECT() *MockFenceInterfaceMockRecorder { + return m.recorder +} + +// CheckForShepherdAPI mocks base method. +func (m *MockFenceInterface) CheckForShepherdAPI(ctx context.Context) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckForShepherdAPI", ctx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckForShepherdAPI indicates an expected call of CheckForShepherdAPI. +func (mr *MockFenceInterfaceMockRecorder) CheckForShepherdAPI(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckForShepherdAPI", reflect.TypeOf((*MockFenceInterface)(nil).CheckForShepherdAPI), ctx) +} + +// CheckPrivileges mocks base method. +func (m *MockFenceInterface) CheckPrivileges(ctx context.Context) (map[string]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckPrivileges", ctx) + ret0, _ := ret[0].(map[string]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckPrivileges indicates an expected call of CheckPrivileges. +func (mr *MockFenceInterfaceMockRecorder) CheckPrivileges(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckPrivileges", reflect.TypeOf((*MockFenceInterface)(nil).CheckPrivileges), ctx) +} + +// CompleteMultipartUpload mocks base method. +func (m *MockFenceInterface) CompleteMultipartUpload(ctx context.Context, key, uploadID string, parts []fence.MultipartPart, bucket string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CompleteMultipartUpload", ctx, key, uploadID, parts, bucket) + ret0, _ := ret[0].(error) + return ret0 +} + +// CompleteMultipartUpload indicates an expected call of CompleteMultipartUpload. +func (mr *MockFenceInterfaceMockRecorder) CompleteMultipartUpload(ctx, key, uploadID, parts, bucket any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompleteMultipartUpload", reflect.TypeOf((*MockFenceInterface)(nil).CompleteMultipartUpload), ctx, key, uploadID, parts, bucket) +} + +// DeleteRecord mocks base method. +func (m *MockFenceInterface) DeleteRecord(ctx context.Context, guid string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRecord", ctx, guid) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteRecord indicates an expected call of DeleteRecord. +func (mr *MockFenceInterfaceMockRecorder) DeleteRecord(ctx, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecord", reflect.TypeOf((*MockFenceInterface)(nil).DeleteRecord), ctx, guid) +} + +// Do mocks base method. +func (m *MockFenceInterface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", ctx, req) + ret0, _ := ret[0].(*http.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Do indicates an expected call of Do. +func (mr *MockFenceInterfaceMockRecorder) Do(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockFenceInterface)(nil).Do), ctx, req) +} + +// GenerateMultipartPresignedURL mocks base method. +func (m *MockFenceInterface) GenerateMultipartPresignedURL(ctx context.Context, key, uploadID string, partNumber int, bucket string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateMultipartPresignedURL", ctx, key, uploadID, partNumber, bucket) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateMultipartPresignedURL indicates an expected call of GenerateMultipartPresignedURL. +func (mr *MockFenceInterfaceMockRecorder) GenerateMultipartPresignedURL(ctx, key, uploadID, partNumber, bucket any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateMultipartPresignedURL", reflect.TypeOf((*MockFenceInterface)(nil).GenerateMultipartPresignedURL), ctx, key, uploadID, partNumber, bucket) +} + +// GetBucketDetails mocks base method. +func (m *MockFenceInterface) GetBucketDetails(ctx context.Context, bucket string) (*fence.S3Bucket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBucketDetails", ctx, bucket) + ret0, _ := ret[0].(*fence.S3Bucket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBucketDetails indicates an expected call of GetBucketDetails. +func (mr *MockFenceInterfaceMockRecorder) GetBucketDetails(ctx, bucket any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBucketDetails", reflect.TypeOf((*MockFenceInterface)(nil).GetBucketDetails), ctx, bucket) +} + +// GetDownloadPresignedUrl mocks base method. +func (m *MockFenceInterface) GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDownloadPresignedUrl", ctx, guid, protocolText) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDownloadPresignedUrl indicates an expected call of GetDownloadPresignedUrl. +func (mr *MockFenceInterfaceMockRecorder) GetDownloadPresignedUrl(ctx, guid, protocolText any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDownloadPresignedUrl", reflect.TypeOf((*MockFenceInterface)(nil).GetDownloadPresignedUrl), ctx, guid, protocolText) +} + +// GetUploadPresignedUrl mocks base method. +func (m *MockFenceInterface) GetUploadPresignedUrl(ctx context.Context, guid, filename, bucket string) (fence.FenceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUploadPresignedUrl", ctx, guid, filename, bucket) + ret0, _ := ret[0].(fence.FenceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUploadPresignedUrl indicates an expected call of GetUploadPresignedUrl. +func (mr *MockFenceInterfaceMockRecorder) GetUploadPresignedUrl(ctx, guid, filename, bucket any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUploadPresignedUrl", reflect.TypeOf((*MockFenceInterface)(nil).GetUploadPresignedUrl), ctx, guid, filename, bucket) +} + +// InitMultipartUpload mocks base method. +func (m *MockFenceInterface) InitMultipartUpload(ctx context.Context, filename, bucket, guid string) (fence.FenceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InitMultipartUpload", ctx, filename, bucket, guid) + ret0, _ := ret[0].(fence.FenceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InitMultipartUpload indicates an expected call of InitMultipartUpload. +func (mr *MockFenceInterfaceMockRecorder) InitMultipartUpload(ctx, filename, bucket, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitMultipartUpload", reflect.TypeOf((*MockFenceInterface)(nil).InitMultipartUpload), ctx, filename, bucket, guid) +} + +// InitUpload mocks base method. +func (m *MockFenceInterface) InitUpload(ctx context.Context, filename, bucket, guid string) (fence.FenceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InitUpload", ctx, filename, bucket, guid) + ret0, _ := ret[0].(fence.FenceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InitUpload indicates an expected call of InitUpload. +func (mr *MockFenceInterfaceMockRecorder) InitUpload(ctx, filename, bucket, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitUpload", reflect.TypeOf((*MockFenceInterface)(nil).InitUpload), ctx, filename, bucket, guid) +} + +// New mocks base method. +func (m *MockFenceInterface) New(method, url string) *request.RequestBuilder { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "New", method, url) + ret0, _ := ret[0].(*request.RequestBuilder) + return ret0 +} + +// New indicates an expected call of New. +func (mr *MockFenceInterfaceMockRecorder) New(method, url any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockFenceInterface)(nil).New), method, url) +} + +// NewAccessToken mocks base method. +func (m *MockFenceInterface) NewAccessToken(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAccessToken", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewAccessToken indicates an expected call of NewAccessToken. +func (mr *MockFenceInterfaceMockRecorder) NewAccessToken(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAccessToken", reflect.TypeOf((*MockFenceInterface)(nil).NewAccessToken), ctx) +} + +// ParseFenceURLResponse mocks base method. +func (m *MockFenceInterface) ParseFenceURLResponse(resp *http.Response) (fence.FenceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseFenceURLResponse", resp) + ret0, _ := ret[0].(fence.FenceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ParseFenceURLResponse indicates an expected call of ParseFenceURLResponse. +func (mr *MockFenceInterfaceMockRecorder) ParseFenceURLResponse(resp any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseFenceURLResponse", reflect.TypeOf((*MockFenceInterface)(nil).ParseFenceURLResponse), resp) +} diff --git a/client/mocks/mock_functions.go b/mocks/mock_functions.go similarity index 76% rename from client/mocks/mock_functions.go rename to mocks/mock_functions.go index de3b1bc..9f905fd 100644 --- a/client/mocks/mock_functions.go +++ b/mocks/mock_functions.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/api (interfaces: FunctionInterface) +// Source: github.com/calypr/data-client/api (interfaces: FunctionInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/client/api FunctionInterface +// mockgen -destination=../mocks/mock_functions.go -package=mocks github.com/calypr/data-client/api FunctionInterface // // Package mocks is a generated GoMock package. @@ -14,9 +14,8 @@ import ( http "net/http" reflect "reflect" - api "github.com/calypr/data-client/client/api" - conf "github.com/calypr/data-client/client/conf" - request "github.com/calypr/data-client/client/request" + conf "github.com/calypr/data-client/conf" + request "github.com/calypr/data-client/request" gomock "go.uber.org/mock/gomock" ) @@ -118,19 +117,19 @@ func (mr *MockFunctionInterfaceMockRecorder) ExportCredential(ctx, cred any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportCredential", reflect.TypeOf((*MockFunctionInterface)(nil).ExportCredential), ctx, cred) } -// GetPresignedUrl mocks base method. -func (m *MockFunctionInterface) GetPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { +// GetDownloadPresignedUrl mocks base method. +func (m *MockFunctionInterface) GetDownloadPresignedUrl(ctx context.Context, guid, protocolText string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPresignedUrl", ctx, guid, protocolText) + ret := m.ctrl.Call(m, "GetDownloadPresignedUrl", ctx, guid, protocolText) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetPresignedUrl indicates an expected call of GetPresignedUrl. -func (mr *MockFunctionInterfaceMockRecorder) GetPresignedUrl(ctx, guid, protocolText any) *gomock.Call { +// GetDownloadPresignedUrl indicates an expected call of GetDownloadPresignedUrl. +func (mr *MockFunctionInterfaceMockRecorder) GetDownloadPresignedUrl(ctx, guid, protocolText any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPresignedUrl", reflect.TypeOf((*MockFunctionInterface)(nil).GetPresignedUrl), ctx, guid, protocolText) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDownloadPresignedUrl", reflect.TypeOf((*MockFunctionInterface)(nil).GetDownloadPresignedUrl), ctx, guid, protocolText) } // New mocks base method. @@ -147,17 +146,16 @@ func (mr *MockFunctionInterfaceMockRecorder) New(method, url any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockFunctionInterface)(nil).New), method, url) } -// ParseFenceURLResponse mocks base method. -func (m *MockFunctionInterface) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { +// NewAccessToken mocks base method. +func (m *MockFunctionInterface) NewAccessToken(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ParseFenceURLResponse", resp) - ret0, _ := ret[0].(api.FenceResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "NewAccessToken", ctx) + ret0, _ := ret[0].(error) + return ret0 } -// ParseFenceURLResponse indicates an expected call of ParseFenceURLResponse. -func (mr *MockFunctionInterfaceMockRecorder) ParseFenceURLResponse(resp any) *gomock.Call { +// NewAccessToken indicates an expected call of NewAccessToken. +func (mr *MockFunctionInterfaceMockRecorder) NewAccessToken(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseFenceURLResponse", reflect.TypeOf((*MockFunctionInterface)(nil).ParseFenceURLResponse), resp) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAccessToken", reflect.TypeOf((*MockFunctionInterface)(nil).NewAccessToken), ctx) } diff --git a/mocks/mock_gen3interface.go b/mocks/mock_gen3interface.go new file mode 100644 index 0000000..a627c69 --- /dev/null +++ b/mocks/mock_gen3interface.go @@ -0,0 +1,115 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/calypr/data-client/g3client (interfaces: Gen3Interface) +// +// Generated by this command: +// +// mockgen -destination=../mocks/mock_gen3interface.go -package=mocks github.com/calypr/data-client/g3client Gen3Interface +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + conf "github.com/calypr/data-client/conf" + fence "github.com/calypr/data-client/fence" + indexd "github.com/calypr/data-client/indexd" + logs "github.com/calypr/data-client/logs" + gomock "go.uber.org/mock/gomock" +) + +// MockGen3Interface is a mock of Gen3Interface interface. +type MockGen3Interface struct { + ctrl *gomock.Controller + recorder *MockGen3InterfaceMockRecorder + isgomock struct{} +} + +// MockGen3InterfaceMockRecorder is the mock recorder for MockGen3Interface. +type MockGen3InterfaceMockRecorder struct { + mock *MockGen3Interface +} + +// NewMockGen3Interface creates a new mock instance. +func NewMockGen3Interface(ctrl *gomock.Controller) *MockGen3Interface { + mock := &MockGen3Interface{ctrl: ctrl} + mock.recorder = &MockGen3InterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockGen3Interface) EXPECT() *MockGen3InterfaceMockRecorder { + return m.recorder +} + +// ExportCredential mocks base method. +func (m *MockGen3Interface) ExportCredential(ctx context.Context, cred *conf.Credential) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExportCredential", ctx, cred) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExportCredential indicates an expected call of ExportCredential. +func (mr *MockGen3InterfaceMockRecorder) ExportCredential(ctx, cred any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportCredential", reflect.TypeOf((*MockGen3Interface)(nil).ExportCredential), ctx, cred) +} + +// Fence mocks base method. +func (m *MockGen3Interface) Fence() fence.FenceInterface { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fence") + ret0, _ := ret[0].(fence.FenceInterface) + return ret0 +} + +// Fence indicates an expected call of Fence. +func (mr *MockGen3InterfaceMockRecorder) Fence() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fence", reflect.TypeOf((*MockGen3Interface)(nil).Fence)) +} + +// GetCredential mocks base method. +func (m *MockGen3Interface) GetCredential() *conf.Credential { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCredential") + ret0, _ := ret[0].(*conf.Credential) + return ret0 +} + +// GetCredential indicates an expected call of GetCredential. +func (mr *MockGen3InterfaceMockRecorder) GetCredential() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockGen3Interface)(nil).GetCredential)) +} + +// Indexd mocks base method. +func (m *MockGen3Interface) Indexd() indexd.IndexdInterface { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Indexd") + ret0, _ := ret[0].(indexd.IndexdInterface) + return ret0 +} + +// Indexd indicates an expected call of Indexd. +func (mr *MockGen3InterfaceMockRecorder) Indexd() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexd", reflect.TypeOf((*MockGen3Interface)(nil).Indexd)) +} + +// Logger mocks base method. +func (m *MockGen3Interface) Logger() *logs.Gen3Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*logs.Gen3Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockGen3InterfaceMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockGen3Interface)(nil).Logger)) +} diff --git a/mocks/mock_indexd.go b/mocks/mock_indexd.go new file mode 100644 index 0000000..6c0d5e2 --- /dev/null +++ b/mocks/mock_indexd.go @@ -0,0 +1,251 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/calypr/data-client/indexd (interfaces: IndexdInterface) +// +// Generated by this command: +// +// mockgen -destination=../mocks/mock_indexd.go -package=mocks github.com/calypr/data-client/indexd IndexdInterface +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + http "net/http" + reflect "reflect" + + indexd "github.com/calypr/data-client/indexd" + drs "github.com/calypr/data-client/indexd/drs" + request "github.com/calypr/data-client/request" + gomock "go.uber.org/mock/gomock" +) + +// MockIndexdInterface is a mock of IndexdInterface interface. +type MockIndexdInterface struct { + ctrl *gomock.Controller + recorder *MockIndexdInterfaceMockRecorder + isgomock struct{} +} + +// MockIndexdInterfaceMockRecorder is the mock recorder for MockIndexdInterface. +type MockIndexdInterfaceMockRecorder struct { + mock *MockIndexdInterface +} + +// NewMockIndexdInterface creates a new mock instance. +func NewMockIndexdInterface(ctrl *gomock.Controller) *MockIndexdInterface { + mock := &MockIndexdInterface{ctrl: ctrl} + mock.recorder = &MockIndexdInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIndexdInterface) EXPECT() *MockIndexdInterfaceMockRecorder { + return m.recorder +} + +// DeleteIndexdRecord mocks base method. +func (m *MockIndexdInterface) DeleteIndexdRecord(ctx context.Context, did string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteIndexdRecord", ctx, did) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteIndexdRecord indicates an expected call of DeleteIndexdRecord. +func (mr *MockIndexdInterfaceMockRecorder) DeleteIndexdRecord(ctx, did any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteIndexdRecord", reflect.TypeOf((*MockIndexdInterface)(nil).DeleteIndexdRecord), ctx, did) +} + +// DeleteRecordByHash mocks base method. +func (m *MockIndexdInterface) DeleteRecordByHash(ctx context.Context, hashValue, projectId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRecordByHash", ctx, hashValue, projectId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRecordByHash indicates an expected call of DeleteRecordByHash. +func (mr *MockIndexdInterfaceMockRecorder) DeleteRecordByHash(ctx, hashValue, projectId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecordByHash", reflect.TypeOf((*MockIndexdInterface)(nil).DeleteRecordByHash), ctx, hashValue, projectId) +} + +// DeleteRecordsByProject mocks base method. +func (m *MockIndexdInterface) DeleteRecordsByProject(ctx context.Context, projectId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRecordsByProject", ctx, projectId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRecordsByProject indicates an expected call of DeleteRecordsByProject. +func (mr *MockIndexdInterfaceMockRecorder) DeleteRecordsByProject(ctx, projectId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRecordsByProject", reflect.TypeOf((*MockIndexdInterface)(nil).DeleteRecordsByProject), ctx, projectId) +} + +// Do mocks base method. +func (m *MockIndexdInterface) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", ctx, req) + ret0, _ := ret[0].(*http.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Do indicates an expected call of Do. +func (mr *MockIndexdInterfaceMockRecorder) Do(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockIndexdInterface)(nil).Do), ctx, req) +} + +// GetDownloadURL mocks base method. +func (m *MockIndexdInterface) GetDownloadURL(ctx context.Context, did, accessType string) (*drs.AccessURL, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDownloadURL", ctx, did, accessType) + ret0, _ := ret[0].(*drs.AccessURL) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDownloadURL indicates an expected call of GetDownloadURL. +func (mr *MockIndexdInterfaceMockRecorder) GetDownloadURL(ctx, did, accessType any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDownloadURL", reflect.TypeOf((*MockIndexdInterface)(nil).GetDownloadURL), ctx, did, accessType) +} + +// GetObject mocks base method. +func (m *MockIndexdInterface) GetObject(ctx context.Context, id string) (*drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetObject", ctx, id) + ret0, _ := ret[0].(*drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetObject indicates an expected call of GetObject. +func (mr *MockIndexdInterfaceMockRecorder) GetObject(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObject", reflect.TypeOf((*MockIndexdInterface)(nil).GetObject), ctx, id) +} + +// GetObjectByHash mocks base method. +func (m *MockIndexdInterface) GetObjectByHash(ctx context.Context, hashType, hashValue string) ([]drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetObjectByHash", ctx, hashType, hashValue) + ret0, _ := ret[0].([]drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetObjectByHash indicates an expected call of GetObjectByHash. +func (mr *MockIndexdInterfaceMockRecorder) GetObjectByHash(ctx, hashType, hashValue any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectByHash", reflect.TypeOf((*MockIndexdInterface)(nil).GetObjectByHash), ctx, hashType, hashValue) +} + +// GetProjectSample mocks base method. +func (m *MockIndexdInterface) GetProjectSample(ctx context.Context, projectId string, limit int) ([]drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProjectSample", ctx, projectId, limit) + ret0, _ := ret[0].([]drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProjectSample indicates an expected call of GetProjectSample. +func (mr *MockIndexdInterfaceMockRecorder) GetProjectSample(ctx, projectId, limit any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProjectSample", reflect.TypeOf((*MockIndexdInterface)(nil).GetProjectSample), ctx, projectId, limit) +} + +// ListObjects mocks base method. +func (m *MockIndexdInterface) ListObjects(ctx context.Context) (chan drs.DRSObjectResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListObjects", ctx) + ret0, _ := ret[0].(chan drs.DRSObjectResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListObjects indicates an expected call of ListObjects. +func (mr *MockIndexdInterfaceMockRecorder) ListObjects(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListObjects", reflect.TypeOf((*MockIndexdInterface)(nil).ListObjects), ctx) +} + +// ListObjectsByProject mocks base method. +func (m *MockIndexdInterface) ListObjectsByProject(ctx context.Context, projectId string) (chan drs.DRSObjectResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListObjectsByProject", ctx, projectId) + ret0, _ := ret[0].(chan drs.DRSObjectResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListObjectsByProject indicates an expected call of ListObjectsByProject. +func (mr *MockIndexdInterfaceMockRecorder) ListObjectsByProject(ctx, projectId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListObjectsByProject", reflect.TypeOf((*MockIndexdInterface)(nil).ListObjectsByProject), ctx, projectId) +} + +// New mocks base method. +func (m *MockIndexdInterface) New(method, url string) *request.RequestBuilder { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "New", method, url) + ret0, _ := ret[0].(*request.RequestBuilder) + return ret0 +} + +// New indicates an expected call of New. +func (mr *MockIndexdInterfaceMockRecorder) New(method, url any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockIndexdInterface)(nil).New), method, url) +} + +// RegisterIndexdRecord mocks base method. +func (m *MockIndexdInterface) RegisterIndexdRecord(ctx context.Context, indexdObj *indexd.IndexdRecord) (*drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterIndexdRecord", ctx, indexdObj) + ret0, _ := ret[0].(*drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegisterIndexdRecord indicates an expected call of RegisterIndexdRecord. +func (mr *MockIndexdInterfaceMockRecorder) RegisterIndexdRecord(ctx, indexdObj any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterIndexdRecord", reflect.TypeOf((*MockIndexdInterface)(nil).RegisterIndexdRecord), ctx, indexdObj) +} + +// RegisterRecord mocks base method. +func (m *MockIndexdInterface) RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterRecord", ctx, record) + ret0, _ := ret[0].(*drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegisterRecord indicates an expected call of RegisterRecord. +func (mr *MockIndexdInterfaceMockRecorder) RegisterRecord(ctx, record any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterRecord", reflect.TypeOf((*MockIndexdInterface)(nil).RegisterRecord), ctx, record) +} + +// UpdateRecord mocks base method. +func (m *MockIndexdInterface) UpdateRecord(ctx context.Context, updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateRecord", ctx, updateInfo, did) + ret0, _ := ret[0].(*drs.DRSObject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateRecord indicates an expected call of UpdateRecord. +func (mr *MockIndexdInterfaceMockRecorder) UpdateRecord(ctx, updateInfo, did any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRecord", reflect.TypeOf((*MockIndexdInterface)(nil).UpdateRecord), ctx, updateInfo, did) +} diff --git a/client/mocks/mock_request.go b/mocks/mock_request.go similarity index 91% rename from client/mocks/mock_request.go rename to mocks/mock_request.go index 1021d18..8ccd2a0 100644 --- a/client/mocks/mock_request.go +++ b/mocks/mock_request.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/calypr/data-client/client/request (interfaces: RequestInterface) +// Source: github.com/calypr/data-client/request (interfaces: RequestInterface) // // Generated by this command: // -// mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/request RequestInterface +// mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/request RequestInterface // // Package mocks is a generated GoMock package. @@ -14,7 +14,7 @@ import ( http "net/http" reflect "reflect" - request "github.com/calypr/data-client/client/request" + request "github.com/calypr/data-client/request" gomock "go.uber.org/mock/gomock" ) diff --git a/client/request/auth.go b/request/auth.go similarity index 90% rename from client/request/auth.go rename to request/auth.go index eb87829..cc93723 100644 --- a/client/request/auth.go +++ b/request/auth.go @@ -9,8 +9,8 @@ import ( "strconv" "sync" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" ) func (t *AuthTransport) NewAccessToken(ctx context.Context) error { @@ -66,6 +66,11 @@ type AuthTransport struct { } func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Header.Get("X-Skip-Auth") == "true" { + req.Header.Del("X-Skip-Auth") + return t.Base.RoundTrip(req) + } + t.mu.RLock() token := t.Cred.AccessToken t.mu.RUnlock() diff --git a/client/request/builder.go b/request/builder.go similarity index 87% rename from client/request/builder.go rename to request/builder.go index 1280fb7..e12e923 100644 --- a/client/request/builder.go +++ b/request/builder.go @@ -3,7 +3,7 @@ package request import ( "io" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) // New addition to your request package @@ -15,6 +15,7 @@ type RequestBuilder struct { Headers map[string]string Token string PartSize int64 + SkipAuth bool } func (r *Request) New(method, url string) *RequestBuilder { @@ -52,3 +53,8 @@ func (ar *RequestBuilder) WithHeader(key, value string) *RequestBuilder { ar.Headers[key] = value return ar } + +func (ar *RequestBuilder) WithSkipAuth(skip bool) *RequestBuilder { + ar.SkipAuth = skip + return ar +} diff --git a/client/request/request.go b/request/request.go similarity index 90% rename from client/request/request.go rename to request/request.go index 585624e..82711ba 100644 --- a/client/request/request.go +++ b/request/request.go @@ -1,6 +1,6 @@ package request -//go:generate mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/client/request RequestInterface +//go:generate mockgen -destination=../mocks/mock_request.go -package=mocks github.com/calypr/data-client/request RequestInterface import ( "context" @@ -9,13 +9,13 @@ import ( "net/http" "time" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/logs" "github.com/hashicorp/go-retryablehttp" ) type Request struct { - Logs logs.Logger + Logs *logs.Gen3Logger RetryClient *retryablehttp.Client } @@ -25,7 +25,7 @@ type RequestInterface interface { } func NewRequestInterface( - logger logs.Logger, + logger *logs.Gen3Logger, cred *conf.Credential, conf conf.ManagerInterface, ) RequestInterface { @@ -88,6 +88,10 @@ func (r *Request) Do(ctx context.Context, rb *RequestBuilder) (*http.Response, e httpReq.Header.Add(key, value) } + if rb.SkipAuth { + httpReq.Header.Set("X-Skip-Auth", "true") + } + if rb.Token != "" { httpReq.Header.Set("Authorization", "Bearer "+rb.Token) } diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 0000000..019b097 --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,263 @@ +package request + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/logs" +) + +func TestNewRequestInterface(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: "https://example.com", + } + + // Create a mock config manager + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + + if reqInterface == nil { + t.Fatal("Expected non-nil request interface") + } + + req, ok := reqInterface.(*Request) + if !ok { + t.Fatal("Expected request interface to be of type *Request") + } + + if req.RetryClient == nil { + t.Error("Expected non-nil retry client") + } + + if req.Logs == nil { + t.Error("Expected non-nil logger") + } +} + +func TestRequestBuilder_New(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: "https://example.com", + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + builder := req.New("GET", "https://example.com/api/test") + + if builder == nil { + t.Fatal("Expected non-nil request builder") + } + + if builder.Method != "GET" { + t.Errorf("Expected method 'GET', got '%s'", builder.Method) + } + + if builder.Url != "https://example.com/api/test" { + t.Errorf("Expected URL 'https://example.com/api/test', got '%s'", builder.Url) + } +} + +func TestRequestBuilder_WithHeaders(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: "https://example.com", + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + builder := req.New("GET", "https://example.com/api/test") + builder = builder.WithHeader("Content-Type", "application/json") + builder = builder.WithHeader("X-Custom-Header", "test-value") + + if len(builder.Headers) != 2 { + t.Errorf("Expected 2 headers, got %d", len(builder.Headers)) + } + + if builder.Headers["Content-Type"] != "application/json" { + t.Error("Expected Content-Type header to be set") + } + + if builder.Headers["X-Custom-Header"] != "test-value" { + t.Error("Expected X-Custom-Header to be set") + } +} + +func TestRequestBuilder_WithToken(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: "https://example.com", + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + token := "test-bearer-token-12345" + builder := req.New("GET", "https://example.com/api/test") + builder = builder.WithToken(token) + + if builder.Token != token { + t.Errorf("Expected token '%s', got '%s'", token, builder.Token) + } +} + +func TestRequestBuilder_WithBody(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: "https://example.com", + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + body := strings.NewReader("test body content") + builder := req.New("POST", "https://example.com/api/test") + builder = builder.WithBody(body) + + if builder.Body == nil { + t.Error("Expected non-nil body") + } +} + +func TestRequest_Do_Success(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request + if r.Method != "GET" { + t.Errorf("Expected GET method, got %s", r.Method) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "success"}`)) + })) + defer server.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: server.URL, + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + builder := req.New("GET", server.URL+"/api/test") + builder = builder.WithToken("test-token") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + resp, err := req.Do(ctx, builder) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "success") { + t.Error("Expected response body to contain 'success'") + } +} + +func TestRequest_Do_WithCustomHeaders(t *testing.T) { + // Create a test server that checks for custom headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + customHeader := r.Header.Get("X-Custom-Header") + if customHeader != "test-value" { + t.Errorf("Expected X-Custom-Header 'test-value', got '%s'", customHeader) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + cred := &conf.Credential{ + KeyID: "test-key", + APIKey: "test-secret", + APIEndpoint: server.URL, + } + mockConf := &mockConfigManager{} + + reqInterface := NewRequestInterface(logs.NewGen3Logger(logger, "", ""), cred, mockConf) + req := reqInterface.(*Request) + + builder := req.New("GET", server.URL+"/api/test") + builder = builder.WithHeader("X-Custom-Header", "test-value") + + ctx := context.Background() + resp, err := req.Do(ctx, builder) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + resp.Body.Close() +} + +// Mock config manager for testing +type mockConfigManager struct{} + +func (m *mockConfigManager) Import(filePath, fenceToken string) (*conf.Credential, error) { + return &conf.Credential{}, nil +} + +func (m *mockConfigManager) Load(profile string) (*conf.Credential, error) { + return &conf.Credential{}, nil +} + +func (m *mockConfigManager) Save(cred *conf.Credential) error { + return nil +} + +func (m *mockConfigManager) EnsureExists() error { + return nil +} + +func (m *mockConfigManager) IsCredentialValid(cred *conf.Credential) (bool, error) { + return true, nil +} + +func (m *mockConfigManager) IsTokenValid(token string) (bool, error) { + return true, nil +} diff --git a/requestor/client.go b/requestor/client.go new file mode 100644 index 0000000..c379c66 --- /dev/null +++ b/requestor/client.go @@ -0,0 +1,265 @@ +package requestor + +import ( + "context" + "embed" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strings" + + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/request" + "gopkg.in/yaml.v3" +) + +//go:embed policies/*.yaml +var policyFS embed.FS + +type RequestorClient struct { + request.RequestInterface + Endpoint string +} + +func NewRequestorClient(req request.RequestInterface, creds *conf.Credential) *RequestorClient { + return &RequestorClient{ + RequestInterface: req, + Endpoint: creds.APIEndpoint, + } +} + +// Ensure interface compliance +var _ RequestorInterface = &RequestorClient{} + +type RequestorInterface interface { + ListRequests(ctx context.Context, mine bool, active bool, username string) ([]Request, error) + CreateRequest(ctx context.Context, req CreateRequestRequest, revoke bool) (*Request, error) + UpdateRequest(ctx context.Context, requestID string, status string) (*Request, error) + AddUser(ctx context.Context, projectID string, username string, write bool, guppy bool) ([]Request, error) + RemoveUser(ctx context.Context, projectID string, username string) ([]Request, error) +} + +func (c *RequestorClient) ListRequests(ctx context.Context, mine bool, active bool, username string) ([]Request, error) { + url := c.Endpoint + "/requestor/request" + if mine { + url += "/user" + } + + params := []string{} + if active { + params = append(params, "active") + } + if username != "" && !mine { + params = append(params, fmt.Sprintf("username=%s", username)) + } + + if len(params) > 0 { + url += "?" + strings.Join(params, "&") + } + + rb := c.New(http.MethodGet, url) + resp, err := c.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to list requests: status %d", resp.StatusCode) + } + + var requests []Request + if err := json.NewDecoder(resp.Body).Decode(&requests); err != nil { + return nil, err + } + return requests, nil +} + +func (c *RequestorClient) CreateRequest(ctx context.Context, reqPayload CreateRequestRequest, revoke bool) (*Request, error) { + url := c.Endpoint + "/requestor/request" + if revoke { + url += "?revoke" + } + + rb := c.New(http.MethodPost, url) + rb, err := rb.WithJSONBody(reqPayload) + if err != nil { + return nil, err + } + + resp, err := c.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to create request: status %d, body: %s", resp.StatusCode, string(bodyBytes)) + } + + var createdRequest Request + if err := json.NewDecoder(resp.Body).Decode(&createdRequest); err != nil { + return nil, err + } + return &createdRequest, nil +} + +func (c *RequestorClient) UpdateRequest(ctx context.Context, requestID string, status string) (*Request, error) { + url := fmt.Sprintf("%s/requestor/request/%s", c.Endpoint, requestID) + payload := UpdateRequestRequest{Status: status} + + rb := c.New(http.MethodPut, url) + rb, err := rb.WithJSONBody(payload) + if err != nil { + return nil, err + } + + resp, err := c.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to update request: status %d, body: %s", resp.StatusCode, string(bodyBytes)) + } + + var updatedRequest Request + if err := json.NewDecoder(resp.Body).Decode(&updatedRequest); err != nil { + return nil, err + } + return &updatedRequest, nil +} + +func loadPolicies(filename string) ([]CreateRequestRequest, error) { + content, err := policyFS.ReadFile("policies/" + filename) + if err != nil { + return nil, err + } + + var config PolicyConfig + if err := yaml.Unmarshal(content, &config); err != nil { + return nil, err + } + return config.Policies, nil +} + +func formatPolicy(policy CreateRequestRequest, projectID string, username string) CreateRequestRequest { + p := policy + if username != "" { + p.Username = username + } + + if projectID != "" { + parts := strings.Split(projectID, "-") + if len(parts) >= 2 { + program := parts[0] + project := parts[1] + + newPaths := make([]string, len(p.ResourcePaths)) + for i, path := range p.ResourcePaths { + r := strings.ReplaceAll(path, "PROGRAM", program) + r = strings.ReplaceAll(r, "PROJECT", project) + newPaths[i] = r + } + p.ResourcePaths = newPaths + } + p.ResourceDisplayName = projectID + } + return p +} + +func (c *RequestorClient) getPolicyKey(p CreateRequestRequest) string { + roles := make([]string, len(p.RoleIDs)) + copy(roles, p.RoleIDs) + sort.Strings(roles) + + paths := make([]string, len(p.ResourcePaths)) + copy(paths, p.ResourcePaths) + sort.Strings(paths) + + return fmt.Sprintf("%s:%s:%s", p.PolicyID, strings.Join(roles, ","), strings.Join(paths, ",")) +} + +func (c *RequestorClient) AddUser(ctx context.Context, projectID string, username string, write bool, guppy bool) ([]Request, error) { + uniquePolicies := make(map[string]CreateRequestRequest) + + addFrom := func(fileName string) error { + pols, err := loadPolicies(fileName) + if err != nil { + return err + } + for _, p := range pols { + formatted := formatPolicy(p, projectID, username) + key := c.getPolicyKey(formatted) + uniquePolicies[key] = formatted + } + return nil + } + + // Always add read + if err := addFrom("add-user-read.yaml"); err != nil { + return nil, fmt.Errorf("failed to load read policy: %w", err) + } + + if write { + if err := addFrom("add-user-write.yaml"); err != nil { + return nil, fmt.Errorf("failed to load write policy: %w", err) + } + } + if guppy { + if err := addFrom("add-user-guppy-admin.yaml"); err != nil { + return nil, fmt.Errorf("failed to load guppy policy: %w", err) + } + } + + var createdRequests []Request + for _, formatted := range uniquePolicies { + req, err := c.CreateRequest(ctx, formatted, false) + if err != nil { + return createdRequests, fmt.Errorf("failed to create request for policy %v: %w", formatted, err) + } + createdRequests = append(createdRequests, *req) + } + return createdRequests, nil +} + +func (c *RequestorClient) RemoveUser(ctx context.Context, projectID string, username string) ([]Request, error) { + uniquePolicies := make(map[string]CreateRequestRequest) + + addFrom := func(fileName string) error { + pols, err := loadPolicies(fileName) + if err != nil { + return err + } + for _, p := range pols { + formatted := formatPolicy(p, projectID, username) + key := c.getPolicyKey(formatted) + uniquePolicies[key] = formatted + } + return nil + } + + // Revoke read and write + if err := addFrom("add-user-read.yaml"); err != nil { + return nil, fmt.Errorf("failed to load read policy: %w", err) + } + + if err := addFrom("add-user-write.yaml"); err != nil { + return nil, fmt.Errorf("failed to load write policy: %w", err) + } + + var createdRequests []Request + for _, formatted := range uniquePolicies { + req, err := c.CreateRequest(ctx, formatted, true) // revoke=true + if err != nil { + return createdRequests, fmt.Errorf("failed to revoke request: %w", err) + } + createdRequests = append(createdRequests, *req) + } + return createdRequests, nil +} diff --git a/requestor/client_test.go b/requestor/client_test.go new file mode 100644 index 0000000..1daef5e --- /dev/null +++ b/requestor/client_test.go @@ -0,0 +1,57 @@ +package requestor + +import ( + "testing" +) + +func TestGetPolicyKey(t *testing.T) { + c := &RequestorClient{} + + p1 := CreateRequestRequest{ + PolicyID: "p1", + RoleIDs: []string{"reader"}, + ResourcePaths: []string{"/path1"}, + } + p2 := CreateRequestRequest{ + PolicyID: "p1", + RoleIDs: []string{"reader"}, + ResourcePaths: []string{"/path1"}, + } + p3 := CreateRequestRequest{ + PolicyID: "p2", + RoleIDs: []string{"reader"}, + ResourcePaths: []string{"/path1"}, + } + + if c.getPolicyKey(p1) != c.getPolicyKey(p2) { + t.Errorf("Expected p1 and p2 to have same key") + } + if c.getPolicyKey(p1) == c.getPolicyKey(p3) { + t.Errorf("Expected p1 and p3 to have different keys (PolicyID differs)") + } + + p4 := CreateRequestRequest{ + RoleIDs: []string{"a", "b"}, + ResourcePaths: []string{"/p1", "/p2"}, + } + p5 := CreateRequestRequest{ + RoleIDs: []string{"b", "a"}, + ResourcePaths: []string{"/p2", "/p1"}, + } + if c.getPolicyKey(p4) != c.getPolicyKey(p5) { + t.Errorf("Expected p4 and p5 to have same key (sorting check)") + } + + // Empty PolicyID check + p6 := CreateRequestRequest{ + RoleIDs: []string{"reader"}, + ResourcePaths: []string{"/path1"}, + } + p7 := CreateRequestRequest{ + RoleIDs: []string{"reader"}, + ResourcePaths: []string{"/path1"}, + } + if c.getPolicyKey(p6) != c.getPolicyKey(p7) { + t.Errorf("Expected p6 and p7 (empty PolicyID) to have same key") + } +} diff --git a/requestor/policies/add-user-guppy-admin.yaml b/requestor/policies/add-user-guppy-admin.yaml new file mode 100644 index 0000000..fb544df --- /dev/null +++ b/requestor/policies/add-user-guppy-admin.yaml @@ -0,0 +1,9 @@ +policies: +- role_ids: + - writer + resource_paths: + - /programs/PROGRAM/projects/PROJECT +- role_ids: + - guppy_admin_user + resource_paths: + - /guppy_admin diff --git a/requestor/policies/add-user-read.yaml b/requestor/policies/add-user-read.yaml new file mode 100644 index 0000000..e7acb34 --- /dev/null +++ b/requestor/policies/add-user-read.yaml @@ -0,0 +1,5 @@ +policies: +- role_ids: + - reader + resource_paths: + - /programs/PROGRAM/projects/PROJECT diff --git a/requestor/policies/add-user-write.yaml b/requestor/policies/add-user-write.yaml new file mode 100644 index 0000000..8fda383 --- /dev/null +++ b/requestor/policies/add-user-write.yaml @@ -0,0 +1,5 @@ +policies: +- role_ids: + - writer + resource_paths: + - /programs/PROGRAM/projects/PROJECT diff --git a/requestor/types.go b/requestor/types.go new file mode 100644 index 0000000..4649124 --- /dev/null +++ b/requestor/types.go @@ -0,0 +1,34 @@ +package requestor + +// Request represents a requestor request object +type Request struct { + RequestID string `json:"request_id,omitempty" yaml:"request_id,omitempty"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` + PolicyID string `json:"policy_id,omitempty" yaml:"policy_id,omitempty"` + ResourcePaths []string `json:"resource_paths,omitempty" yaml:"resource_paths,omitempty"` + RoleIDs []string `json:"role_ids,omitempty" yaml:"role_ids,omitempty"` + ResourceID string `json:"resource_id,omitempty" yaml:"resource_id,omitempty"` + ResourceDisplay string `json:"resource_display_name,omitempty" yaml:"resource_display_name,omitempty"` + Status string `json:"status,omitempty" yaml:"status,omitempty"` + CreatedTime string `json:"created_time,omitempty" yaml:"created_time,omitempty"` + UpdatedTime string `json:"updated_time,omitempty" yaml:"updated_time,omitempty"` + Revoke bool `json:"revoke,omitempty" yaml:"revoke,omitempty"` +} + +// CreateRequestRequest represents the payload to create a request +type CreateRequestRequest struct { + Username string `json:"username,omitempty" yaml:"username,omitempty"` + PolicyID string `json:"policy_id,omitempty" yaml:"policy_id,omitempty"` + ResourcePaths []string `json:"resource_paths,omitempty" yaml:"resource_paths,omitempty"` + RoleIDs []string `json:"role_ids,omitempty" yaml:"role_ids,omitempty"` + ResourceDisplayName string `json:"resource_display_name,omitempty" yaml:"resource_display_name,omitempty"` +} + +// UpdateRequestRequest represents the payload to update a request +type UpdateRequestRequest struct { + Status string `json:"status" yaml:"status"` +} + +type PolicyConfig struct { + Policies []CreateRequestRequest `yaml:"policies"` +} diff --git a/sower/client.go b/sower/client.go new file mode 100644 index 0000000..770e891 --- /dev/null +++ b/sower/client.go @@ -0,0 +1,148 @@ +package sower + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/calypr/data-client/request" +) + +const ( + sowerDispatch = "/job/dispatch" + sowerStatus = "/job/status" + sowerList = "/job/list" + sowerJobOutput = "/job/output" +) + +type SowerInterface interface { + DispatchJob(ctx context.Context, name string, args *DispatchArgs) (*StatusResp, error) + Status(ctx context.Context, uid string) (*StatusResp, error) + List(ctx context.Context) ([]StatusResp, error) + Output(ctx context.Context, uid string) (*OutputResp, error) +} + +type SowerClient struct { + request.RequestInterface + Endpoint string +} + +func NewSowerClient(req request.RequestInterface, endpoint string) *SowerClient { + return &SowerClient{ + RequestInterface: req, + Endpoint: endpoint, + } +} + +func (sc *SowerClient) fullURL(path string) string { + u, _ := url.Parse(sc.Endpoint) + u.Path = path + return u.String() +} + +func (sc *SowerClient) DispatchJob(ctx context.Context, name string, args *DispatchArgs) (*StatusResp, error) { + body := JobArgs{ + Action: name, + Input: *args, + } + + rb := sc.New(http.MethodPost, sc.fullURL(sowerDispatch)) + rb, err := rb.WithJSONBody(body) + if err != nil { + return nil, err + } + + resp, err := sc.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("sower dispatch failed: %d %s", resp.StatusCode, string(b)) + } + + statusResp := &StatusResp{} + err = json.NewDecoder(resp.Body).Decode(statusResp) + if err != nil { + return nil, err + } + return statusResp, nil +} + +func (sc *SowerClient) Status(ctx context.Context, uid string) (*StatusResp, error) { + u, _ := url.Parse(sc.fullURL(sowerStatus)) + q := u.Query() + q.Add("UID", uid) + u.RawQuery = q.Encode() + + rb := sc.New(http.MethodGet, u.String()) + resp, err := sc.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("sower status failed: %d %s", resp.StatusCode, string(b)) + } + + statusResp := &StatusResp{} + err = json.NewDecoder(resp.Body).Decode(statusResp) + if err != nil { + return nil, err + } + return statusResp, nil +} + +func (sc *SowerClient) Output(ctx context.Context, uid string) (*OutputResp, error) { + u, _ := url.Parse(sc.fullURL(sowerJobOutput)) + q := u.Query() + q.Add("UID", uid) + u.RawQuery = q.Encode() + + rb := sc.New(http.MethodGet, u.String()) + resp, err := sc.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("sower output failed: %d %s", resp.StatusCode, string(b)) + } + + var outputResp OutputResp + err = json.NewDecoder(resp.Body).Decode(&outputResp) + if err != nil { + return nil, err + } + return &outputResp, nil +} + +func (sc *SowerClient) List(ctx context.Context) ([]StatusResp, error) { + rb := sc.New(http.MethodGet, sc.fullURL(sowerList)) + resp, err := sc.Do(ctx, rb) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("sower list failed: %d %s", resp.StatusCode, string(b)) + } + + var listResp []StatusResp + err = json.NewDecoder(resp.Body).Decode(&listResp) + if err != nil { + return nil, err + } + return listResp, nil +} diff --git a/sower/types.go b/sower/types.go new file mode 100644 index 0000000..7b735a7 --- /dev/null +++ b/sower/types.go @@ -0,0 +1,33 @@ +package sower + +type StatusResp struct { + Uid string `json:"uid"` + Name string `json:"name"` + Status string `json:"status"` +} + +type OutputResp struct { + Output string `json:"output"` +} + +type File struct { + FileTitle string `json:"fileTitle,omitempty"` + FilePath string `json:"filePath"` +} + +type DispatchArgs struct { + Method string `json:"method"` + ProjectId string `json:"projectId"` + Profile string `json:"profile"` + BucketName string `json:"bucketName"` + APIEndpoint string `json:"APIEndpoint"` + GHCommitHash string `json:"ghCommitHash"` + GHPAccessToken string `json:"ghToken"` + GHUserName string `json:"ghUserName"` + GHRepoURL string `json:"ghRepoUrl"` +} + +type JobArgs struct { + Input DispatchArgs `json:"input"` + Action string `json:"action"` +} diff --git a/tests/download-multiple_test.go b/tests/download-multiple_test.go index c2da7c6..84169b7 100644 --- a/tests/download-multiple_test.go +++ b/tests/download-multiple_test.go @@ -1,19 +1,19 @@ package tests import ( + "context" "fmt" "io" "net/http" - "os" "strings" "testing" - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/download" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/mocks" - req "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/download" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/mocks" + req "github.com/calypr/data-client/request" "go.uber.org/mock/gomock" ) @@ -26,12 +26,14 @@ func Test_askGen3ForFileInfo_withShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockFence := mocks.NewMockFenceInterface(mockCtrl) // Expect credential access mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() // Shepherd is available - mockGen3.EXPECT(). + mockFence.EXPECT(). CheckForShepherdAPI(gomock.Any()). Return(true, nil) @@ -48,10 +50,10 @@ func Test_askGen3ForFileInfo_withShepherd(t *testing.T) { Body: io.NopCloser(strings.NewReader(testBody)), } - // Expect authenticated request to Shepherd - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { + // Expect request to Shepherd + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx any, rb *req.RequestBuilder) (*http.Response, error) { if !strings.HasSuffix(rb.Url, "/objects/"+testGUID) { t.Errorf("Expected request to Shepherd objects endpoint, got %s", rb.Url) } @@ -59,10 +61,10 @@ func Test_askGen3ForFileInfo_withShepherd(t *testing.T) { }) // Optional: logger - mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() + mockGen3.EXPECT().Logger().Return(logs.NewGen3Logger(nil, "", "test")).AnyTimes() skipped := []download.RenamedOrSkippedFileInfo{} - info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + info, err := download.AskGen3ForFileInfo(context.Background(), mockGen3, testGUID, "", "", "original", true, &skipped) if err != nil { t.Error(err) } @@ -77,6 +79,7 @@ func Test_askGen3ForFileInfo_withShepherd(t *testing.T) { t.Errorf("Expected no skipped files, got %v", skipped) } } + func Test_askGen3ForFileInfo_withShepherd_shepherdError(t *testing.T) { testGUID := "000000-0000000-0000000-000000" @@ -84,47 +87,42 @@ func Test_askGen3ForFileInfo_withShepherd_shepherdError(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockFence := mocks.NewMockFenceInterface(mockCtrl) dummyCred := &conf.Credential{} mockGen3.EXPECT().GetCredential().Return(dummyCred).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() // 1. Shepherd is available - mockGen3.EXPECT(). + mockFence.EXPECT(). CheckForShepherdAPI(gomock.Any()). Return(true, nil). Times(1) // 2. Shepherd request fails → triggers fallback to Indexd - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). Return(nil, fmt.Errorf("Shepherd error")). Times(1) // only the Shepherd call - // 3. Fallback: Indexd request also fails (we want to test error handling) - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + // 3. Fallback: Indexd request also fails + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). Return(nil, fmt.Errorf("Indexd error")). Times(1) - // Optional: if it tries to parse nil response from Indexd - mockGen3.EXPECT(). - ParseFenceURLResponse(gomock.Nil()). - Return(api.FenceResponse{}, fmt.Errorf("no response")). - AnyTimes() - // Logger mockGen3.EXPECT(). Logger(). - Return(logs.NewTeeLogger("", "test", os.Stdout)). + Return(logs.NewGen3Logger(nil, "", "test")). AnyTimes() skipped := []download.RenamedOrSkippedFileInfo{} - info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + info, err := download.AskGen3ForFileInfo(context.Background(), mockGen3, testGUID, "", "", "original", true, &skipped) if err != nil { t.Fatal(err) } - // Critical fix: check for nil first if info == nil { t.Fatal("AskGen3ForFileInfo returned nil when both Shepherd and Indexd failed. Expected fallback FileInfo with Name = GUID") } @@ -149,25 +147,28 @@ func Test_askGen3ForFileInfo_noShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockFence := mocks.NewMockFenceInterface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() // No Shepherd - mockGen3.EXPECT().CheckForShepherdAPI(gomock.Any()).Return(false, nil) + mockFence.EXPECT().CheckForShepherdAPI(gomock.Any()).Return(false, nil) // Indexd returns parsed FenceResponse - mockGen3.EXPECT(). + mockFence.EXPECT(). ParseFenceURLResponse(gomock.Any()). - Return(api.FenceResponse{FileName: testFileName, Size: testFileSize}, nil) + Return(fence.FenceResponse{FileName: testFileName, Size: testFileSize}, nil) - // DoAuthenticatedRequest called for indexd - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). + // Do called for indexd + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). Return(&http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("{}"))}, nil) - mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() + mockGen3.EXPECT().Logger().Return(logs.NewGen3Logger(nil, "", "test")).AnyTimes() skipped := []download.RenamedOrSkippedFileInfo{} - info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) + info, err := download.AskGen3ForFileInfo(context.Background(), mockGen3, testGUID, "", "", "original", true, &skipped) if err != nil { t.Fatal(err) } @@ -179,34 +180,3 @@ func Test_askGen3ForFileInfo_noShepherd(t *testing.T) { t.Errorf("Wanted filesize %v, got %v", testFileSize, info.Size) } } - -func Test_askGen3ForFileInfo_noShepherd_indexdError(t *testing.T) { - testGUID := "000000-0000000-0000000-000000" - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockGen3 := mocks.NewMockGen3Interface(mockCtrl) - mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() - mockGen3.EXPECT().CheckForShepherdAPI(gomock.Any()).Return(false, nil) - - // Indexd request fails - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - Return(nil, fmt.Errorf("Indexd error")) - - mockGen3.EXPECT().Logger().Return(logs.NewTeeLogger("", "test", os.Stdout)).AnyTimes() - - skipped := []download.RenamedOrSkippedFileInfo{} - info, err := download.AskGen3ForFileInfo(mockGen3, testGUID, "", "", "original", true, &skipped) - if err != nil { - t.Fatal(err) - } - - if info.Name != testGUID { - t.Errorf("Wanted fallback filename %v, got %v", testGUID, info.Name) - } - if len(skipped) != 1 || skipped[0].GUID != testGUID { - t.Errorf("Expected skipped entry for GUID: %v", skipped) - } -} diff --git a/tests/functions_test.go b/tests/functions_test.go deleted file mode 100755 index 8c20d33..0000000 --- a/tests/functions_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package tests - -import ( - "bytes" - "io" - "net/http" - "reflect" - "strings" - "testing" - - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/mocks" - req "github.com/calypr/data-client/client/request" - "go.uber.org/mock/gomock" -) - -func TestDoAuthenticatedRequest_NoProfile(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - emptyCred := &conf.Credential{} - - // Expect error when credentials are incomplete - _, err := mockFuncs.DoAuthenticatedRequest(emptyCred, &req.RequestBuilder{ - Url: "/user/data/download/test_uuid", - }) - if err == nil { - t.Error("Expected error due to missing credentials, but got nil") - } -} - -func TestDoAuthenticatedRequest_GoodToken(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - cred := &conf.Credential{ - APIKey: "fake_api_key", - AccessToken: "non_expired_token", - APIEndpoint: "https://example.com", - } - - mockedResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewBufferString(`{"url": "https://signed.url"}`)), - } - - mockFuncs.EXPECT(). - DoAuthenticatedRequest(cred, gomock.Any()). - Return(mockedResp, nil). - Times(1) - - resp, err := mockFuncs.DoAuthenticatedRequest(cred, &req.RequestBuilder{ - Url: "/user/data/download/test_uuid", - }) - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("Expected status 200, got %d", resp.StatusCode) - } -} - -func TestDoAuthenticatedRequest_MissingToken_CreatesNew(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - mockConfig := mocks.NewMockManagerInterface(mockCtrl) - - // Assuming Functions struct has both Config and Functions fields - testFunction := &api.Functions{ - Config: mockConfig, - } - - cred := &conf.Credential{ - APIKey: "fake_api_key", - AccessToken: "", // empty → should trigger token creation - APIEndpoint: "https://example.com", - } - - mockedResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewBufferString(`{"url": "https://signed.url"}`)), - } - - // Expect Save to be called if new token is generated and saved - mockConfig.EXPECT().Save(cred).AnyTimes() - - mockFuncs.EXPECT(). - DoAuthenticatedRequest(cred, gomock.Any()). - Return(mockedResp, nil). - Times(1) - - _, err := testFunction.DoAuthenticatedRequest(cred, &req.RequestBuilder{ - Url: "/user/data/download/test_uuid", - }) - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } -} - -func TestCheckPrivileges_NoProfile(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - emptyCred := &conf.Credential{} - - _, err := mockFuncs.CheckPrivileges(emptyCred) - if err == nil { - t.Error("Expected error when credentials are missing, got nil") - } -} - -func TestCheckPrivileges_NoAccess(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - cred := &conf.Credential{ - APIKey: "fake_api_key", - AccessToken: "valid_token", - APIEndpoint: "https://example.com", - } - - userResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"project_access": {}}`)), - } - - mockFuncs.EXPECT(). - DoAuthenticatedRequest(cred, gomock.Any()). - Return(userResp, nil) - - privileges, err := mockFuncs.CheckPrivileges(cred) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - expected := make(map[string]any) - if !reflect.DeepEqual(privileges, expected) { - t.Errorf("Expected empty privileges, got %v", privileges) - } -} - -func TestCheckPrivileges_GrantedAccess_ProjectAccess(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - cred := &conf.Credential{ - APIKey: "fake_api_key", - AccessToken: "valid_token", - APIEndpoint: "https://example.com", - } - - jsonBody := `{ - "project_access": { - "test_project": ["read", "create", "read-storage", "update", "delete"] - } - }` - - userResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(jsonBody)), - } - - mockFuncs.EXPECT(). - DoAuthenticatedRequest(cred, gomock.Any()). - Return(userResp, nil) - - privileges, err := mockFuncs.CheckPrivileges(cred) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - expected := map[string]any{ - "test_project": []any{"read", "create", "read-storage", "update", "delete"}, - } - - if !reflect.DeepEqual(privileges, expected) { - t.Errorf("Privileges mismatch.\nExpected: %v\nGot: %v", expected, privileges) - } -} - -func TestCheckPrivileges_GrantedAccess_AuthzTakesPrecedence(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockFuncs := mocks.NewMockFunctionInterface(mockCtrl) - - cred := &conf.Credential{ - APIKey: "fake_api_key", - AccessToken: "valid_token", - APIEndpoint: "https://example.com", - } - - jsonBody := `{ - "authz": { - "test_project": [ - {"method": "create", "service": "*"}, - {"method": "delete", "service": "*"}, - {"method": "read", "service": "*"}, - {"method": "read-storage", "service": "*"}, - {"method": "update", "service": "*"}, - {"method": "upload", "service": "*"} - ] - }, - "project_access": { - "test_project": ["read", "create", "read-storage", "update", "delete"] - } - }` - - userResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(jsonBody)), - } - - mockFuncs.EXPECT(). - DoAuthenticatedRequest(cred, gomock.Any()). - Return(userResp, nil) - - privileges, err := mockFuncs.CheckPrivileges(cred) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - expected := map[string]any{ - "test_project": []any{ - map[string]any{"method": "create", "service": "*"}, - map[string]any{"method": "delete", "service": "*"}, - map[string]any{"method": "read", "service": "*"}, - map[string]any{"method": "read-storage", "service": "*"}, - map[string]any{"method": "update", "service": "*"}, - map[string]any{"method": "upload", "service": "*"}, - }, - } - - if !reflect.DeepEqual(privileges, expected) { - t.Errorf("Authz privileges should take precedence.\nExpected: %v\nGot: %v", expected, privileges) - } -} diff --git a/tests/utils_test.go b/tests/utils_test.go index 758fb24..fa330d6 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -1,19 +1,20 @@ package tests import ( + "context" "fmt" "io" "net/http" "strings" "testing" - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/download" - "github.com/calypr/data-client/client/mocks" - req "github.com/calypr/data-client/client/request" - "github.com/calypr/data-client/client/upload" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/download" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/mocks" + "github.com/calypr/data-client/request" + "github.com/calypr/data-client/upload" "go.uber.org/mock/gomock" ) @@ -26,40 +27,33 @@ func TestGetDownloadResponse_withShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockFence := mocks.NewMockFenceInterface(mockCtrl) // Mock credential mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() - - // Shepherd is deployed - mockGen3.EXPECT(). - CheckForShepherdAPI(gomock.Any()). - Return(true, nil) - - // Shepherd download URL response - downloadURLBody := fmt.Sprintf(`{"url": "%s"}`, mockDownloadURL) - shepherdResp := &http.Response{ + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() + + mockFence.EXPECT(). + GetDownloadPresignedUrl(gomock.Any(), testGUID, ""). + Return(mockDownloadURL, nil) + + mockFence.EXPECT(). + New(http.MethodGet, mockDownloadURL). + Return(&request.RequestBuilder{ + Method: http.MethodGet, + Url: mockDownloadURL, + Headers: make(map[string]string), + }). + AnyTimes() + + // Mock successful response from the presigned URL + mockResp := &http.Response{ StatusCode: 200, - Body: io.NopCloser(strings.NewReader(downloadURLBody)), + Body: io.NopCloser(strings.NewReader("content")), } - - // Expect DoAuthenticatedRequest to Shepherd /objects/{guid}/download - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { - if !strings.HasSuffix(rb.Url, "/objects/"+testGUID+"/download") { - t.Errorf("Expected Shepherd download URL request, got %s", rb.Url) - } - return shepherdResp, nil - }) - - // ParseFenceURLResponse to extract URL - mockGen3.EXPECT(). - ParseFenceURLResponse(shepherdResp). - Return(api.FenceResponse{URL: mockDownloadURL}, nil) - - // We assume the implementation uses http.Client directly for presigned URLs (common pattern) - // So no mock needed here unless you inject an HTTP client — this part may be unmocked. - // If you have a mockable HTTP doer, adjust accordingly. + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). + Return(mockResp, nil) mockFDRObj := common.FileDownloadResponseObject{ Filename: testFilename, @@ -67,17 +61,14 @@ func TestGetDownloadResponse_withShepherd(t *testing.T) { Range: 0, } - err := download.GetDownloadResponse(mockGen3, &mockFDRObj, "") + err := download.GetDownloadResponse(context.Background(), mockGen3, &mockFDRObj, "") if err != nil { t.Fatalf("Unexpected error: %v", err) } - if mockFDRObj.URL != mockDownloadURL { - t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.URL) + if mockFDRObj.PresignedURL != mockDownloadURL { + t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.PresignedURL) } - - // Note: Response may be fetched outside the interface (direct http.Get), so this check might not work unless injected. - // If you want to fully mock it, consider injecting a downloader. } func TestGetDownloadResponse_noShepherd(t *testing.T) { @@ -89,26 +80,32 @@ func TestGetDownloadResponse_noShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) - mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockFence := mocks.NewMockFenceInterface(mockCtrl) - // No Shepherd - mockGen3.EXPECT(). - CheckForShepherdAPI(gomock.Any()). - Return(false, nil) - - // Fence returns presigned URL - fenceResp := &http.Response{ + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() + + mockFence.EXPECT(). + GetDownloadPresignedUrl(gomock.Any(), testGUID, ""). + Return(mockDownloadURL, nil) + + mockFence.EXPECT(). + New(http.MethodGet, mockDownloadURL). + Return(&request.RequestBuilder{ + Method: http.MethodGet, + Url: mockDownloadURL, + Headers: make(map[string]string), + }). + AnyTimes() + + // Mock successful response + mockResp := &http.Response{ StatusCode: 200, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"url": "%s"}`, mockDownloadURL))), + Body: io.NopCloser(strings.NewReader("content")), } - - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - Return(fenceResp, nil) - - mockGen3.EXPECT(). - ParseFenceURLResponse(fenceResp). - Return(api.FenceResponse{URL: mockDownloadURL}, nil) + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). + Return(mockResp, nil) mockFDRObj := common.FileDownloadResponseObject{ Filename: testFilename, @@ -116,17 +113,17 @@ func TestGetDownloadResponse_noShepherd(t *testing.T) { Range: 0, } - err := download.GetDownloadResponse(mockGen3, &mockFDRObj, "") + err := download.GetDownloadResponse(context.Background(), mockGen3, &mockFDRObj, "") if err != nil { t.Fatalf("Unexpected error: %v", err) } - if mockFDRObj.URL != mockDownloadURL { - t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.URL) + if mockFDRObj.PresignedURL != mockDownloadURL { + t.Errorf("Wanted URL %s, got %s", mockDownloadURL, mockFDRObj.PresignedURL) } } -func TestGeneratePresignedURL_noShepherd(t *testing.T) { +func TestGeneratePresignedUploadURL_noShepherd(t *testing.T) { testFilename := "test-file" testBucketname := "test-bucket" mockPresignedURL := "https://example.com/example.pfb" @@ -136,33 +133,24 @@ func TestGeneratePresignedURL_noShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) + mockFence := mocks.NewMockFenceInterface(mockCtrl) + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() // No Shepherd - mockGen3.EXPECT(). + mockFence.EXPECT(). CheckForShepherdAPI(gomock.Any()). Return(false, nil) - // Fence upload endpoint response - fenceResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf( - `{"url": "%s", "guid": "%s"}`, mockPresignedURL, mockGUID, - ))), - } - - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - Return(fenceResp, nil) - - mockGen3.EXPECT(). - ParseFenceURLResponse(fenceResp). - Return(api.FenceResponse{ + mockFence.EXPECT(). + InitUpload(gomock.Any(), testFilename, testBucketname, ""). + Return(fence.FenceResponse{ URL: mockPresignedURL, GUID: mockGUID, }, nil) - resp, err := upload.GeneratePresignedURL(mockGen3, testFilename, common.FileMetadata{}, testBucketname) + resp, err := upload.GeneratePresignedUploadURL(context.Background(), mockGen3, testFilename, common.FileMetadata{}, testBucketname) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -175,7 +163,7 @@ func TestGeneratePresignedURL_noShepherd(t *testing.T) { } } -func TestGeneratePresignedURL_withShepherd(t *testing.T) { +func TestGeneratePresignedUploadURL_withShepherd(t *testing.T) { testFilename := "test-file" testBucketname := "test-bucket" mockPresignedURL := "https://example.com/example.pfb" @@ -191,10 +179,13 @@ func TestGeneratePresignedURL_withShepherd(t *testing.T) { defer mockCtrl.Finish() mockGen3 := mocks.NewMockGen3Interface(mockCtrl) - mockGen3.EXPECT().GetCredential().Return(&conf.Credential{}).AnyTimes() + mockFence := mocks.NewMockFenceInterface(mockCtrl) + + mockGen3.EXPECT().GetCredential().Return(&conf.Credential{AccessToken: "token"}).AnyTimes() + mockGen3.EXPECT().Fence().Return(mockFence).AnyTimes() // Shepherd is deployed - mockGen3.EXPECT(). + mockFence.EXPECT(). CheckForShepherdAPI(gomock.Any()). Return(true, nil) @@ -206,17 +197,11 @@ func TestGeneratePresignedURL_withShepherd(t *testing.T) { ))), } - mockGen3.EXPECT(). - DoAuthenticatedRequest(gomock.Any(), gomock.Any()). - DoAndReturn(func(cred *conf.Credential, rb *req.RequestBuilder) (*http.Response, error) { - if rb.Method != "POST" || !strings.HasSuffix(rb.Url, "/objects") { - t.Errorf("Expected POST to /objects, got %s %s", rb.Method, rb.Url) - } - // Optionally validate body here if needed - return shepherdResp, nil - }) + mockFence.EXPECT(). + Do(gomock.Any(), gomock.Any()). + Return(shepherdResp, nil) - respObj, err := upload.GeneratePresignedURL(mockGen3, testFilename, testMetadata, testBucketname) + respObj, err := upload.GeneratePresignedUploadURL(context.Background(), mockGen3, testFilename, testMetadata, testBucketname) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/client/upload/batch.go b/upload/batch.go similarity index 70% rename from client/upload/batch.go rename to upload/batch.go index 5e08af4..41aea65 100644 --- a/client/upload/batch.go +++ b/upload/batch.go @@ -8,9 +8,9 @@ import ( "os" "sync" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/request" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" ) @@ -63,21 +63,21 @@ func BatchUpload( for fur := range workCh { // --- Ensure presigned URL --- if fur.PresignedURL == "" { - resp, err := GeneratePresignedUploadURL(ctx, g3i, fur.Filename, fur.FileMetadata, fur.Bucket) + resp, err := GeneratePresignedUploadURL(ctx, g3i, fur.ObjectKey, fur.FileMetadata, fur.Bucket) if err != nil { - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, "", 0, false) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, "", 0, false) errCh <- err continue } fur.PresignedURL = resp.URL fur.GUID = resp.GUID - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, resp.GUID, 0, false) // update log + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, resp.GUID, 0, false) // update log } // --- Open file --- - file, err := os.Open(fur.FilePath) + file, err := os.Open(fur.SourcePath) if err != nil { - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, fur.GUID, 0, false) errCh <- fmt.Errorf("file open error: %w", err) continue } @@ -85,22 +85,22 @@ func BatchUpload( fi, err := file.Stat() if err != nil { file.Close() - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, fur.GUID, 0, false) errCh <- fmt.Errorf("file stat error: %w", err) continue } if fi.Size() > common.FileSizeLimit { file.Close() - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) - errCh <- fmt.Errorf("file size exceeds limit: %s", fur.Filename) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, fur.GUID, 0, false) + errCh <- fmt.Errorf("file size exceeds limit: %s", fur.ObjectKey) continue } // --- Progress bar --- bar := progress.AddBar(fi.Size(), mpb.PrependDecorators( - decor.Name(fur.Filename+" "), + decor.Name(fur.ObjectKey+" "), decor.CountersKibiByte("% .1f / % .1f"), ), mpb.AppendDecorators( @@ -112,7 +112,7 @@ func BatchUpload( proxyReader := bar.ProxyReader(file) // --- Upload using DoAuthenticatedRequest (no manual http.Request!) --- - resp, err := g3i.Do( + resp, err := g3i.Fence().Do( ctx, &request.RequestBuilder{ Method: http.MethodPut, @@ -126,7 +126,7 @@ func BatchUpload( bar.Abort(false) if err != nil { - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, fur.GUID, 0, false) errCh <- err continue } @@ -135,7 +135,7 @@ func BatchUpload( bodyBytes, _ := io.ReadAll(resp.Body) resp.Body.Close() errMsg := fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - g3i.Logger().Failed(fur.FilePath, fur.Filename, fur.FileMetadata, fur.GUID, 0, false) + g3i.Logger().Failed(fur.SourcePath, fur.ObjectKey, fur.FileMetadata, fur.GUID, 0, false) errCh <- errMsg continue } @@ -144,8 +144,8 @@ func BatchUpload( // Success respCh <- resp - g3i.Logger().DeleteFromFailedLog(fur.FilePath) - g3i.Logger().Succeeded(fur.FilePath, fur.GUID) + g3i.Logger().DeleteFromFailedLog(fur.SourcePath) + g3i.Logger().Succeeded(fur.SourcePath, fur.GUID) g3i.Logger().Scoreboard().IncrementSB(0) } }() diff --git a/client/upload/multipart.go b/upload/multipart.go similarity index 60% rename from client/upload/multipart.go rename to upload/multipart.go index df8d0cd..b0e3cce 100644 --- a/client/upload/multipart.go +++ b/upload/multipart.go @@ -1,9 +1,7 @@ package upload import ( - "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -14,15 +12,15 @@ import ( "sync" "sync/atomic" - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - req "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/fence" + client "github.com/calypr/data-client/g3client" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" ) func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, file *os.File, showProgress bool) error { - g3.Logger().Printf("File Upload Request: %#v\n", req) + g3.Logger().InfoContext(ctx, "File Upload Request", "request", req) stat, err := file.Stat() if err != nil { @@ -31,7 +29,7 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi fileSize := stat.Size() if fileSize == 0 { - return fmt.Errorf("file is empty: %s", req.Filename) + return fmt.Errorf("file is empty: %s", req.ObjectKey) } var p *mpb.Progress @@ -40,7 +38,7 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi p = mpb.New(mpb.WithOutput(os.Stdout)) bar = p.AddBar(fileSize, mpb.PrependDecorators( - decor.Name(req.Filename+" "), + decor.Name(req.ObjectKey+" "), decor.CountersKibiByte("%.1f / %.1f"), ), mpb.AppendDecorators( @@ -58,8 +56,8 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi // 2. Construct the S3 Key correctly // Ensure finalGUID is not empty to avoid a leading slash - key := fmt.Sprintf("%s/%s", finalGUID, req.Filename) - g3.Logger().Printf("Initialized Upload: ID=%s, Key=%s\n", uploadID, key) + key := fmt.Sprintf("%s/%s", finalGUID, req.ObjectKey) + g3.Logger().InfoContext(ctx, "Initialized Upload", "id", uploadID, "key", key) chunkSize := OptimalChunkSize(fileSize) @@ -74,11 +72,17 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi var ( wg sync.WaitGroup mu sync.Mutex - parts []MultipartPartObject + parts []fence.MultipartPart uploadErrors []error totalBytes int64 // Atomic counter for monotonically increasing BytesSoFar ) + progressCallback := common.GetProgress(ctx) + oid := common.GetOid(ctx) + if oid == "" { + oid = resolveUploadOID(req) + } + // 3. Worker logic worker := func() { defer wg.Done() @@ -113,24 +117,21 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi } mu.Lock() - parts = append(parts, MultipartPartObject{ + parts = append(parts, fence.MultipartPart{ PartNumber: partNum, ETag: etag, }) if bar != nil { bar.IncrInt64(size) } - if req.Progress != nil { + if progressCallback != nil { currentTotal := atomic.AddInt64(&totalBytes, size) - err = req.Progress(common.ProgressEvent{ + _ = progressCallback(common.ProgressEvent{ Event: "progress", - Oid: req.OID, + Oid: oid, BytesSinceLast: size, BytesSoFar: currentTotal, }) - if err != nil { - g3.Logger().Printf("progress callback error: %v", err) - } } mu.Unlock() } @@ -160,34 +161,13 @@ func MultipartUpload(ctx context.Context, g3 client.Gen3Interface, req common.Fi return fmt.Errorf("failed to complete multipart upload: %w", err) } - g3.Logger().Printf("Successfully uploaded %s to %s", req.Filename, key) - g3.Logger().Succeeded(req.FilePath, req.GUID) + g3.Logger().InfoContext(ctx, "Successfully uploaded", "file", req.ObjectKey, "key", key) + g3.Logger().SucceededContext(ctx, req.SourcePath, req.GUID) return nil } -// InitMultipartUpload helps sending requests to FENCE to init a multipart upload func initMultipartUpload(ctx context.Context, g3 client.Gen3Interface, furObject common.FileUploadRequestObject, bucketName string) (string, string, error) { - // Use Filename and GUID directly from the unified request object - - reader, err := common.ToJSONReader( - InitRequestObject{ - Filename: furObject.Filename, - Bucket: bucketName, - GUID: furObject.GUID, - }, - ) - - cred := g3.GetCredential() - resp, err := g3.Do( - ctx, - &req.RequestBuilder{ - Method: http.MethodPost, - Url: cred.APIEndpoint + common.FenceDataMultipartInitEndpoint, - Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, - Body: reader, - Token: cred.AccessToken, - }, - ) + msg, err := g3.Fence().InitMultipartUpload(ctx, furObject.ObjectKey, bucketName, furObject.GUID) if err != nil { if strings.Contains(err.Error(), "404") { @@ -196,81 +176,26 @@ func initMultipartUpload(ctx context.Context, g3 client.Gen3Interface, furObject return "", "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) } - msg, err := g3.ParseFenceURLResponse(resp) - if err != nil { - return "", "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) - - } - if msg.UploadID == "" || msg.GUID == "" { return "", "", errors.New("unknown error has occurred during multipart upload initialization. Please check logs from Gen3 services") } - return msg.UploadID, msg.GUID, err + return msg.UploadID, msg.GUID, nil } -// GenerateMultipartPresignedURL helps sending requests to FENCE to get a presigned URL for a part during a multipart upload func generateMultipartPresignedURL(ctx context.Context, g3 client.Gen3Interface, key string, uploadID string, partNumber int, bucketName string) (string, error) { - - reader, err := common.ToJSONReader( - MultipartUploadRequestObject{ - Key: key, - UploadID: uploadID, - PartNumber: partNumber, - Bucket: bucketName, - }, - ) - if err != nil { - return "", err - } - - cred := g3.GetCredential() - resp, err := g3.Do( - ctx, - &req.RequestBuilder{ - Url: cred.APIEndpoint + common.FenceDataMultipartUploadEndpoint, - Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, - Method: http.MethodPost, - Body: reader, - Token: cred.AccessToken, - }, - ) + url, err := g3.Fence().GenerateMultipartPresignedURL(ctx, key, uploadID, partNumber, bucketName) if err != nil { return "", errors.New("Error has occurred during multipart upload presigned url generation, detailed error message: " + err.Error()) } - msg, err := g3.ParseFenceURLResponse(resp) - if err != nil { - return "", errors.New("Error has occurred during multipart upload initialization, detailed error message: " + err.Error()) - } - - if msg.PresignedURL == "" { + if url == "" { return "", errors.New("unknown error has occurred during multipart upload presigned url generation. Please check logs from Gen3 services") } - return msg.PresignedURL, err + return url, nil } -// CompleteMultipartUpload helps sending requests to FENCE to complete a multipart upload -func CompleteMultipartUpload(ctx context.Context, g3 client.Gen3Interface, key string, uploadID string, parts []MultipartPartObject, bucketName string) error { - multipartCompleteObject := MultipartCompleteRequestObject{Key: key, UploadID: uploadID, Parts: parts, Bucket: bucketName} - - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(multipartCompleteObject) - if err != nil { - return errors.New("Error occurred during encoding multipart upload data: " + err.Error()) - } - - // TOOD: error check this, return resp information - cred := g3.GetCredential() - _, err = g3.Do( - ctx, - &req.RequestBuilder{ - Url: cred.APIEndpoint + common.FenceDataMultipartCompleteEndpoint, - Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, - Body: &buf, - Method: http.MethodPost, - Token: cred.AccessToken, - }, - ) +func CompleteMultipartUpload(ctx context.Context, g3 client.Gen3Interface, key string, uploadID string, parts []fence.MultipartPart, bucketName string) error { + err := g3.Fence().CompleteMultipartUpload(ctx, key, uploadID, parts, bucketName) if err != nil { return errors.New("Error has occurred during completing multipart upload, detailed error message: " + err.Error()) } diff --git a/client/upload/multipart_test.go b/upload/multipart_test.go similarity index 54% rename from client/upload/multipart_test.go rename to upload/multipart_test.go index c03cbea..d9cad7a 100644 --- a/client/upload/multipart_test.go +++ b/upload/multipart_test.go @@ -3,6 +3,7 @@ package upload import ( "bytes" "context" + "encoding/json" "fmt" "io" "net/http" @@ -13,42 +14,73 @@ import ( "sync" "testing" - "github.com/calypr/data-client/client/api" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/conf" - "github.com/calypr/data-client/client/logs" - "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/indexd" + "github.com/calypr/data-client/logs" + "github.com/calypr/data-client/request" + "github.com/calypr/data-client/sower" ) type fakeGen3Upload struct { cred *conf.Credential - logger *logs.TeeLogger + logger *logs.Gen3Logger doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) } func (f *fakeGen3Upload) GetCredential() *conf.Credential { return f.cred } -func (f *fakeGen3Upload) Logger() *logs.TeeLogger { return f.logger } -func (f *fakeGen3Upload) New(method, url string) *request.RequestBuilder { - return &request.RequestBuilder{Method: method, Url: url} +func (f *fakeGen3Upload) Logger() *logs.Gen3Logger { return f.logger } +func (f *fakeGen3Upload) ExportCredential(ctx context.Context, cred *conf.Credential) error { + return nil } -func (f *fakeGen3Upload) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { +func (f *fakeGen3Upload) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } +func (f *fakeGen3Upload) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } +func (f *fakeGen3Upload) Sower() sower.SowerInterface { return nil } + +type fakeFence struct { + fence.FenceInterface + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeFence) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { return f.doFunc(ctx, req) } -func (f *fakeGen3Upload) CheckPrivileges(context.Context) (map[string]any, error) { - return nil, nil +func (f *fakeFence) InitMultipartUpload(ctx context.Context, filename string, bucket string, guid string) (fence.FenceResponse, error) { + resp, err := f.Do(ctx, &request.RequestBuilder{Url: common.FenceDataMultipartInitEndpoint}) + if err != nil { + return fence.FenceResponse{}, err + } + return f.ParseFenceURLResponse(resp) +} +func (f *fakeFence) GenerateMultipartPresignedURL(ctx context.Context, key string, uploadID string, partNumber int, bucket string) (string, error) { + resp, err := f.Do(ctx, &request.RequestBuilder{Url: common.FenceDataMultipartUploadEndpoint}) + if err != nil { + return "", err + } + msg, err := f.ParseFenceURLResponse(resp) + return msg.PresignedURL, err +} +func (f *fakeFence) CompleteMultipartUpload(ctx context.Context, key string, uploadID string, parts []fence.MultipartPart, bucket string) error { + _, err := f.Do(ctx, &request.RequestBuilder{Url: common.FenceDataMultipartCompleteEndpoint}) + return err } -func (f *fakeGen3Upload) CheckForShepherdAPI(context.Context) (bool, error) { return false, nil } -func (f *fakeGen3Upload) DeleteRecord(context.Context, string) (string, error) { - return "", nil +func (f *fakeFence) ParseFenceURLResponse(resp *http.Response) (fence.FenceResponse, error) { + var msg fence.FenceResponse + if resp != nil && resp.Body != nil { + json.NewDecoder(resp.Body).Decode(&msg) + } + return msg, nil } -func (f *fakeGen3Upload) GetDownloadPresignedUrl(context.Context, string, string) (string, error) { - return "", nil + +type fakeIndexd struct { + indexd.IndexdInterface + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) } -func (f *fakeGen3Upload) ParseFenceURLResponse(resp *http.Response) (api.FenceResponse, error) { - return (&api.Functions{}).ParseFenceURLResponse(resp) + +func (f *fakeIndexd) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + return f.doFunc(ctx, req) } -func (f *fakeGen3Upload) ExportCredential(context.Context, *conf.Credential) error { return nil } -func (f *fakeGen3Upload) NewAccessToken(context.Context) error { return nil } func TestMultipartUploadProgressIntegration(t *testing.T) { ctx := context.Background() @@ -89,7 +121,7 @@ func TestMultipartUploadProgressIntegration(t *testing.T) { return nil } - logger := logs.NewTeeLogger("", "", io.Discard) + logger := logs.NewGen3Logger(nil, "", "") fake := &fakeGen3Upload{ cred: &conf.Credential{ APIEndpoint: "https://example.com", @@ -111,14 +143,15 @@ func TestMultipartUploadProgressIntegration(t *testing.T) { } requestObject := common.FileUploadRequestObject{ - FilePath: file.Name(), - Filename: "multipart.bin", - GUID: "guid-123", - OID: "oid-123", - Bucket: "bucket", - Progress: progress, + SourcePath: file.Name(), + ObjectKey: "multipart.bin", + GUID: "guid-123", + Bucket: "bucket", } + ctx = common.WithProgress(ctx, progress) + ctx = common.WithOid(ctx, "guid-123") + if err := MultipartUpload(ctx, fake, requestObject, file, false); err != nil { t.Fatalf("multipart upload failed: %v", err) } diff --git a/client/upload/progress_reader.go b/upload/progress_reader.go similarity index 56% rename from client/upload/progress_reader.go rename to upload/progress_reader.go index 8262ee5..b7b3294 100644 --- a/client/upload/progress_reader.go +++ b/upload/progress_reader.go @@ -1,36 +1,34 @@ package upload import ( + "fmt" "io" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) type progressReader struct { reader io.Reader onProgress common.ProgressCallback - oid string + hash string total int64 bytesSoFar int64 } -func newProgressReader(reader io.Reader, onProgress common.ProgressCallback, oid string, total int64) *progressReader { +func newProgressReader(reader io.Reader, onProgress common.ProgressCallback, hash string, total int64) *progressReader { return &progressReader{ reader: reader, onProgress: onProgress, - oid: oid, + hash: hash, total: total, } } func resolveUploadOID(req common.FileUploadRequestObject) string { - if req.OID != "" { - return req.OID + if req.ObjectKey != "" { + return req.ObjectKey } - if req.GUID != "" { - return req.GUID - } - return req.Filename + return req.GUID } func (pr *progressReader) Read(p []byte) (int, error) { @@ -40,7 +38,7 @@ func (pr *progressReader) Read(p []byte) (int, error) { pr.bytesSoFar += delta if progressErr := pr.onProgress(common.ProgressEvent{ Event: "progress", - Oid: pr.oid, + Oid: pr.hash, BytesSoFar: pr.bytesSoFar, BytesSinceLast: delta, }); progressErr != nil { @@ -51,18 +49,18 @@ func (pr *progressReader) Read(p []byte) (int, error) { } func (pr *progressReader) Finalize() error { - if pr.onProgress == nil { - return nil - } - if pr.total == 0 || pr.bytesSoFar >= pr.total { - return nil + if pr.total > 0 && pr.bytesSoFar < pr.total { + delta := pr.total - pr.bytesSoFar + pr.bytesSoFar = pr.total + if pr.onProgress != nil { + _ = pr.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pr.hash, + BytesSoFar: pr.bytesSoFar, + BytesSinceLast: delta, + }) + } + return fmt.Errorf("upload incomplete: %d/%d bytes", pr.bytesSoFar-delta, pr.total) } - delta := pr.total - pr.bytesSoFar - pr.bytesSoFar = pr.total - return pr.onProgress(common.ProgressEvent{ - Event: "progress", - Oid: pr.oid, - BytesSoFar: pr.bytesSoFar, - BytesSinceLast: delta, - }) + return nil } diff --git a/client/upload/progress_reader_test.go b/upload/progress_reader_test.go similarity index 95% rename from client/upload/progress_reader_test.go rename to upload/progress_reader_test.go index de77d8e..789afa0 100644 --- a/client/upload/progress_reader_test.go +++ b/upload/progress_reader_test.go @@ -5,7 +5,7 @@ import ( "io" "testing" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) func TestProgressReaderFinalizes(t *testing.T) { diff --git a/client/upload/request.go b/upload/request.go similarity index 53% rename from client/upload/request.go rename to upload/request.go index 894db52..6036fed 100644 --- a/client/upload/request.go +++ b/upload/request.go @@ -6,47 +6,24 @@ import ( "errors" "fmt" "net/http" - "net/url" "os" "strings" - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - req "github.com/calypr/data-client/client/request" + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" + req "github.com/calypr/data-client/request" "github.com/vbauerster/mpb/v8" ) // GeneratePresignedURL handles both Shepherd and Fence fallback func GeneratePresignedUploadURL(ctx context.Context, g3 client.Gen3Interface, filename string, metadata common.FileMetadata, bucket string) (*PresignedURLResponse, error) { - hasShepherd, err := g3.CheckForShepherdAPI(ctx) + hasShepherd, err := g3.Fence().CheckForShepherdAPI(ctx) if err != nil || !hasShepherd { - payload := map[string]string{ - "file_name": filename, - } - if bucket != "" { - payload["bucket"] = bucket - } - - buf, err := common.ToJSONReader(payload) - if err != nil { - return nil, err - } - - cred := g3.GetCredential() - resp, err := g3.Do( - ctx, - &req.RequestBuilder{ - Method: http.MethodPost, - Url: cred.APIEndpoint + common.FenceDataUploadEndpoint, - Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, - Body: buf, - Token: cred.AccessToken, - }) + msg, err := g3.Fence().InitUpload(ctx, filename, bucket, "") if err != nil { return nil, err } - msg, err := g3.ParseFenceURLResponse(resp) - return &PresignedURLResponse{msg.URL, msg.GUID}, err + return &PresignedURLResponse{URL: msg.URL, GUID: msg.GUID}, nil } shepherdPayload := ShepherdInitRequestObject{ @@ -64,7 +41,7 @@ func GeneratePresignedUploadURL(ctx context.Context, g3 client.Gen3Interface, fi } cred := g3.GetCredential() - r, err := g3.Do( + r, err := g3.Fence().Do( ctx, &req.RequestBuilder{ Url: cred.APIEndpoint + common.ShepherdEndpoint + "/objects", @@ -86,42 +63,23 @@ func GeneratePresignedUploadURL(ctx context.Context, g3 client.Gen3Interface, fi // GenerateUploadRequest helps preparing the HTTP request for upload and the progress bar for single part upload func generateUploadRequest(ctx context.Context, g3 client.Gen3Interface, furObject common.FileUploadRequestObject, file *os.File, progress *mpb.Progress) (common.FileUploadRequestObject, error) { if furObject.PresignedURL == "" { - endPointPostfix := common.FenceDataUploadEndpoint + "/" + furObject.GUID + "?file_name=" + url.QueryEscape(furObject.Filename) - - if furObject.Bucket != "" { - endPointPostfix += "&bucket=" + furObject.Bucket - } - cred := g3.GetCredential() - resp, err := g3.Do( - ctx, - &req.RequestBuilder{ - Url: cred.APIEndpoint + endPointPostfix, - Headers: map[string]string{common.HeaderContentType: common.MIMEApplicationJSON}, - Token: cred.AccessToken, - Method: http.MethodGet, - }, - ) - if err != nil { - return furObject, fmt.Errorf("Upload error: %w", err) - } - - msg, err := g3.ParseFenceURLResponse(resp) + msg, err := g3.Fence().GetUploadPresignedUrl(ctx, furObject.GUID, furObject.ObjectKey, furObject.Bucket) if err != nil && !strings.Contains(err.Error(), "No GUID found") { return furObject, fmt.Errorf("Upload error: %w", err) } if msg.URL == "" { - return furObject, errors.New("Upload error: error in generating presigned URL for " + furObject.Filename) + return furObject, errors.New("Upload error: error in generating presigned URL for " + furObject.ObjectKey) } furObject.PresignedURL = msg.URL } fi, err := file.Stat() if err != nil { - return furObject, errors.New("File stat error for file" + furObject.Filename + ", file may be missing or unreadable because of permissions.\n") + return furObject, errors.New("File stat error for file" + furObject.ObjectKey + ", file may be missing or unreadable because of permissions.\n") } if fi.Size() > common.FileSizeLimit { - return furObject, errors.New("The file size of file " + furObject.Filename + " exceeds the limit allowed and cannot be uploaded. The maximum allowed file size is " + FormatSize(common.FileSizeLimit) + ".\n") + return furObject, errors.New("The file size of file " + furObject.ObjectKey + " exceeds the limit allowed and cannot be uploaded. The maximum allowed file size is " + FormatSize(common.FileSizeLimit) + ".\n") } return furObject, err diff --git a/client/upload/retry.go b/upload/retry.go similarity index 77% rename from client/upload/retry.go rename to upload/retry.go index ca3779f..679a93d 100644 --- a/client/upload/retry.go +++ b/upload/retry.go @@ -6,8 +6,8 @@ import ( "path/filepath" "time" - client "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" ) // GetWaitTime calculates exponential backoff with cap @@ -45,14 +45,14 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap for ro := range retryChan { ro.RetryCount++ - logger.Printf("#%d retry — %s\n", ro.RetryCount, ro.FilePath) + logger.Printf("#%d retry — %s\n", ro.RetryCount, ro.SourcePath) wait := GetWaitTime(ro.RetryCount) logger.Printf("Waiting %.0f seconds before retry...\n", wait.Seconds()) time.Sleep(wait) // Clean up old record if exists if ro.GUID != "" { - if msg, err := g3.DeleteRecord( + if msg, err := g3.Fence().DeleteRecord( ctx, ro.GUID, ); err == nil { @@ -60,29 +60,29 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap } } - file, err := os.Open(ro.FilePath) + file, err := os.Open(ro.SourcePath) if err != nil { continue } // Ensure filename is set - if ro.Filename == "" { - absPath, _ := common.GetAbsolutePath(ro.FilePath) - ro.Filename = filepath.Base(absPath) + if ro.ObjectKey == "" { + absPath, _ := common.GetAbsolutePath(ro.SourcePath) + ro.ObjectKey = filepath.Base(absPath) } if ro.Multipart { // Retry multipart req := common.FileUploadRequestObject{ - FilePath: ro.FilePath, - Filename: ro.Filename, + SourcePath: ro.SourcePath, + ObjectKey: ro.ObjectKey, GUID: ro.GUID, FileMetadata: ro.FileMetadata, Bucket: ro.Bucket, } err = MultipartUpload(ctx, g3, req, file, true) if err == nil { - logger.Succeeded(ro.FilePath, req.GUID) + logger.Succeeded(ro.SourcePath, req.GUID) if sb != nil { sb.IncrementSB(ro.RetryCount - 1) } @@ -90,13 +90,13 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap } } else { // Retry single-part - respObj, err := GeneratePresignedUploadURL(ctx, g3, ro.Filename, ro.FileMetadata, ro.Bucket) + respObj, err := GeneratePresignedUploadURL(ctx, g3, ro.ObjectKey, ro.FileMetadata, ro.Bucket) if err != nil { handleRetryFailure(ctx, g3, ro, retryChan, err) continue } - file, err := os.Open(ro.FilePath) + file, err := os.Open(ro.SourcePath) if err != nil { handleRetryFailure(ctx, g3, ro, retryChan, err) continue @@ -111,8 +111,8 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap } fur := common.FileUploadRequestObject{ - FilePath: ro.FilePath, - Filename: ro.Filename, + SourcePath: ro.SourcePath, + ObjectKey: ro.ObjectKey, FileMetadata: ro.FileMetadata, GUID: respObj.GUID, PresignedURL: respObj.URL, @@ -124,9 +124,9 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap continue } - err = UploadSingleFile(ctx, g3, fur, true) + err = UploadSingle(ctx, g3, fur, true) if err == nil { - logger.Succeeded(ro.FilePath, fur.GUID) + logger.Succeeded(ro.SourcePath, fur.GUID) if sb != nil { sb.IncrementSB(ro.RetryCount - 1) } @@ -142,7 +142,7 @@ func RetryFailedUploads(ctx context.Context, g3 client.Gen3Interface, failedMap // handleRetryFailure logs failure and requeues if retries remain func handleRetryFailure(ctx context.Context, g3 client.Gen3Interface, ro common.RetryObject, retryChan chan common.RetryObject, err error) { logger := g3.Logger() - logger.Failed(ro.FilePath, ro.Filename, ro.FileMetadata, ro.GUID, ro.RetryCount, ro.Multipart) + logger.Failed(ro.SourcePath, ro.ObjectKey, ro.FileMetadata, ro.GUID, ro.RetryCount, ro.Multipart) if err != nil { logger.Println("Retry error:", err) } @@ -154,7 +154,7 @@ func handleRetryFailure(ctx context.Context, g3 client.Gen3Interface, ro common. // Max retries reached — final cleanup if ro.GUID != "" { - if msg, err := g3.DeleteRecord(ctx, ro.GUID); err == nil { + if msg, err := g3.Fence().DeleteRecord(ctx, ro.GUID); err == nil { logger.Println("Cleaned up failed record:", msg) } else { logger.Println("Cleanup failed:", err) diff --git a/upload/singleFile.go b/upload/singleFile.go new file mode 100644 index 0000000..962a468 --- /dev/null +++ b/upload/singleFile.go @@ -0,0 +1,96 @@ +package upload + +import ( + "context" + "fmt" + "io" + "os" + + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" +) + +func UploadSingle(ctx context.Context, g3Client client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + + // We use the provided client interface + g3i := g3Client + + g3i.Logger().InfoContext(ctx, "File Upload Request", "request", req) + + // Helper to handle * in path if it was passed, though optimally caller handles this. + // We will trust the SourcePath in the request object mostly, but for safety we can check existence. + // But commonly parsing happens before creating the object usually. + // Let's assume req.SourcePath is a single valid file path for now as per design. + + file, err := os.Open(req.SourcePath) + if err != nil { + if showProgress { + sb := g3i.Logger().Scoreboard() + if sb != nil { + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } + } + g3i.Logger().Failed(req.SourcePath, req.ObjectKey, common.FileMetadata{}, "", 0, false) + g3i.Logger().ErrorContext(ctx, "File open error", "file", req.SourcePath, "error", err) + return fmt.Errorf("[ERROR] when opening file path %s, an error occurred: %s\n", req.SourcePath, err.Error()) + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + fileSize := fi.Size() + + furObject, err := generateUploadRequest(ctx, g3i, req, file, nil) + if err != nil { + if showProgress { + sb := g3i.Logger().Scoreboard() + if sb != nil { + sb.IncrementSB(len(sb.Counts)) + sb.PrintSB() + } + } + g3i.Logger().Failed(req.SourcePath, req.ObjectKey, common.FileMetadata{}, req.GUID, 0, false) + g3i.Logger().ErrorContext(ctx, "Error occurred during request generation", "file", req.SourcePath, "error", err) + return fmt.Errorf("[ERROR] Error occurred during request generation for file %s: %s\n", req.SourcePath, err.Error()) + } + + progressCallback := common.GetProgress(ctx) + oid := common.GetOid(ctx) + if oid == "" { + oid = resolveUploadOID(furObject) + } + + var reader io.Reader = file + var progressTracker *progressReader + if progressCallback != nil { + progressTracker = newProgressReader(file, progressCallback, oid, fileSize) + reader = progressTracker + } + + _, err = uploadPart(ctx, furObject.PresignedURL, reader, fileSize) + if progressTracker != nil { + if finalizeErr := progressTracker.Finalize(); finalizeErr != nil && err == nil { + err = finalizeErr + } + } + + if err != nil { + g3i.Logger().ErrorContext(ctx, "Upload failed", "error", err) + return err + } + + g3i.Logger().InfoContext(ctx, "Successfully uploaded", "file", req.ObjectKey) + g3i.Logger().Succeeded(req.SourcePath, req.GUID) + + if showProgress { + sb := g3i.Logger().Scoreboard() + if sb != nil { + sb.IncrementSB(0) + sb.PrintSB() + } + } + return nil +} diff --git a/client/upload/types.go b/upload/types.go similarity index 52% rename from client/upload/types.go rename to upload/types.go index 2ef2f29..8c69ce1 100644 --- a/client/upload/types.go +++ b/upload/types.go @@ -1,17 +1,12 @@ package upload -import "github.com/calypr/data-client/client/common" +import "github.com/calypr/data-client/common" type PresignedURLResponse struct { GUID string `json:"guid"` URL string `json:"upload_url"` } -type MultipartPartObject struct { - PartNumber int `json:"PartNumber"` - ETag string `json:"ETag"` -} - type UploadConfig struct { BucketName string NumParallel int @@ -21,13 +16,6 @@ type UploadConfig struct { ShowProgress bool } -// InitRequestObject represents the payload that sends to FENCE for getting a singlepart upload presignedURL or init a multipart upload for new object file -type InitRequestObject struct { - Filename string `json:"file_name"` - Bucket string `json:"bucket,omitempty"` - GUID string `json:"guid,omitempty"` -} - // ShepherdInitRequestObject represents the payload that sends to Shepherd for getting a singlepart upload presignedURL or init a multipart upload for new object file type ShepherdInitRequestObject struct { Filename string `json:"file_name"` @@ -36,27 +24,12 @@ type ShepherdInitRequestObject struct { // Metadata is an encoded JSON string of any arbitrary metadata the user wishes to upload. Metadata map[string]any `json:"metadata"` } + type ShepherdAuthz struct { Version string `json:"version"` ResourcePaths []string `json:"resource_paths"` } -// MultipartUploadRequestObject represents the payload that sends to FENCE for getting a presignedURL for a part -type MultipartUploadRequestObject struct { - Key string `json:"key"` - UploadID string `json:"uploadId"` - PartNumber int `json:"partNumber"` - Bucket string `json:"bucket,omitempty"` -} - -// MultipartCompleteRequestObject represents the payload that sends to FENCE for completeing a multipart upload -type MultipartCompleteRequestObject struct { - Key string `json:"key"` - UploadID string `json:"uploadId"` - Parts []MultipartPartObject `json:"parts"` - Bucket string `json:"bucket,omitempty"` -} - // FileInfo is a helper struct for including subdirname as filename type FileInfo struct { FilePath string diff --git a/upload/upload.go b/upload/upload.go new file mode 100644 index 0000000..fb62e96 --- /dev/null +++ b/upload/upload.go @@ -0,0 +1,208 @@ +package upload + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" + drs "github.com/calypr/data-client/indexd/drs" // Imported for DRSObject + "github.com/vbauerster/mpb/v8" +) + +// Upload is a unified catch-all function that automatically chooses between +// single-part and multipart upload based on file size. +func Upload(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + g3.Logger().Printf("Processing Upload Request for: %s\n", req.SourcePath) + + file, err := os.Open(req.SourcePath) + if err != nil { + return fmt.Errorf("cannot open file %s: %w", req.SourcePath, err) + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return fmt.Errorf("cannot stat file: %w", err) + } + + fileSize := stat.Size() + if fileSize == 0 { + return fmt.Errorf("file is empty: %s", req.ObjectKey) + } + + // Use Single-Part if file is smaller than 5GB (or your defined limit) + if fileSize < 5*common.GB { + g3.Logger().Printf("File size %d bytes (< 5GB), performing single-part upload\n", fileSize) + return UploadSingle(ctx, g3, req, true) + } + g3.Logger().Printf("File size %d bytes (>= 5GB), performing multipart upload\n", fileSize) + return MultipartUpload(ctx, g3, req, file, showProgress) +} + +// UploadSingleFile handles single-part upload with progress +func UploadSingleFile(ctx context.Context, g3 client.Gen3Interface, req common.FileUploadRequestObject, showProgress bool) error { + file, err := os.Open(req.SourcePath) + if err != nil { + return err + } + defer file.Close() + + fi, _ := file.Stat() + if fi.Size() > common.FileSizeLimit { + return fmt.Errorf("file exceeds 5GB limit") + } + + if fi.Size() > common.FileSizeLimit { + return fmt.Errorf("file exceeds 5GB limit") + } + + // Generate request with progress bar + var p *mpb.Progress + if showProgress { + p = mpb.New(mpb.WithOutput(os.Stdout)) + } + + // Populate PresignedURL and GUID if missing + fur, err := generateUploadRequest(ctx, g3, req, file, p) + if err != nil { + return err + } + + return MultipartUpload(ctx, g3, fur, file, showProgress) +} + +// RegisterAndUploadFile orchestrates registration with Indexd and uploading via Fence. +// It handles checking for existing records, upsert logic, checking if file is already downloadable, and performing the upload. +func RegisterAndUploadFile(ctx context.Context, g3 client.Gen3Interface, drsObject *drs.DRSObject, filePath string, bucketName string, upsert bool) (*drs.DRSObject, error) { + // 1. Register with Indexd + // Note: The caller is responsible for converting local DRS object to data-client DRS object if needed. + + res, err := g3.Indexd().RegisterRecord(ctx, drsObject) + if err != nil { + if strings.Contains(err.Error(), "already exists") { + if !upsert { + g3.Logger().Printf("indexd record already exists, proceeding for %s\n", drsObject.Id) + } else { + g3.Logger().Printf("indexd record already exists, deleting and re-adding for %s\n", drsObject.Id) + err = g3.Indexd().DeleteIndexdRecord(ctx, drsObject.Id) + if err != nil { + return nil, fmt.Errorf("failed to delete existing record: %w", err) + } + res, err = g3.Indexd().RegisterRecord(ctx, drsObject) + if err != nil { + return nil, fmt.Errorf("failed to re-register record: %w", err) + } + } + } else { + return nil, fmt.Errorf("error registering indexd record: %w", err) + } + } else { + // If registration succeeded, use the returned object which might have updated fields (e.g. created time) + // although we typically reuse the ID for upload. + } + + // If we didn't get a new object (upsert=false case), we should fetch the existing one to be sure about its state? + // But we have the ID in drsObject.Id. + + // 2. Check if file is downloadable + downloadable, err := isFileDownloadable(ctx, g3, drsObject.Id) + if err != nil { + return nil, fmt.Errorf("failed to check if file is downloadable: %w", err) + } + + if downloadable { + g3.Logger().Printf("File %s is already downloadable, skipping upload.\n", drsObject.Id) + // Return the registered object (or the one passed in if we didn't re-register) + // If we re-registered, res is populated. If not, we might want to return the fetched object? + // For consistency, let's return res if set, or fetch it. + if res != nil { + return res, nil + } + return g3.Indexd().GetObject(ctx, drsObject.Id) + } + + // 3. Upload File + uploadFilename := filepath.Base(filePath) + + // Attempt to determine the correct upload filename from the registered object's URL. + // git-drs registers s3://bucket/GUID/SHA, so we want to upload to "SHA", not "filename.ext". + if res != nil && len(res.AccessMethods) > 0 { + for _, am := range res.AccessMethods { + if am.Type == "s3" && am.AccessURL.URL != "" { + // Parse s3://bucket/guid/sha -> sha + parts := strings.Split(am.AccessURL.URL, "/") + if len(parts) > 0 { + candidate := parts[len(parts)-1] + if candidate != "" { + uploadFilename = candidate + } + } + break + } + } + } else if len(drsObject.AccessMethods) > 0 { + // Fallback to checking the input object if res didn't have methods (unlikely for upsert=false) + for _, am := range drsObject.AccessMethods { + if am.Type == "s3" && am.AccessURL.URL != "" { + parts := strings.Split(am.AccessURL.URL, "/") + if len(parts) > 0 { + candidate := parts[len(parts)-1] + if candidate != "" { + uploadFilename = candidate + } + } + break + } + } + } + + req := common.FileUploadRequestObject{ + SourcePath: filePath, + ObjectKey: uploadFilename, + GUID: drsObject.Id, + Bucket: bucketName, + } + + // Use Upload function which handles single/multipart selection + err = Upload(ctx, g3, req, false) + if err != nil { + return nil, fmt.Errorf("failed to upload file: %w", err) + } + + // Return the object + if res != nil { + return res, nil + } + return g3.Indexd().GetObject(ctx, drsObject.Id) +} + +func isFileDownloadable(ctx context.Context, g3 client.Gen3Interface, did string) (bool, error) { + // Get the object to find access methods + obj, err := g3.Indexd().GetObject(ctx, did) + if err != nil { + return false, err + } + + if len(obj.AccessMethods) == 0 { + return false, nil + } + + accessType := obj.AccessMethods[0].Type + res, err := g3.Indexd().GetDownloadURL(ctx, did, accessType) + if err != nil { + // If we can't get a download URL, it's not downloadable + return false, nil + } + + if res.URL == "" { + return false, nil + } + + // Check if the URL is accessible + err = common.CanDownloadFile(res.URL) + return err == nil, nil +} diff --git a/client/upload/utils.go b/upload/utils.go similarity index 87% rename from client/upload/utils.go rename to upload/utils.go index c26f3fc..54cf836 100644 --- a/client/upload/utils.go +++ b/upload/utils.go @@ -8,9 +8,9 @@ import ( "path/filepath" "strings" - "github.com/calypr/data-client/client/client" - "github.com/calypr/data-client/client/common" - "github.com/calypr/data-client/client/logs" + "github.com/calypr/data-client/common" + client "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/logs" ) func SeparateSingleAndMultipartUploads(g3i client.Gen3Interface, objects []common.FileUploadRequestObject) ([]common.FileUploadRequestObject, []common.FileUploadRequestObject) { @@ -20,21 +20,21 @@ func SeparateSingleAndMultipartUploads(g3i client.Gen3Interface, objects []commo var multipartObjects []common.FileUploadRequestObject for _, object := range objects { - fi, err := os.Stat(object.FilePath) + fi, err := os.Stat(object.SourcePath) if err != nil { if os.IsNotExist(err) { - g3i.Logger().Printf("The file you specified \"%s\" does not exist locally\n", object.FilePath) + g3i.Logger().Printf("The file you specified \"%s\" does not exist locally\n", object.SourcePath) } else { g3i.Logger().Println("File stat error: " + err.Error()) } - g3i.Logger().Failed(object.FilePath, object.Filename, object.FileMetadata, object.GUID, 0, false) + g3i.Logger().Failed(object.SourcePath, object.ObjectKey, object.FileMetadata, object.GUID, 0, false) continue } if fi.IsDir() { continue } - if _, ok := g3i.Logger().GetSucceededLogMap()[object.FilePath]; ok { - g3i.Logger().Println("File \"" + object.FilePath + "\" found in history. Skipping.") + if _, ok := g3i.Logger().GetSucceededLogMap()[object.SourcePath]; ok { + g3i.Logger().Println("File \"" + object.SourcePath + "\" found in history. Skipping.") continue } if fi.Size() > common.MultipartFileSizeLimit { @@ -51,7 +51,7 @@ func SeparateSingleAndMultipartUploads(g3i client.Gen3Interface, objects []commo } // ProcessFilename returns an FileInfo object which has the information about the path and name to be used for upload of a file -func ProcessFilename(logger logs.Logger, uploadPath string, filePath string, objectId string, includeSubDirName bool, includeMetadata bool) (common.FileUploadRequestObject, error) { +func ProcessFilename(logger *logs.Gen3Logger, uploadPath string, filePath string, objectId string, includeSubDirName bool, includeMetadata bool) (common.FileUploadRequestObject, error) { var err error filePath, err = common.GetAbsolutePath(filePath) if err != nil { @@ -102,7 +102,7 @@ func ProcessFilename(logger logs.Logger, uploadPath string, filePath string, obj logger.Printf("WARNING: File metadata is enabled, but could not find the metadata file %v for file %v. Execute `data-client upload --help` for more info on file metadata.\n", metadataFilePath, filePath) } } - return common.FileUploadRequestObject{FilePath: filePath, Filename: filename, FileMetadata: metadata, GUID: objectId}, nil + return common.FileUploadRequestObject{SourcePath: filePath, ObjectKey: filename, FileMetadata: metadata, GUID: objectId}, nil } // FormatSize helps to parse a int64 size into string diff --git a/client/upload/utils_test.go b/upload/utils_test.go similarity index 98% rename from client/upload/utils_test.go rename to upload/utils_test.go index 8681096..6abe45e 100644 --- a/client/upload/utils_test.go +++ b/upload/utils_test.go @@ -3,7 +3,7 @@ package upload import ( "testing" - "github.com/calypr/data-client/client/common" + "github.com/calypr/data-client/common" ) func TestOptimalChunkSize(t *testing.T) { From 16fb1ee19c839b01dd47e001cb9a6a598265917a Mon Sep 17 00:00:00 2001 From: Brian Date: Thu, 5 Feb 2026 08:05:24 -0800 Subject: [PATCH 13/14] Refactor/progress callback (#29) * fix tests * make progress logging tied to minChunkSize * refactor/add-url #28 * typo * improve onProgress bytesSinceReport >= 1MB * OnProgressThreshold --------- Co-authored-by: matthewpeterkort --- common/constants.go | 2 + download/progress_writer.go | 39 +- download/transfer_test.go | 24 +- {indexd/drs => drs}/drs.go | 2 +- drs/object_builder.go | 56 +++ drs/object_builder_test.go | 51 ++ {indexd/drs => drs}/types.go | 4 +- {indexd/hash => hash}/hash.go | 0 {indexd/hash => hash}/hash_test.go | 0 indexd/add_url.go | 106 ---- indexd/client.go | 5 +- indexd/client_test.go | 4 +- indexd/convert.go | 2 +- indexd/records.go | 2 +- indexd/tests/add-url-integration_test.go | 68 --- indexd/tests/client_read_test.go.todo | 134 ----- indexd/tests/client_write_test.go.todo | 369 -------------- indexd/tests/mock_servers_test.go | 610 ----------------------- indexd/types.go | 9 +- indexd/types_test.go | 4 +- indexd/upsert.go | 54 ++ mocks/mock_fence.go | 29 ++ mocks/mock_gen3interface.go | 30 ++ mocks/mock_indexd.go | 2 +- {indexd => s3utils}/s3_utils.go | 7 +- upload/multipart_test.go | 8 +- upload/progress_reader.go | 39 +- upload/upload.go | 2 +- 28 files changed, 320 insertions(+), 1342 deletions(-) rename {indexd/drs => drs}/drs.go (98%) create mode 100644 drs/object_builder.go create mode 100644 drs/object_builder_test.go rename {indexd/drs => drs}/types.go (96%) rename {indexd/hash => hash}/hash.go (100%) rename {indexd/hash => hash}/hash_test.go (100%) delete mode 100644 indexd/add_url.go delete mode 100644 indexd/tests/add-url-integration_test.go delete mode 100644 indexd/tests/client_read_test.go.todo delete mode 100644 indexd/tests/client_write_test.go.todo delete mode 100644 indexd/tests/mock_servers_test.go create mode 100644 indexd/upsert.go rename {indexd => s3utils}/s3_utils.go (97%) diff --git a/common/constants.go b/common/constants.go index 6299a2c..1191795 100644 --- a/common/constants.go +++ b/common/constants.go @@ -85,6 +85,8 @@ const ( MaxMultipartParts = 10000 MaxConcurrentUploads = 10 MaxRetries = 5 + + OnProgressThreshold = 1 * MB ) var ( diff --git a/download/progress_writer.go b/download/progress_writer.go index dd1abf0..3917234 100644 --- a/download/progress_writer.go +++ b/download/progress_writer.go @@ -8,11 +8,12 @@ import ( ) type progressWriter struct { - writer io.Writer - onProgress common.ProgressCallback - hash string - total int64 - bytesSoFar int64 + writer io.Writer + onProgress common.ProgressCallback + hash string + total int64 + bytesSoFar int64 + bytesSinceReport int64 } func newProgressWriter(writer io.Writer, onProgress common.ProgressCallback, hash string, total int64) *progressWriter { @@ -29,19 +30,33 @@ func (pw *progressWriter) Write(p []byte) (int, error) { if n > 0 && pw.onProgress != nil { delta := int64(n) pw.bytesSoFar += delta - if progressErr := pw.onProgress(common.ProgressEvent{ - Event: "progress", - Oid: pw.hash, - BytesSoFar: pw.bytesSoFar, - BytesSinceLast: delta, - }); progressErr != nil { - return n, progressErr + pw.bytesSinceReport += delta + + if pw.bytesSinceReport >= common.OnProgressThreshold { + if progressErr := pw.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pw.hash, + BytesSoFar: pw.bytesSoFar, + BytesSinceLast: pw.bytesSinceReport, + }); progressErr != nil { + return n, progressErr + } + pw.bytesSinceReport = 0 } } return n, err } func (pw *progressWriter) Finalize() error { + if pw.onProgress != nil && pw.bytesSinceReport > 0 { + _ = pw.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pw.hash, + BytesSoFar: pw.bytesSoFar, + BytesSinceLast: pw.bytesSinceReport, + }) + pw.bytesSinceReport = 0 + } if pw.total > 0 && pw.bytesSoFar < pw.total { delta := pw.total - pw.bytesSoFar pw.bytesSoFar = pw.total diff --git a/download/transfer_test.go b/download/transfer_test.go index aab05e7..d811afe 100644 --- a/download/transfer_test.go +++ b/download/transfer_test.go @@ -15,11 +15,12 @@ import ( "github.com/calypr/data-client/common" "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/fence" "github.com/calypr/data-client/indexd" - "github.com/calypr/data-client/indexd/drs" "github.com/calypr/data-client/logs" "github.com/calypr/data-client/request" + "github.com/calypr/data-client/requestor" "github.com/calypr/data-client/sower" ) @@ -34,9 +35,10 @@ func (f *fakeGen3Download) Logger() *logs.Gen3Logger { return f.logger } func (f *fakeGen3Download) ExportCredential(ctx context.Context, cred *conf.Credential) error { return nil } -func (f *fakeGen3Download) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } -func (f *fakeGen3Download) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } -func (f *fakeGen3Download) Sower() sower.SowerInterface { return nil } +func (f *fakeGen3Download) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } +func (f *fakeGen3Download) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } +func (f *fakeGen3Download) Sower() sower.SowerInterface { return nil } +func (f *fakeGen3Download) Requestor() requestor.RequestorInterface { return nil } type fakeFence struct { fence.FenceInterface @@ -200,3 +202,17 @@ func newDownloadResponse(rawURL string, payload []byte, status int) *http.Respon Header: make(http.Header), } } + +// fakeRequestor implements requestor.RequestorInterface using the same doFunc. +type fakeRequestor struct { + requestor.RequestorInterface + doFunc func(context.Context, *request.RequestBuilder) (*http.Response, error) +} + +func (f *fakeRequestor) Do(ctx context.Context, req *request.RequestBuilder) (*http.Response, error) { + return f.doFunc(ctx, req) +} + +func (f *fakeRequestor) New(method, url string) *request.RequestBuilder { + return &request.RequestBuilder{Method: method, Url: url, Headers: make(map[string]string)} +} diff --git a/indexd/drs/drs.go b/drs/drs.go similarity index 98% rename from indexd/drs/drs.go rename to drs/drs.go index 46ea800..55feb1a 100644 --- a/indexd/drs/drs.go +++ b/drs/drs.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" "github.com/google/uuid" ) diff --git a/drs/object_builder.go b/drs/object_builder.go new file mode 100644 index 0000000..61fec11 --- /dev/null +++ b/drs/object_builder.go @@ -0,0 +1,56 @@ +package drs + +import ( + "fmt" + "path/filepath" + + "github.com/calypr/data-client/hash" +) + +type ObjectBuilder struct { + Bucket string + ProjectID string + AccessType string +} + +func NewObjectBuilder(bucket, projectID string) ObjectBuilder { + return ObjectBuilder{ + Bucket: bucket, + ProjectID: projectID, + AccessType: "s3", + } +} + +func (b ObjectBuilder) Build(fileName string, checksum string, size int64, drsID string) (*DRSObject, error) { + if b.Bucket == "" { + return nil, fmt.Errorf("error: bucket name is empty in config file") + } + accessType := b.AccessType + if accessType == "" { + accessType = "s3" + } + + fileURL := fmt.Sprintf("s3://%s", filepath.Join(b.Bucket, drsID, checksum)) + + authzStr, err := ProjectToResource(b.ProjectID) + if err != nil { + return nil, err + } + authorizations := Authorizations{ + Value: authzStr, + } + + drsObj := DRSObject{ + Id: drsID, + Name: fileName, + AccessMethods: []AccessMethod{{ + Type: accessType, + AccessURL: AccessURL{URL: fileURL}, + Authorizations: &authorizations, + }}, + Checksums: hash.HashInfo{SHA256: checksum}, + Size: size, + } + + return &drsObj, nil +} diff --git a/drs/object_builder_test.go b/drs/object_builder_test.go new file mode 100644 index 0000000..e196e00 --- /dev/null +++ b/drs/object_builder_test.go @@ -0,0 +1,51 @@ +package drs + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestObjectBuilderBuildSuccess(t *testing.T) { + builder := ObjectBuilder{ + ProjectID: "test-project", + Bucket: "bucket", + } + + obj, err := builder.Build("file.txt", "sha-256", 12, "did-1") + if err != nil { + t.Fatalf("Build error: %v", err) + } + if obj.Id != "did-1" { + t.Fatalf("unexpected Id: %s", obj.Id) + } + if obj.Name != "file.txt" { + t.Fatalf("unexpected Name: %s", obj.Name) + } + if obj.Checksums.SHA256 != "sha-256" { + t.Fatalf("unexpected checksum: %v", obj.Checksums) + } + if obj.Size != 12 { + t.Fatalf("unexpected size: %d", obj.Size) + } + if len(obj.AccessMethods) != 1 { + t.Fatalf("expected 1 access method, got %d", len(obj.AccessMethods)) + } + if !strings.Contains(obj.AccessMethods[0].AccessURL.URL, filepath.Join("bucket", "did-1", "sha-256")) { + t.Fatalf("unexpected access URL: %s", obj.AccessMethods[0].AccessURL.URL) + } + if obj.AccessMethods[0].Type != "s3" { + t.Fatalf("unexpected access method type: %s", obj.AccessMethods[0].Type) + } +} + +func TestObjectBuilderBuildEmptyBucket(t *testing.T) { + builder := ObjectBuilder{ + ProjectID: "test-project", + Bucket: "", + } + + if _, err := builder.Build("file.txt", "sha-256", 12, "did-1"); err == nil { + t.Fatalf("expected error when Bucket is empty") + } +} diff --git a/indexd/drs/types.go b/drs/types.go similarity index 96% rename from indexd/drs/types.go rename to drs/types.go index d17cd45..ff203cc 100644 --- a/indexd/drs/types.go +++ b/drs/types.go @@ -1,8 +1,6 @@ package drs -import ( - "github.com/calypr/data-client/indexd/hash" -) +import "github.com/calypr/data-client/hash" type ChecksumType = hash.ChecksumType type Checksum = hash.Checksum diff --git a/indexd/hash/hash.go b/hash/hash.go similarity index 100% rename from indexd/hash/hash.go rename to hash/hash.go diff --git a/indexd/hash/hash_test.go b/hash/hash_test.go similarity index 100% rename from indexd/hash/hash_test.go rename to hash/hash_test.go diff --git a/indexd/add_url.go b/indexd/add_url.go deleted file mode 100644 index af85298..0000000 --- a/indexd/add_url.go +++ /dev/null @@ -1,106 +0,0 @@ -package indexd - -import ( - "context" - "fmt" - "slices" - - "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/calypr/data-client/fence" - "github.com/calypr/data-client/indexd/drs" -) - -// UpsertIndexdRecord creates or updates an indexd record with a new URL. -func (c *IndexdClient) UpsertIndexdRecord(ctx context.Context, url string, sha256 string, fileSize int64, projectId string) (*drs.DRSObject, error) { - uuid := drs.DrsUUID(projectId, sha256) - - records, err := c.GetObjectByHash(ctx, "sha256", sha256) - if err != nil { - return nil, fmt.Errorf("error querying indexd server: %v", err) - } - - var matchingRecord *drs.DRSObject - for i := range records { - if records[i].Id == uuid { - matchingRecord = &records[i] - break - } - } - - if matchingRecord != nil { - existingURLs := IndexdURLFromDrsAccessURLs(matchingRecord.AccessMethods) - if slices.Contains(existingURLs, url) { - c.logger.Debug("Nothing to do: file already registered") - return matchingRecord, nil - } - - c.logger.Debug("updating existing record with new url") - updatedRecord := drs.DRSObject{AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: url}}}} - return c.UpdateRecord(ctx, &updatedRecord, matchingRecord.Id) - } - - // If no record exists, create one - c.logger.Debug("creating new record") - _, key, err := ParseS3URL(url) - if err != nil { - return nil, err - } - - drsObj, err := drs.BuildDrsObj(key, sha256, fileSize, uuid, "placeholder-bucket", projectId) - if err != nil { - return nil, err - } - - return c.RegisterRecord(ctx, drsObj) -} - -// AddURL implements the AddURL logic ported from git-drs. -func (c *IndexdClient) AddURL( - ctx context.Context, - fClient fence.FenceInterface, - s3URL string, - sha256 string, - awsAccessKey string, - awsSecretKey string, - region string, - endpoint string, - s3Client *s3.Client, -) (S3Meta, error) { - if err := ValidateInputs(s3URL, sha256); err != nil { - return S3Meta{}, err - } - - bucket, _, err := ParseS3URL(s3URL) - if err != nil { - return S3Meta{}, err - } - - var bucketDetails *fence.S3Bucket - if fClient != nil { - bucketDetails, err = fClient.GetBucketDetails(ctx, bucket) - if err != nil { - c.logger.Debug(fmt.Sprintf("Warning: unable to get bucket details from Gen3: %v", err)) - } - } - - size, modifiedDate, err := FetchS3MetadataWithBucketDetails( - ctx, s3URL, awsAccessKey, awsSecretKey, region, endpoint, bucketDetails, s3Client, c.logger, - ) - if err != nil { - return S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) - } - - // This part needs project ID. In git-drs it was in the client config. - projectId := "unknown-project" - // ... (logic to get project ID) - - _, err = c.UpsertIndexdRecord(ctx, s3URL, sha256, size, projectId) - if err != nil { - return S3Meta{}, fmt.Errorf("failed to upsert indexd record: %w", err) - } - - return S3Meta{ - Size: size, - LastModified: modifiedDate, - }, nil -} diff --git a/indexd/client.go b/indexd/client.go index 29ea378..989e69e 100644 --- a/indexd/client.go +++ b/indexd/client.go @@ -9,10 +9,8 @@ import ( "log/slog" "net/http" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/calypr/data-client/conf" - "github.com/calypr/data-client/fence" - "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/request" ) @@ -36,7 +34,6 @@ type IndexdInterface interface { DeleteRecordByHash(ctx context.Context, hashValue string, projectId string) error RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) UpsertIndexdRecord(ctx context.Context, url string, sha256 string, fileSize int64, projectId string) (*drs.DRSObject, error) - AddURL(ctx context.Context, fClient fence.FenceInterface, s3URL, sha256, awsAccessKey, awsSecretKey, region, endpoint string, s3Client *s3.Client) (S3Meta, error) } // IndexdClient implements IndexdInterface diff --git a/indexd/client_test.go b/indexd/client_test.go index 2fb76ed..818e498 100644 --- a/indexd/client_test.go +++ b/indexd/client_test.go @@ -11,8 +11,8 @@ import ( "testing" "github.com/calypr/data-client/conf" - drs "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + drs "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/logs" "github.com/calypr/data-client/request" ) diff --git a/indexd/convert.go b/indexd/convert.go index 0fb44d9..117cac9 100644 --- a/indexd/convert.go +++ b/indexd/convert.go @@ -6,7 +6,7 @@ import ( "fmt" "net/url" - "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/drs" ) // IndexdRecordFromDrsObject represents a simplified version of an indexd record for conversion purposes diff --git a/indexd/records.go b/indexd/records.go index 72e2de6..7f03613 100644 --- a/indexd/records.go +++ b/indexd/records.go @@ -3,7 +3,7 @@ package indexd // https://github.com/uc-cdis/indexd/blob/master/openapis/swagger.yaml import ( - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" ) // subset of the OpenAPI spec for the InputInfo object in indexd diff --git a/indexd/tests/add-url-integration_test.go b/indexd/tests/add-url-integration_test.go deleted file mode 100644 index 0500b10..0000000 --- a/indexd/tests/add-url-integration_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package indexd_tests - -// // TODO: fix this during add-url fix -// import ( -// "testing" - -// "github.com/calypr/git-drs/utils" -// "github.com/stretchr/testify/require" -// ) - -// //////////////////// -// // E2E TESTS // -// // & MISC TESTS // -// //////////////////// - -// // TestAddURL_E2E_IdempotentSameURL tests end-to-end idempotency -// func TestAddURL_E2E_IdempotentSameURL(t *testing.T) { -// // Arrange: Start mock servers -// gen3Mock := NewMockGen3Server(t, "http://localhost:9000") -// defer gen3Mock.Close() - -// s3Mock := NewMockS3Server(t) -// defer s3Mock.Close() - -// indexdMock := NewMockIndexdServer(t) -// defer indexdMock.Close() - -// // Pre-populate S3 with test object -// s3Mock.AddObject("test-bucket", "sample.bam", 1024) - -// // TODO: This test is limited because AddURL has hardcoded config.LoadConfig() -// // In a real scenario, we'd need to mock that too or refactor AddURL to accept config -// t.Skip("Requires AddURL refactoring to accept config parameter") -// } - -// // TestAddURL_E2E_UpdateDifferentURL tests updating record with different URL -// // TODO: stubbed -// func TestAddURL_E2E_UpdateDifferentURL(t *testing.T) { -// // TODO: This test is skipped because it requires AddURL refactoring -// // See TestAddURL_E2E_IdempotentSameURL for explanation -// t.Skip("Requires AddURL refactoring to accept config parameter") -// } - -// // TestAddURL_E2E_LFSNotTracked tests LFS validation -// func TestAddURL_E2E_LFSNotTracked(t *testing.T) { -// // This test validates the LFS tracking check -// // The actual utils.IsLFSTracked function is tested separately in utils package - -// // Test the pattern matching logic by verifying ParseGitAttributes works -// gitattributesContent := `*.bam filter=lfs diff=lfs merge=lfs -text -// *.vcf filter=lfs diff=lfs merge=lfs -text` - -// attributes, err := utils.ParseGitAttributes(gitattributesContent) -// require.NoError(t, err) -// require.GreaterOrEqual(t, len(attributes), 2) - -// // Verify .bam pattern exists -// found := false -// for _, attr := range attributes { -// if attr.Pattern == "*.bam" { -// if filter, exists := attr.Attributes["filter"]; exists { -// require.Equal(t, "lfs", filter) -// found = true -// } -// } -// } -// require.True(t, found, "*.bam pattern with lfs filter should exist") -// } diff --git a/indexd/tests/client_read_test.go.todo b/indexd/tests/client_read_test.go.todo deleted file mode 100644 index 51857e1..0000000 --- a/indexd/tests/client_read_test.go.todo +++ /dev/null @@ -1,134 +0,0 @@ -package indexd_tests - -import ( - "testing" - - "github.com/calypr/git-drs/drs/hash" - "github.com/stretchr/testify/require" -) - -/////////////////// -// READ TESTS // -/////////////////// - -// Integration tests for READ operations on IndexdClient using mock indexd server. -// These tests verify non-mutating operations that query and retrieve data: -// - GetRecord / GetIndexdRecordByDID - Retrieve a single record by DID -// - GetObjectsByHash - Query records by hash -// - GetDownloadURL - Get signed download URLs -// - GetProjectId - Simple getter for project ID - -// TestIndexdClient_GetRecord tests retrieving a record via the client method with mocked auth -func TestIndexdClient_GetRecord(t *testing.T) { - // Arrange: Start mock server - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - // Pre-populate mock with test record - testRecord := newTestRecord("uuid-test-123") - addRecordToMockServer(mockServer, testRecord) - - // Act: Use client method with mocked auth (tests actual client logic) - client := testIndexdClientWithMockAuth(mockServer.URL()) - record, err := client.GetIndexdRecordByDID(testRecord.Did) - - // Assert: Test actual client logic - require.NoError(t, err) - require.NotNil(t, record) - require.Equal(t, testRecord.Did, record.Did) - require.Equal(t, testRecord.Size, record.Size) - require.Equal(t, testRecord.FileName, record.FileName) -} - -// TestIndexdClient_GetRecord_NotFound tests error handling for non-existent records -func TestIndexdClient_GetRecord_NotFound(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - // Act: Use client method to request non-existent record - client := testIndexdClientWithMockAuth(mockServer.URL()) - record, err := client.GetIndexdRecordByDID("does-not-exist") - - // Assert: Client should handle 404 errors properly - require.Error(t, err) - require.Nil(t, record) - require.Contains(t, err.Error(), "failed to get record") -} - -/////////////////////////////// -// GetObjectsByHash Tests -/////////////////////////////// - -// TestIndexdClient_GetObjectsByHash tests hash-based queries via client method with mocked auth -func TestIndexdClient_GetObjectsByHash(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - testRecord := newTestRecord("uuid-test-456", withTestRecordSize(2048)) - sha256 := testRecord.Hashes["sha256"] - addRecordWithHashIndex(mockServer, testRecord, "sha256", sha256) - - // Create client with mocked auth - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act: Call the actual client method - results, err := client.GetObjectByHash(&hash.Checksum{Type: "sha256", Checksum: sha256}) - - // Assert: Verify client method works end-to-end - require.NoError(t, err) - require.Len(t, results, 1) - - // Verify correct record was returned - record := results[0] - require.Equal(t, testRecord.Did, record.Id) - require.Equal(t, testRecord.Size, record.Size) - require.Equal(t, sha256, record.Checksums.SHA256) - - require.Equal(t, testRecord.URLs[0], record.AccessMethods[0].AccessURL.URL) - require.Equal(t, testRecord.Authz[0], record.AccessMethods[0].Authorizations.Value) - - // Test: Query with non-existent hash - emptyResults, err := client.GetObjectByHash(&hash.Checksum{Type: "sha256", Checksum: "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}) - require.NoError(t, err) - require.Len(t, emptyResults, 0) -} - -/////////////////////////////// -// GetProjectId Tests -/////////////////////////////// - -// TestIndexdClient_GetProjectId tests the simple getter for project ID -func TestIndexdClient_GetProjectId(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act - projectId := client.GetProjectId() - - // Assert: Should return the project ID set during client creation - require.Equal(t, "test-project", projectId, "Should return configured project ID") -} - -// TestIndexdClient_GetProjectId_ConsistentAcrossCalls tests that GetProjectId is consistent -func TestIndexdClient_GetProjectId_ConsistentAcrossCalls(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act: Call multiple times - projectId1 := client.GetProjectId() - projectId2 := client.GetProjectId() - projectId3 := client.GetProjectId() - - // Assert: All calls should return the same value - require.Equal(t, projectId1, projectId2, "GetProjectId should be consistent") - require.Equal(t, projectId2, projectId3, "GetProjectId should be consistent") - require.Equal(t, "test-project", projectId1) -} diff --git a/indexd/tests/client_write_test.go.todo b/indexd/tests/client_write_test.go.todo deleted file mode 100644 index 1f6ee62..0000000 --- a/indexd/tests/client_write_test.go.todo +++ /dev/null @@ -1,369 +0,0 @@ -package indexd_tests - -import ( - "testing" - - indexd_client "github.com/calypr/git-drs/client/indexd" - "github.com/calypr/git-drs/drs" - "github.com/calypr/git-drs/drs/hash" - "github.com/stretchr/testify/require" -) - -/////////////////// -// WRITE TESTS // -/////////////////// - -// Integration tests for WRITE operations on IndexdClient using mock indexd server. -// These tests verify mutating operations that create, update, or delete data: -// - RegisterRecord / RegisterIndexdRecord - Create new records -// - UpdateRecord / UpdateRecord - Modify existing records -// - DeleteRecord / DeleteIndexdRecord - Remove records - -// TestIndexdClient_RegisterRecord tests the high-level RegisterRecord method -// which converts a DRSObject to IndexdRecord and registers it -func TestIndexdClient_RegisterRecord(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Create a DRS object to register - drsObject := &drs.DRSObject{ - Id: "uuid-drs-register-test", - Name: "test-file.bam", - Size: 3000, - Checksums: hash.HashInfo{ - SHA256: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - }, - AccessMethods: []drs.AccessMethod{ - { - AccessURL: drs.AccessURL{ - URL: "s3://drs-test-bucket/test-file.bam", - }, - Authorizations: &drs.Authorizations{ - Value: "/programs/drs-test/projects/test", - }, - }, - }, - } - - // Act: Call RegisterRecord which should: - // 1. Convert DRSObject to IndexdRecord - // 2. Call RegisterIndexdRecord - // 3. Return the registered DRSObject - result, err := client.RegisterRecord(drsObject) - - // Assert - require.NoError(t, err, "RegisterRecord should succeed") - require.NotNil(t, result, "Should return a valid DRSObject") - - // Verify the record was created in the mock server - storedRecord := mockServer.GetRecord(drsObject.Id) - require.NotNil(t, storedRecord, "Record should be stored in mock server") - require.Equal(t, drsObject.Name, storedRecord.FileName) - require.Equal(t, drsObject.Size, storedRecord.Size) - require.Contains(t, storedRecord.URLs, "s3://drs-test-bucket/test-file.bam") - - // Verify the returned DRS object matches - require.Equal(t, drsObject.Id, result.Id) - require.Equal(t, drsObject.Name, result.Name) - require.Equal(t, drsObject.Size, result.Size) -} - -// TestIndexdClient_RegisterRecord_MissingDID tests error handling when DID is missing -func TestIndexdClient_RegisterRecord_MissingDID(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Create a DRS object without ID (mock server will reject it) - invalidDrsObject := &drs.DRSObject{ - Name: "test-file.bam", - Size: 3000, - // Missing Id field - mock server should reject - } - - // Act - result, err := client.RegisterRecord(invalidDrsObject) - - // Assert: Should fail when registering with server (missing DID) - require.Error(t, err, "Should fail when DID is missing") - require.Nil(t, result) - require.Contains(t, err.Error(), "Missing required field: did") -} - -// TestIndexdClient_RegisterIndexdRecord_CreatesNewRecord tests record creation via client method -func TestIndexdClient_RegisterIndexdRecord_CreatesNewRecord(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Create input record to register - // IndexdRecord used here is the client-side object - // We don't use the newTestRecord helper bc that's the [mock] server-side object - newRecord := &indexd_client.IndexdRecord{ - Did: "uuid-register-test", - FileName: "new-file.bam", - Size: 5000, - URLs: []string{"s3://bucket/new-file.bam"}, - Authz: []string{"/workspace/test"}, - Hashes: hash.HashInfo{ - SHA256: "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", - }, - Metadata: map[string]string{ - "source": "test", - }, - } - - // Act: Call the RegisterIndexdRecord client method - // This tests: - // 1. Wrapping IndexdRecord in IndexdRecordForm with form="object" - // 2. Setting correct headers (Content-Type, accept) - // 3. Injecting auth header via MockAuthHandler - // 4. POSTing to /index/index endpoint - // 5. Handling 200 OK response - // 6. Querying the new record via GET /ga4gh/drs/v1/objects/{did} - // 7. Returning a valid DRSObject - drsObj, err := client.RegisterIndexdRecord(newRecord) - - // Assert: Verify the client method executed successfully - require.NoError(t, err, "RegisterIndexdRecord should succeed") - require.NotNil(t, drsObj, "Should return a valid DRSObject") - - // Verify the stored record matches input - storedRecord := mockServer.GetRecord(newRecord.Did) - require.NotNil(t, storedRecord, "Record should be stored in mock server after POST") - require.Equal(t, newRecord.FileName, storedRecord.FileName) - require.Equal(t, newRecord.Size, storedRecord.Size) - require.Equal(t, newRecord.URLs, storedRecord.URLs) - require.Equal(t, newRecord.Hashes.SHA256, storedRecord.Hashes["sha256"]) - - // Verify the returned DRS object matches input - require.Equal(t, newRecord.Did, drsObj.Id, "DRS object ID should match DID") - require.Equal(t, newRecord.FileName, drsObj.Name, "DRS object name should match FileName") - require.Equal(t, newRecord.Size, drsObj.Size, "DRS object size should match") - require.NotEmpty(t, drsObj.Checksums.SHA256, "Should have SHA256 checksum") - require.Equal(t, newRecord.Hashes.SHA256, drsObj.Checksums.SHA256) - require.Len(t, drsObj.AccessMethods, 1, "Should have one access method") - require.Equal(t, newRecord.URLs[0], drsObj.AccessMethods[0].AccessURL.URL) -} - -/////////////////////////////// -// UpdateRecord / UpdateRecord Tests -/////////////////////////////// - -// TestIndexdClient_UpdateRecord_AppendsURLs tests updating record via client method -func TestIndexdClient_UpdateRecord_AppendsURLs(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - originalRecord := newTestRecord("uuid-update-test", - withTestRecordFileName("file.bam"), - withTestRecordSize(2048), - withTestRecordURLs("s3://original-bucket/file.bam"), - withTestRecordHash("sha256", "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd")) - addRecordToMockServer(mockServer, originalRecord) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Create update info with new URL - newURL := "s3://new-bucket/file-v2.bam" - updateInfo := &drs.DRSObject{ - AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: newURL}}}, - } - - // Act: Call the UpdateRecord client method - // This tests: - // 1. Getting the existing record via GET /index/{did} - // 2. Appending new URLs to existing URLs - // 3. Marshaling UpdateInputInfo to JSON - // 4. Setting correct headers (Content-Type, accept) - // 5. Injecting auth header via MockAuthHandler - // 6. PUTting to /index/index/{did} endpoint with new URLs - // 7. Handling 200 OK response - // 8. Querying the updated record via GET /ga4gh/drs/v1/objects/{did} - // 9. Returning a valid DRSObject - drsObj, err := client.UpdateRecord(updateInfo, originalRecord.Did) - - // Assert: Verify the client method executed successfully - require.NoError(t, err, "UpdateRecord should succeed") - require.NotNil(t, drsObj, "Should return a valid DRSObject") - - // Verify the URLs were appended correctly - updatedRecord := mockServer.GetRecord(originalRecord.Did) - require.NotNil(t, updatedRecord) - require.Equal(t, 2, len(updatedRecord.URLs), "Should have appended new URL to existing") - require.Contains(t, updatedRecord.URLs, originalRecord.URLs[0]) - require.Contains(t, updatedRecord.URLs, newURL) - - // Verify the returned DRS object - require.Equal(t, originalRecord.Did, drsObj.Id, "DRS object ID should match DID") - require.Equal(t, originalRecord.FileName, drsObj.Name, "DRS object name should match FileName") - require.Equal(t, originalRecord.Size, drsObj.Size, "DRS object size should match") - require.NotEmpty(t, drsObj.Checksums.SHA256, "Should have SHA256 checksum") - require.Equal(t, originalRecord.Hashes["sha256"], drsObj.Checksums.SHA256) - require.Len(t, drsObj.AccessMethods, 2, "Should have two access methods (URLs)") - urls := []string{drsObj.AccessMethods[0].AccessURL.URL, drsObj.AccessMethods[1].AccessURL.URL} - require.Contains(t, urls, originalRecord.URLs[0]) - require.Contains(t, urls, newURL) -} - -// TestIndexdClient_RegisterFile_UsesSingleHashQuery verifies RegisterFile reuses -// the initial hash lookup when checking downloadability. -func TestIndexdClient_RegisterFile_UsesSingleHashQuery(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - mockServer.signedURLBase = mockServer.URL() + "/signed" - - record := newTestRecord("uuid-register-file-test", - withTestRecordHash("sha256", testSHA256Hash), - withTestRecordURLs("s3://test-bucket/test-file.bam")) - addRecordWithHashIndex(mockServer, record, "sha256", testSHA256Hash) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act - result, err := client.RegisterFile(testSHA256Hash) - - // Assert - require.NoError(t, err, "RegisterFile should not error when file is downloadable") - require.NotNil(t, result, "RegisterFile should return the existing DRS object") - require.Equal(t, 1, mockServer.HashQueryCount(), "expected a single hash query during RegisterFile") -} - -// TestIndexdClient_UpdateRecord_Idempotent tests URL appending idempotency via client method -func TestIndexdClient_UpdateRecord_Idempotent(t *testing.T) { - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - originalRecord := newTestRecord("uuid-update-idempotent", - withTestRecordURLs("s3://bucket1/file.bam"), - withTestRecordHash("sha256", "aaaa...")) - addRecordToMockServer(mockServer, originalRecord) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Create update info with same URL (should be idempotent) - updateInfo := &drs.DRSObject{ - AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: originalRecord.URLs[0]}}}, - } - - // call the UpdateRecord client method - drsObj, err := client.UpdateRecord(updateInfo, originalRecord.Did) - require.NoError(t, err) - - // Verify URL wasn't duplicated - updated := mockServer.GetRecord(drsObj.Id) - require.NotNil(t, updated) - require.Equal(t, 1, len(updated.URLs)) - require.Equal(t, originalRecord.URLs[0], updated.URLs[0]) -} - -/////////////////////////////// -// DeleteRecord / DeleteIndexdRecord Tests -/////////////////////////////// - -// TestIndexdClient_DeleteRecord tests deleting a record by OID -func TestIndexdClient_DeleteRecord(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - // Pre-populate with a test record - testHash := "1111111111111111111111111111111111111111111111111111111111111111" - testRecord := newTestRecord("uuid-delete-by-oid", - withTestRecordFileName("delete-me.bam"), - withTestRecordSize(4096), - withTestRecordHash("sha256", testHash)) - addRecordWithHashIndex(mockServer, testRecord, "sha256", testHash) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Verify record exists before deletion - recordBefore := mockServer.GetRecord(testRecord.Did) - require.NotNil(t, recordBefore, "Record should exist before deletion") - - // Act: Delete by OID (which is the hash) - err := client.DeleteRecord(testHash) - - // Assert - require.NoError(t, err, "DeleteRecord should succeed") - - // Verify record was deleted - recordAfter := mockServer.GetRecord(testRecord.Did) - require.Nil(t, recordAfter, "Record should be deleted") -} - -// TestIndexdClient_DeleteRecord_NotFound tests deleting a non-existent record -func TestIndexdClient_DeleteRecord_NotFound(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act: Try to delete a record that doesn't exist - nonExistentHash := "9999999999999999999999999999999999999999999999999999999999999999" - err := client.DeleteRecord(nonExistentHash) - - // Assert: Should return error - require.Error(t, err, "Should fail when record doesn't exist") - require.Contains(t, err.Error(), "no records found for OID") -} - -// TestIndexdClient_DeleteRecord_NoMatchingProject tests deletion when record exists but for different project -func TestIndexdClient_DeleteRecord_NoMatchingProject(t *testing.T) { - // Arrange - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - // Create a record with a DIFFERENT project authorization - testHash := "2222222222222222222222222222222222222222222222222222222222222222" - differentProjectAuthz := "/programs/other-program/projects/other-project" - testRecord := newTestRecord("uuid-different-project", - withTestRecordFileName("other-project.bam"), - withTestRecordHash("sha256", testHash)) - testRecord.Authz = []string{differentProjectAuthz} // Override with different project - addRecordWithHashIndex(mockServer, testRecord, "sha256", testHash) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Act: Try to delete - should fail because project doesn't match - err := client.DeleteRecord(testHash) - - // Assert - require.Error(t, err, "Should fail when no matching project") - require.Contains(t, err.Error(), "no matching record found for project") - - // Verify record still exists (wasn't deleted) - recordAfter := mockServer.GetRecord(testRecord.Did) - require.NotNil(t, recordAfter, "Record should still exist") -} - -// TestIndexdClient_DeleteIndexdRecord_Removes tests record deletion via client method -func TestIndexdClient_DeleteIndexdRecord_Removes(t *testing.T) { - mockServer := NewMockIndexdServer(t) - defer mockServer.Close() - - testRecord := newTestRecord("uuid-delete-test", withTestRecordURLs("s3://bucket/file.bam")) - addRecordToMockServer(mockServer, testRecord) - - client := testIndexdClientWithMockAuth(mockServer.URL()) - - // Delete record via client method - err := client.DeleteIndexdRecord(testRecord.Did) - - require.NoError(t, err) - - // Verify it's gone - deletedRecord := mockServer.GetRecord(testRecord.Did) - require.Nil(t, deletedRecord) -} diff --git a/indexd/tests/mock_servers_test.go b/indexd/tests/mock_servers_test.go deleted file mode 100644 index 869cd51..0000000 --- a/indexd/tests/mock_servers_test.go +++ /dev/null @@ -1,610 +0,0 @@ -package indexd_tests - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "slices" - "strings" - "sync" - "testing" - "time" - - indexd_client "github.com/calypr/data-client/indexd" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" -) - -////////////////// -// MOCK SERVERS // -////////////////// - -// MockIndexdRecord represents a stored Indexd record in memory -type MockIndexdRecord struct { - Did string `json:"did"` - FileName string `json:"file_name"` - Size int64 `json:"size"` - Hashes map[string]string `json:"hashes"` - URLs []string `json:"urls"` - Authz []string `json:"authz"` - Metadata map[string]string `json:"metadata"` - CreatedAt time.Time `json:"-"` // Not serialized -} - -// MockIndexdServer simulates an Indexd server with in-memory storage -type MockIndexdServer struct { - httpServer *httptest.Server - records map[string]*MockIndexdRecord - hashIndex map[string][]string // hash -> [DIDs] - signedURLBase string - hashQueryCount int - recordMutex sync.RWMutex -} - -// NewMockIndexdServer creates and starts a mock Indexd server -func NewMockIndexdServer(t *testing.T) *MockIndexdServer { - mis := &MockIndexdServer{ - records: make(map[string]*MockIndexdRecord), - hashIndex: make(map[string][]string), - signedURLBase: "https://signed-url.example.com", - } - - mux := http.NewServeMux() - - // Register handlers for /index and /index/ paths - // /index matches exact path and query params (POST, GET with ?hash=) - mux.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) { - // POST /index - create record - if r.Method == http.MethodPost { - mis.handleCreateRecord(w, r) - return - } - - // GET /index?hash=... - query by hash - if r.Method == http.MethodGet { - mis.handleQueryByHash(w, r) - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - }) - - // /index/index handles /index/index for POST and /index/index?hash= for GET - mux.HandleFunc("/index/index", func(w http.ResponseWriter, r *http.Request) { - // POST /index/index - create record - if r.Method == http.MethodPost { - mis.handleCreateRecord(w, r) - return - } - - // GET /index/index?hash=... - query by hash - if r.Method == http.MethodGet { - mis.handleQueryByHash(w, r) - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - }) - - // /ga4gh/drs/v1/objects/ handles GET requests for DRS object and signed URLs - mux.HandleFunc("/ga4gh/drs/v1/objects/", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Extract path after /ga4gh/drs/v1/objects/ - path := strings.TrimPrefix(r.URL.Path, "/ga4gh/drs/v1/objects/") - if path == "" { - http.Error(w, "Missing object ID", http.StatusBadRequest) - return - } - - // Split path to determine if this is object request or access request - pathParts := strings.Split(path, "/") - - if len(pathParts) == 1 { - // GET /ga4gh/drs/v1/objects/{id} - get DRS object - mis.handleGetDRSObject(w, r, pathParts[0]) - } else if len(pathParts) == 3 && pathParts[1] == "access" { - // GET /ga4gh/drs/v1/objects/{id}/access/{accessId} - get signed URL - mis.handleGetSignedURL(w, r, pathParts[0], pathParts[2]) - } else { - http.Error(w, "Invalid path", http.StatusBadRequest) - } - }) - - mux.HandleFunc("/signed/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - // /index/ matches /index/{guid} (trailing slash pattern) - mux.HandleFunc("/index/", func(w http.ResponseWriter, r *http.Request) { - // Extract DID from path: /index/{guid} -> {guid} - // This handles both /index/{id} and /index/index/{id} - path := r.URL.Path - var did string - - if strings.HasPrefix(path, "/index/index/") { - did = strings.TrimPrefix(path, "/index/index/") - } else { - did = strings.TrimPrefix(path, "/index/") - } - - if did == "" || did == "index" { - http.Error(w, "Missing DID", http.StatusBadRequest) - return - } - - switch r.Method { - case http.MethodGet: - mis.handleGetRecord(w, r, did) - case http.MethodPut: - mis.handleUpdateRecord(w, r, did) - case http.MethodDelete: - mis.handleDeleteRecord(w, r, did) - default: - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - }) - - mis.httpServer = httptest.NewServer(mux) - return mis -} - -func (mis *MockIndexdServer) handleGetRecord(w http.ResponseWriter, r *http.Request, did string) { - mis.recordMutex.RLock() - record, exists := mis.records[did] - mis.recordMutex.RUnlock() - - if !exists { - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "Record not found"}) - return - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(record) -} - -func (mis *MockIndexdServer) handleGetDRSObject(w http.ResponseWriter, r *http.Request, id string) { - mis.recordMutex.RLock() - record, exists := mis.records[id] - mis.recordMutex.RUnlock() - if !exists { - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "Object not found"}) - return - } - - // Build standard DRS checksums array - checksums := []map[string]string{} - for typ, sum := range record.Hashes { - if sum != "" { - checksums = append(checksums, map[string]string{ - "type": strings.ToLower(typ), - "checksum": sum, - }) - } - } - - // Build access methods - accessMethods := []map[string]any{} - for i, url := range record.URLs { - am := map[string]any{ - "type": "https", - "access_id": fmt.Sprintf("https-%d", i), - "access_url": map[string]string{"url": url}, - } - // Only add authorizations if present, and as a SINGLE object (not array) - if len(record.Authz) > 0 { - am["authorizations"] = map[string]string{ - "value": record.Authz[0], - } - } - accessMethods = append(accessMethods, am) - } - - // Full response - response := map[string]any{ - "id": record.Did, - "name": record.FileName, - "size": record.Size, - "created_time": record.CreatedAt.Format(time.RFC3339), - "checksums": checksums, - "access_methods": accessMethods, - "description": "Mock DRS object from Indexd record", - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(response) -} - -func (mis *MockIndexdServer) handleGetSignedURL(w http.ResponseWriter, r *http.Request, objectId, accessId string) { - mis.recordMutex.RLock() - _, exists := mis.records[objectId] - mis.recordMutex.RUnlock() - - if !exists { - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "Object not found"}) - return - } - - // Create a mock signed URL - base := strings.TrimSuffix(mis.signedURLBase, "/") - signedURL := drs.AccessURL{ - URL: fmt.Sprintf("%s/%s/%s", base, objectId, accessId), - Headers: []string{}, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(signedURL) -} - -func (mis *MockIndexdServer) handleCreateRecord(w http.ResponseWriter, r *http.Request) { - // Handle IndexdRecordForm (client sends this with POST) - var form struct { - indexd_client.IndexdRecord - Form string `json:"form"` - Rev string `json:"rev"` - } - - if err := json.NewDecoder(r.Body).Decode(&form); err != nil { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) - return - } - - // Extract the core record data - record := MockIndexdRecord{ - Did: form.Did, - FileName: form.FileName, - Size: form.Size, - URLs: form.URLs, - Authz: form.Authz, - Hashes: hash.ConvertHashInfoToMap(form.Hashes), - Metadata: form.Metadata, // Already map[string]string from IndexdRecord - CreatedAt: time.Now(), - } - - if record.Did == "" { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Missing required field: did"}) - return - } - - mis.recordMutex.Lock() - defer mis.recordMutex.Unlock() - - if _, exists := mis.records[record.Did]; exists { - w.WriteHeader(http.StatusConflict) - json.NewEncoder(w).Encode(map[string]string{"error": "Record already exists"}) - return - } - - // Index by hash for queryability - for hashType, hash := range record.Hashes { - if hash != "" { // Only index non-empty hashes - key := hashType + ":" + hash - mis.hashIndex[key] = append(mis.hashIndex[key], record.Did) - } - } - - mis.records[record.Did] = &record - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(record) -} - -func (mis *MockIndexdServer) handleUpdateRecord(w http.ResponseWriter, r *http.Request, did string) { - mis.recordMutex.Lock() - defer mis.recordMutex.Unlock() - - record, exists := mis.records[did] - if !exists { - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "Record not found"}) - return - } - - var update struct { - URLs []string `json:"urls"` - } - if err := json.NewDecoder(r.Body).Decode(&update); err != nil { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) - return - } - - // Append new URLs (avoid duplicates) - for _, newURL := range update.URLs { - if !slices.Contains(record.URLs, newURL) { - record.URLs = append(record.URLs, newURL) - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(record) -} - -func (mis *MockIndexdServer) handleQueryByHash(w http.ResponseWriter, r *http.Request) { - hashQuery := r.URL.Query().Get("hash") // format: "sha256:aaaa..." - - mis.recordMutex.Lock() - mis.hashQueryCount++ - mis.recordMutex.Unlock() - - mis.recordMutex.RLock() - dids, exists := mis.hashIndex[hashQuery] - mis.recordMutex.RUnlock() - - outputRecords := []indexd_client.OutputInfo{} - if exists { - mis.recordMutex.RLock() - for _, did := range dids { - if record, ok := mis.records[did]; ok { - // Convert sha256 hash string to HashInfo struct - hashes := hash.HashInfo{} - if sha256, ok := record.Hashes["sha256"]; ok { - hashes.SHA256 = sha256 - } - - // Convert metadata - metadata := make(map[string]any) - for k, v := range record.Metadata { - metadata[k] = v - } - - outputRecords = append(outputRecords, indexd_client.OutputInfo{ - Did: record.Did, - Size: record.Size, - Hashes: hashes, - URLs: record.URLs, - Authz: record.Authz, - Metadata: metadata, - }) - } - } - mis.recordMutex.RUnlock() - } - - w.Header().Set("Content-Type", "application/json") - // Return wrapped in ListRecords object matching Indexd API - response := indexd_client.ListRecords{ - Records: outputRecords, - IDs: dids, - Size: int64(len(outputRecords)), - } - json.NewEncoder(w).Encode(response) -} - -func (mis *MockIndexdServer) handleDeleteRecord(w http.ResponseWriter, r *http.Request, did string) { - mis.recordMutex.Lock() - defer mis.recordMutex.Unlock() - - _, exists := mis.records[did] - if !exists { - w.WriteHeader(http.StatusNotFound) - return - } - - delete(mis.records, did) - w.WriteHeader(http.StatusNoContent) -} - -// URL returns the mock server URL -func (mis *MockIndexdServer) URL() string { - return mis.httpServer.URL -} - -// Close closes the mock server -func (mis *MockIndexdServer) Close() { - mis.httpServer.Close() -} - -// GetAllRecords returns all records for testing purposes -func (mis *MockIndexdServer) GetAllRecords() []*MockIndexdRecord { - mis.recordMutex.RLock() - defer mis.recordMutex.RUnlock() - - records := make([]*MockIndexdRecord, 0, len(mis.records)) - for _, record := range mis.records { - records = append(records, record) - } - return records -} - -// GetRecord retrieves a single record by DID -func (mis *MockIndexdServer) GetRecord(did string) *MockIndexdRecord { - mis.recordMutex.RLock() - defer mis.recordMutex.RUnlock() - return mis.records[did] -} - -// HashQueryCount returns the number of hash query requests observed by the mock server. -func (mis *MockIndexdServer) HashQueryCount() int { - mis.recordMutex.RLock() - defer mis.recordMutex.RUnlock() - return mis.hashQueryCount -} - -// MockGen3Server simulates Gen3 /user/data/buckets endpoint -type MockGen3Server struct { - httpServer *httptest.Server - s3Endpoint string -} - -// NewMockGen3Server creates and starts a mock Gen3 server -func NewMockGen3Server(t *testing.T, s3Endpoint string) *MockGen3Server { - mgs := &MockGen3Server{ - s3Endpoint: s3Endpoint, - } - - mux := http.NewServeMux() - - mux.HandleFunc("/user/data/buckets", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - response := map[string]any{ - "S3_BUCKETS": map[string]any{ - "test-bucket": map[string]any{ - "region": "us-west-2", - "endpoint_url": mgs.s3Endpoint, - "programs": []string{"test-program"}, - }, - }, - "GS_BUCKETS": map[string]any{}, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) - }) - - mgs.httpServer = httptest.NewServer(mux) - return mgs -} - -// URL returns the mock server URL -func (mgs *MockGen3Server) URL() string { - return mgs.httpServer.URL -} - -// Client returns the mock server HTTP client -func (mgs *MockGen3Server) Client() *http.Client { - return mgs.httpServer.Client() -} - -// Close closes the mock server -func (mgs *MockGen3Server) Close() { - mgs.httpServer.Close() -} - -// MockS3Object represents a stored S3 object -type MockS3Object struct { - Size int64 - LastModified time.Time - ContentType string -} - -// MockS3Server simulates S3 HEAD object endpoint -type MockS3Server struct { - httpServer *httptest.Server - objects map[string]*MockS3Object // "bucket/key" -> object - objMutex sync.RWMutex -} - -// NewMockS3Server creates and starts a mock S3 server -func NewMockS3Server(t *testing.T) *MockS3Server { - mss := &MockS3Server{ - objects: make(map[string]*MockS3Object), - } - - mux := http.NewServeMux() - - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - path := strings.TrimPrefix(r.URL.Path, "/") - if path == "" { - http.Error(w, "Not found", http.StatusNotFound) - return - } - - if r.Method == http.MethodHead || r.Method == http.MethodGet { - mss.handleHeadObject(w, r, path) - } else { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - }) - - mss.httpServer = httptest.NewServer(mux) - return mss -} - -func (mss *MockS3Server) handleHeadObject(w http.ResponseWriter, r *http.Request, path string) { - mss.objMutex.RLock() - object, exists := mss.objects[path] - mss.objMutex.RUnlock() - - if !exists { - w.WriteHeader(http.StatusNotFound) - return - } - - w.Header().Set("Content-Length", fmt.Sprintf("%d", object.Size)) - w.Header().Set("Last-Modified", object.LastModified.UTC().Format(http.TimeFormat)) - w.Header().Set("Content-Type", object.ContentType) - w.Header().Set("ETag", fmt.Sprintf("\"%x\"", object.LastModified.Unix())) - - if r.Method == http.MethodHead { - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusOK) - w.Write(make([]byte, 0)) - } -} - -// AddObject adds a mock S3 object for testing -func (mss *MockS3Server) AddObject(bucket, key string, size int64) { - path := bucket + "/" + key - mss.objMutex.Lock() - defer mss.objMutex.Unlock() - - mss.objects[path] = &MockS3Object{ - Size: size, - LastModified: time.Now(), - ContentType: "application/octet-stream", - } -} - -// URL returns the mock server URL -func (mss *MockS3Server) URL() string { - return mss.httpServer.URL -} - -// Close closes the mock server -func (mss *MockS3Server) Close() { - mss.httpServer.Close() -} - -// Helper functions for type conversion -func convertMockRecordToDRSObject(record *MockIndexdRecord) *drs.DRSObject { - - // Convert URLs to AccessMethods - accessMethods := make([]drs.AccessMethod, 0) - for i, url := range record.URLs { - // Get the first authz as the authorization for this access method - var authzPtr *drs.Authorizations - if len(record.Authz) > 0 { - authzPtr = &drs.Authorizations{ - Value: record.Authz[0], - } - } - - accessMethods = append(accessMethods, drs.AccessMethod{ - Type: "https", - AccessID: fmt.Sprintf("access-method-%d", i), - AccessURL: drs.AccessURL{ - URL: url, - Headers: []string{}, - }, - Authorizations: authzPtr, - }) - } - - return &drs.DRSObject{ - Id: record.Did, - Name: record.FileName, - Size: record.Size, - Checksums: hash.ConvertStringMapToHashInfo(record.Hashes), - AccessMethods: accessMethods, - CreatedTime: record.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), - Description: "DRS object created from Indexd record", - } -} diff --git a/indexd/types.go b/indexd/types.go index dff0e48..54c601a 100644 --- a/indexd/types.go +++ b/indexd/types.go @@ -1,8 +1,8 @@ package indexd import ( - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" ) type OutputObject struct { @@ -68,8 +68,3 @@ type UpdateInputInfo struct { // List of authorization policies Authz []string `json:"authz,omitempty"` } - -type S3Meta struct { - Size int64 - LastModified string -} diff --git a/indexd/types_test.go b/indexd/types_test.go index 3125f03..c81536c 100644 --- a/indexd/types_test.go +++ b/indexd/types_test.go @@ -3,8 +3,8 @@ package indexd import ( "testing" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" ) func TestConvertOutputObjectToDRSObject(t *testing.T) { diff --git a/indexd/upsert.go b/indexd/upsert.go new file mode 100644 index 0000000..31f7411 --- /dev/null +++ b/indexd/upsert.go @@ -0,0 +1,54 @@ +package indexd + +import ( + "context" + "fmt" + "slices" + + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/s3utils" +) + +// UpsertIndexdRecord creates or updates an indexd record with a new URL. +func (c *IndexdClient) UpsertIndexdRecord(ctx context.Context, url string, sha256 string, fileSize int64, projectId string) (*drs.DRSObject, error) { + uuid := drs.DrsUUID(projectId, sha256) + + records, err := c.GetObjectByHash(ctx, "sha256", sha256) + if err != nil { + return nil, fmt.Errorf("error querying indexd server: %v", err) + } + + var matchingRecord *drs.DRSObject + for i := range records { + if records[i].Id == uuid { + matchingRecord = &records[i] + break + } + } + + if matchingRecord != nil { + existingURLs := IndexdURLFromDrsAccessURLs(matchingRecord.AccessMethods) + if slices.Contains(existingURLs, url) { + c.logger.Debug("Nothing to do: file already registered") + return matchingRecord, nil + } + + c.logger.Debug("updating existing record with new url") + updatedRecord := drs.DRSObject{AccessMethods: []drs.AccessMethod{{AccessURL: drs.AccessURL{URL: url}}}} + return c.UpdateRecord(ctx, &updatedRecord, matchingRecord.Id) + } + + // If no record exists, create one + c.logger.Debug("creating new record") + _, key, err := s3utils.ParseS3URL(url) + if err != nil { + return nil, err + } + + drsObj, err := drs.BuildDrsObj(key, sha256, fileSize, uuid, "placeholder-bucket", projectId) + if err != nil { + return nil, err + } + + return c.RegisterRecord(ctx, drsObj) +} diff --git a/mocks/mock_fence.go b/mocks/mock_fence.go index f2577d0..004b633 100644 --- a/mocks/mock_fence.go +++ b/mocks/mock_fence.go @@ -250,3 +250,32 @@ func (mr *MockFenceInterfaceMockRecorder) ParseFenceURLResponse(resp any) *gomoc mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseFenceURLResponse", reflect.TypeOf((*MockFenceInterface)(nil).ParseFenceURLResponse), resp) } + +// RefreshToken mocks base method. +func (m *MockFenceInterface) RefreshToken(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshToken", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// RefreshToken indicates an expected call of RefreshToken. +func (mr *MockFenceInterfaceMockRecorder) RefreshToken(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshToken", reflect.TypeOf((*MockFenceInterface)(nil).RefreshToken), ctx) +} + +// UserPing mocks base method. +func (m *MockFenceInterface) UserPing(ctx context.Context) (*fence.PingResp, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UserPing", ctx) + ret0, _ := ret[0].(*fence.PingResp) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UserPing indicates an expected call of UserPing. +func (mr *MockFenceInterfaceMockRecorder) UserPing(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserPing", reflect.TypeOf((*MockFenceInterface)(nil).UserPing), ctx) +} diff --git a/mocks/mock_gen3interface.go b/mocks/mock_gen3interface.go index a627c69..7524b7c 100644 --- a/mocks/mock_gen3interface.go +++ b/mocks/mock_gen3interface.go @@ -17,6 +17,8 @@ import ( fence "github.com/calypr/data-client/fence" indexd "github.com/calypr/data-client/indexd" logs "github.com/calypr/data-client/logs" + requestor "github.com/calypr/data-client/requestor" + sower "github.com/calypr/data-client/sower" gomock "go.uber.org/mock/gomock" ) @@ -113,3 +115,31 @@ func (mr *MockGen3InterfaceMockRecorder) Logger() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockGen3Interface)(nil).Logger)) } + +// Requestor mocks base method. +func (m *MockGen3Interface) Requestor() requestor.RequestorInterface { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Requestor") + ret0, _ := ret[0].(requestor.RequestorInterface) + return ret0 +} + +// Requestor indicates an expected call of Requestor. +func (mr *MockGen3InterfaceMockRecorder) Requestor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Requestor", reflect.TypeOf((*MockGen3Interface)(nil).Requestor)) +} + +// Sower mocks base method. +func (m *MockGen3Interface) Sower() sower.SowerInterface { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sower") + ret0, _ := ret[0].(sower.SowerInterface) + return ret0 +} + +// Sower indicates an expected call of Sower. +func (mr *MockGen3InterfaceMockRecorder) Sower() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sower", reflect.TypeOf((*MockGen3Interface)(nil).Sower)) +} diff --git a/mocks/mock_indexd.go b/mocks/mock_indexd.go index 6c0d5e2..6a4f217 100644 --- a/mocks/mock_indexd.go +++ b/mocks/mock_indexd.go @@ -14,8 +14,8 @@ import ( http "net/http" reflect "reflect" + drs "github.com/calypr/data-client/drs" indexd "github.com/calypr/data-client/indexd" - drs "github.com/calypr/data-client/indexd/drs" request "github.com/calypr/data-client/request" gomock "go.uber.org/mock/gomock" ) diff --git a/indexd/s3_utils.go b/s3utils/s3_utils.go similarity index 97% rename from indexd/s3_utils.go rename to s3utils/s3_utils.go index 09997c8..9e47805 100644 --- a/indexd/s3_utils.go +++ b/s3utils/s3_utils.go @@ -1,4 +1,4 @@ -package indexd +package s3utils import ( "context" @@ -122,3 +122,8 @@ func FetchS3MetadataWithBucketDetails( return contentLength, lastModified, nil } + +type S3Meta struct { + Size int64 + LastModified string +} diff --git a/upload/multipart_test.go b/upload/multipart_test.go index d9cad7a..b7ded8d 100644 --- a/upload/multipart_test.go +++ b/upload/multipart_test.go @@ -20,6 +20,7 @@ import ( "github.com/calypr/data-client/indexd" "github.com/calypr/data-client/logs" "github.com/calypr/data-client/request" + "github.com/calypr/data-client/requestor" "github.com/calypr/data-client/sower" ) @@ -34,9 +35,10 @@ func (f *fakeGen3Upload) Logger() *logs.Gen3Logger { return f.logger } func (f *fakeGen3Upload) ExportCredential(ctx context.Context, cred *conf.Credential) error { return nil } -func (f *fakeGen3Upload) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } -func (f *fakeGen3Upload) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } -func (f *fakeGen3Upload) Sower() sower.SowerInterface { return nil } +func (f *fakeGen3Upload) Fence() fence.FenceInterface { return &fakeFence{doFunc: f.doFunc} } +func (f *fakeGen3Upload) Indexd() indexd.IndexdInterface { return &fakeIndexd{doFunc: f.doFunc} } +func (f *fakeGen3Upload) Sower() sower.SowerInterface { return nil } +func (f *fakeGen3Upload) Requestor() requestor.RequestorInterface { return nil } type fakeFence struct { fence.FenceInterface diff --git a/upload/progress_reader.go b/upload/progress_reader.go index b7b3294..da12f7d 100644 --- a/upload/progress_reader.go +++ b/upload/progress_reader.go @@ -8,11 +8,12 @@ import ( ) type progressReader struct { - reader io.Reader - onProgress common.ProgressCallback - hash string - total int64 - bytesSoFar int64 + reader io.Reader + onProgress common.ProgressCallback + hash string + total int64 + bytesSoFar int64 + bytesSinceReport int64 } func newProgressReader(reader io.Reader, onProgress common.ProgressCallback, hash string, total int64) *progressReader { @@ -36,19 +37,33 @@ func (pr *progressReader) Read(p []byte) (int, error) { if n > 0 && pr.onProgress != nil { delta := int64(n) pr.bytesSoFar += delta - if progressErr := pr.onProgress(common.ProgressEvent{ - Event: "progress", - Oid: pr.hash, - BytesSoFar: pr.bytesSoFar, - BytesSinceLast: delta, - }); progressErr != nil { - return n, progressErr + pr.bytesSinceReport += delta + + if pr.bytesSinceReport >= common.OnProgressThreshold { + if progressErr := pr.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pr.hash, + BytesSoFar: pr.bytesSoFar, + BytesSinceLast: pr.bytesSinceReport, + }); progressErr != nil { + return n, progressErr + } + pr.bytesSinceReport = 0 } } return n, err } func (pr *progressReader) Finalize() error { + if pr.onProgress != nil && pr.bytesSinceReport > 0 { + _ = pr.onProgress(common.ProgressEvent{ + Event: "progress", + Oid: pr.hash, + BytesSoFar: pr.bytesSoFar, + BytesSinceLast: pr.bytesSinceReport, + }) + pr.bytesSinceReport = 0 + } if pr.total > 0 && pr.bytesSoFar < pr.total { delta := pr.total - pr.bytesSoFar pr.bytesSoFar = pr.total diff --git a/upload/upload.go b/upload/upload.go index fb62e96..aec8d66 100644 --- a/upload/upload.go +++ b/upload/upload.go @@ -8,8 +8,8 @@ import ( "strings" "github.com/calypr/data-client/common" + drs "github.com/calypr/data-client/drs" // Imported for DRSObject client "github.com/calypr/data-client/g3client" - drs "github.com/calypr/data-client/indexd/drs" // Imported for DRSObject "github.com/vbauerster/mpb/v8" ) From e400e130cc189beb4e7e8e67979aa6ebd98064f5 Mon Sep 17 00:00:00 2001 From: Brian Date: Tue, 10 Feb 2026 16:49:07 -0800 Subject: [PATCH 14/14] fix:TLS handshake timeout #30 (#31) --- request/request.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/request/request.go b/request/request.go index 82711ba..cdb46fb 100644 --- a/request/request.go +++ b/request/request.go @@ -36,13 +36,13 @@ func NewRequestInterface( retryClient.RetryWaitMax = 15 * time.Second baseTransport := &http.Transport{ DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, + Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConns: 100, MaxIdleConnsPerHost: 100, - TLSHandshakeTimeout: 5 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, + TLSHandshakeTimeout: 15 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, } authTransport := &AuthTransport{