diff --git a/DockerFile b/DockerFile index 4f5a885..8121c44 100644 --- a/DockerFile +++ b/DockerFile @@ -6,12 +6,24 @@ RUN go get github.com/akrylysov/algnhsa && \ go get github.com/sirupsen/logrus && \ go build ./main.go -FROM openeuler/openeuler:22.03 +FROM openeuler/openeuler:24.03 LABEL maintainer="Zhou Yi 1123678689@qq.com" + +# 安装依赖工具 +RUN dnf install -y git wget tar gzip && \ + # 下载git-lfs + wget https://github.com/git-lfs/git-lfs/releases/download/v3.3.0/git-lfs-linux-amd64-v3.3.0.tar.gz && \ + tar -xzf git-lfs-linux-amd64-v3.3.0.tar.gz && \ + cd git-lfs-3.3.0 && \ + ./install.sh && \ + rm -rf git-lfs-* && \ + dnf clean all + RUN useradd -s /bin/bash BigFiles USER BigFiles WORKDIR /home/BigFiles COPY --chown=BigFiles:group --from=BUILDER /home/main /home/BigFiles/main +COPY --chown=BigFiles:group --from=BUILDER /home/scripts/lfsNameQuery.py /home/BigFiles/lfsNameQuery.py EXPOSE 5000 ENTRYPOINT ["/home/BigFiles/main"] \ No newline at end of file diff --git a/auth/gitee.go b/auth/gitee.go index 96c1ba5..8a9729f 100644 --- a/auth/gitee.go +++ b/auth/gitee.go @@ -8,7 +8,10 @@ import ( "net/http" "net/url" "os" + "os/exec" + "path/filepath" "strings" + "time" "github.com/metalogical/BigFiles/batch" "github.com/metalogical/BigFiles/config" @@ -139,7 +142,7 @@ func GiteeAuth() func(UserInRepo) error { } } - if err := CheckRepoOwner(userInRepo); err != nil { + if _, err := CheckRepoOwner(userInRepo); err != nil { return err } @@ -148,7 +151,7 @@ func GiteeAuth() func(UserInRepo) error { } // CheckRepoOwner checks whether the owner of a repo is allowed to use lfs server -func CheckRepoOwner(userInRepo UserInRepo) error { +func CheckRepoOwner(userInRepo UserInRepo) (Repo, error) { path := fmt.Sprintf( "https://gitee.com/api/v5/repos/%s/%s%s", userInRepo.Owner, @@ -165,24 +168,24 @@ func CheckRepoOwner(userInRepo UserInRepo) error { err := getParsedResponse("GET", path, headers, nil, &repo) if err != nil { msg := err.Error() + ": check repo_id failed" - return errors.New(msg) + return *repo, errors.New(msg) } for _, allowedRepo := range allowedRepos { if strings.Split(repo.Fullname, "/")[0] == allowedRepo { - return nil + return *repo, nil } } if repo.Parent.Fullname != "" { for _, allowedRepo := range allowedRepos { if strings.Split(repo.Parent.Fullname, "/")[0] == allowedRepo { - return nil + return *repo, nil } } } msg := "forbidden: repo has no permission to use this lfs server" logrus.Error(fmt.Sprintf("CheckRepoOwner | %s", msg)) - return errors.New(msg) + return *repo, errors.New(msg) } // getToken gets access_token by username and password @@ -242,7 +245,7 @@ func VerifyUser(userInRepo UserInRepo) error { } else if userInRepo.Operation == "delete" { return verifyUserDelete(giteeUser, userInRepo) } else { - msg := "system_error: unknow operation" + msg := "system_error: unknown operation" logrus.Error(fmt.Sprintf(formatLogString, verifyLog, msg)) return errors.New(msg) } @@ -385,6 +388,120 @@ func GetAccountManageToken() (string, error) { return managerTokenOutput.Token, err } +// FileInfo 包含LFS文件的名称和大小信息 +type FileInfo struct { + Name string `json:"name"` + Size int64 `json:"size"` +} + +// GetLFSMapping 调用Python脚本获取LFS文件映射 +// 参数: +// - userInRepo: 仓库相关信息 +// - pythonScriptPath: Python脚本路径(可选) +// +// 返回: +// - map[string]FileInfo: OID到文件信息的映射 +// - error: 错误信息 +func GetLFSMapping(userInRepo UserInRepo, pythonScriptPath ...string) (map[string]FileInfo, error) { + owner := userInRepo.Owner + repo := userInRepo.Repo + username := userInRepo.Username + token := userInRepo.Token + + // 确定Python脚本路径 + scriptPath, err := resolveScriptPath(pythonScriptPath...) + if err != nil { + return nil, err + } + + // 创建临时文件 + outputFile, cleanup, err := createTempOutputFile() + if err != nil { + return nil, err + } + defer cleanup() + + // 构建并执行命令 + cmd := exec.Command("python3") + cmd.Args = append(cmd.Args, scriptPath) + cmd.Args = append(cmd.Args, owner) + cmd.Args = append(cmd.Args, repo) + cmd.Args = append(cmd.Args, outputFile) + cmd.Args = append(cmd.Args, username) + cmd.Args = append(cmd.Args, token) + cmd.Stderr = os.Stderr + + // 运行命令 + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return nil, fmt.Errorf("Python脚本执行失败,退出码%d: %w", exitErr.ExitCode(), err) + } + return nil, fmt.Errorf("执行Python脚本出错: %w", err) + } + + // 读取并解析结果 + return parseOutputFile(outputFile) +} + +// 解析脚本路径 +func resolveScriptPath(pythonScriptPath ...string) (string, error) { + if len(pythonScriptPath) > 0 { + return pythonScriptPath[0], nil + } + + exePath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("获取可执行文件路径失败: %w", err) + } + + scriptPath := filepath.Join(filepath.Dir(exePath), "lfsNameQuery.py") + if _, err := os.Stat(scriptPath); os.IsNotExist(err) { + return "", fmt.Errorf("Python脚本不存在于: %s", scriptPath) + } + + return scriptPath, nil +} + +// 创建临时输出文件 +func createTempOutputFile() (string, func(), error) { + tempDir := os.TempDir() + outputFile := filepath.Join(tempDir, fmt.Sprintf("lfs_mapping_%d.json", time.Now().UnixNano())) + cleanup := func() { + err := os.Remove(outputFile) + if err != nil { + logrus.Warnf("Failed to remove temporary output file %s", outputFile) + return + } + } + return outputFile, cleanup, nil +} + +// 解析输出文件 +func parseOutputFile(outputFile string) (map[string]FileInfo, error) { + // 转换为绝对路径 + absPath, err := filepath.Abs(outputFile) + safeDir := "/tmp" + // 检查路径是否在安全目录内 + if !strings.HasPrefix(filepath.Clean(absPath), safeDir) { + return nil, fmt.Errorf("access denied: file must be in %s", safeDir) + } + if err != nil { + return nil, fmt.Errorf("invalid file path: %w", err) + } + data, err := os.ReadFile(absPath) + if err != nil { + return nil, fmt.Errorf("读取输出文件失败: %w", err) + } + + var mapping map[string]FileInfo + if err := json.Unmarshal(data, &mapping); err != nil { + return nil, fmt.Errorf("解析JSON失败: %w", err) + } + + return mapping, nil +} + func generateError(err error, m string) error { msg := err.Error() + m logrus.Error(fmt.Sprintf(formatLogString, verifyLog, msg)) diff --git a/auth/gitee_test.go b/auth/gitee_test.go index 2249d58..8143509 100644 --- a/auth/gitee_test.go +++ b/auth/gitee_test.go @@ -71,7 +71,7 @@ func (s *SuiteGitee) TestCheckRepoOwner() { Owner: s.Owner, Token: s.cfg.DefaultToken, } - err := CheckRepoOwner(userInRepo) + _, err := CheckRepoOwner(userInRepo) assert.NotNil(s.T(), err) // check no_exist repo @@ -80,7 +80,7 @@ func (s *SuiteGitee) TestCheckRepoOwner() { Owner: "owner", Token: s.cfg.DefaultToken, } - err = CheckRepoOwner(userInRepo) + _, err = CheckRepoOwner(userInRepo) assert.NotNil(s.T(), err) } diff --git a/config/config.go b/config/config.go index 4c9b275..4930090 100644 --- a/config/config.go +++ b/config/config.go @@ -25,9 +25,10 @@ type Config struct { type ValidateConfig struct { OwnerRegexp string `json:"OWNER_REGEXP" required:"true"` - RepoNameRegexp string `json:"REPONAME_REGEXP" required:"true"` - UsernameRegexp string `json:"USERNAME_REGEXP" required:"true"` - PasswordRegexp string `json:"PASSWORD_REGEXP" required:"true"` + RepoNameRegexp string `json:"REPONAME_REGEXP" required:"true"` + UsernameRegexp string `json:"USERNAME_REGEXP" required:"true"` + PasswordRegexp string `json:"PASSWORD_REGEXP" required:"true"` + WebhookKey string `json:"WEBHOOK_KEY" required:"true"` } type DBConfig struct { diff --git a/db/db.go b/db/db.go index 0630f5f..0799774 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "errors" "fmt" "log" "time" @@ -13,7 +14,8 @@ import ( ) var ( - Db *gorm.DB + Db *gorm.DB + oidHolder = "oid = ?" ) // Init initializes the database connection and configuration. @@ -52,6 +54,7 @@ func DB() *gorm.DB { type LfsObj struct { ID int `gorm:"primaryKey;autoIncrement;comment:'自增ID'"` Oid string `gorm:"size:511;not null;default:'';index:idx_oid;comment:'文件OID'"` + FileName string `gorm:"size:255;default:'';comment:'文件名'"` Size int `gorm:"not null;comment:'文件大小'"` Platform string `gorm:"size:64;not null;default:'gitee';index:idx_platform;comment:'所属平台,默认为gitee'"` Owner string `gorm:"size:100;not null;index:idx_platform;comment:'仓库owner'"` @@ -96,7 +99,7 @@ func DeleteLFSObj(obj LfsObj) error { // CountLFSObj 查找指定 OID 的 LFS 元数据数量 func CountLFSObj(obj LfsObj) (int64, error) { var count int64 - result := Db.Model(&LfsObj{}).Where("oid = ?", obj.Oid).Count(&count) + result := Db.Model(&LfsObj{}).Where(oidHolder, obj.Oid).Count(&count) if result.Error != nil { return 0, fmt.Errorf("failed to count LFS objects: %w", result.Error) } @@ -116,8 +119,73 @@ func GetUploadLfsObj() ([]LfsObj, error) { // SelectLfsObjByOid 通过OID查找指定了LFS数据 func SelectLfsObjByOid(oid string) ([]LfsObj, error) { var result []LfsObj - if err := Db.Where("oid = ?", oid).Find(&result).Error; err != nil { + if err := Db.Where(oidHolder, oid).Find(&result).Error; err != nil { return nil, fmt.Errorf("failed to get LfsObj: %w", err) } return result, nil } + +// UpdateLFSObjFileName 更新LFS对象的文件名 +// 参数: +// - oid: 文件OID +// - newFileName: 新文件名 +// - operator: 操作人 +// +// 返回: +// - error: 错误信息 +func UpdateLFSObjFileName(oid, newFileName, operator string) error { + // 参数校验 + if oid == "" { + return fmt.Errorf("OID不能为空") + } + if newFileName == "" { + return fmt.Errorf("新文件名不能为空") + } + + // 开启事务 + tx := Db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + // 查询现有记录 + var obj LfsObj + if err := tx.Where(oidHolder, oid).First(&obj).Error; err != nil { + tx.Rollback() + if errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("未找到OID为 %s 的记录", oid) + } + return fmt.Errorf("查询失败: %w", err) + } + + // 检查是否需要更新 + if obj.FileName == newFileName { + tx.Rollback() + return nil // 文件名相同,无需更新 + } + + // 执行更新 + updateData := map[string]interface{}{ + "file_name": newFileName, + "operator": operator, + "update_time": time.Now().Add(8 * time.Hour), + } + + if err := tx.Model(&LfsObj{}). + Where(oidHolder, oid). + Updates(updateData).Error; err != nil { + tx.Rollback() + return fmt.Errorf("更新失败: %w", err) + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + return fmt.Errorf("提交事务失败: %w", err) + } + + log.Printf("成功更新文件OID: %s, 旧文件名: %s, 新文件名: %s", + oid, obj.FileName, newFileName) + return nil +} diff --git a/main.go b/main.go index 6c7c2b8..7a86b3b 100644 --- a/main.go +++ b/main.go @@ -133,6 +133,7 @@ func main() { }) go server.StartScheduledTask() + go server.ScheduledCheckOidAndFileName() srv := &http.Server{ Addr: "0.0.0.0:5000", diff --git a/scripts/lfsNameQuery.py b/scripts/lfsNameQuery.py new file mode 100644 index 0000000..59a469a --- /dev/null +++ b/scripts/lfsNameQuery.py @@ -0,0 +1,157 @@ +import os +import sys +import subprocess +import json +import shutil +from urllib.parse import quote_plus + + +def clone_repo_skip_lfs(gitee_owner, gitee_repo, username=None, token=None, target_dir="temp_repo"): + """克隆仓库并强制跳过LFS文件下载""" + if username and token: + encoded_username = quote_plus(username) + encoded_token = quote_plus(token) + repo_url = f"https://{encoded_username}:{encoded_token}@gitee.com/{gitee_owner}/{gitee_repo}.git" + else: + repo_url = f"https://gitee.com/{gitee_owner}/{gitee_repo}.git" + + force_remove(target_dir) + + try: + env = os.environ.copy() + env.update({ + "GIT_LFS_SKIP_SMUDGE": "1", + "GIT_CLONE_PROTECTION_ACTIVE": "false" + }) + + subprocess.run( + ["git", "clone", repo_url, target_dir], + env=env, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + return target_dir + except subprocess.CalledProcessError as e: + error_msg = "克隆失败: " + if username and token: + error_msg += "认证失败或" + error_msg += e.stderr.strip() if e.stderr else '未知错误' + force_remove(target_dir) + raise RuntimeError(error_msg) + + +def branch_has_lfsconfig(repo_dir, branch): + """检查分支是否包含.lfsconfig文件""" + try: + result = subprocess.run( + ["git", "ls-tree", "-r", branch, "--name-only"], + cwd=repo_dir, + capture_output=True, + text=True, + encoding='utf-8' + ) + return ".lfsconfig" in result.stdout.split('\n') + except Exception: + return False + + +def get_all_branches_lfs_mapping(repo_dir): + """获取所有包含.lfsconfig的分支的LFS文件信息""" + try: + branches = subprocess.run( + ["git", "branch", "-a"], + cwd=repo_dir, + capture_output=True, + text=True, + encoding='utf-8' + ).stdout.split('\n') + + lfs_mapping = {} + + for branch in branches: + branch = branch.strip().replace('*', '').strip() + if not branch or 'HEAD' in branch: + continue + + if not branch_has_lfsconfig(repo_dir, branch): + print(f"跳过分支 {branch} (无.lfsconfig文件)") + continue + + subprocess.run( + ["git", "checkout", branch], + cwd=repo_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + result = subprocess.run( + ["git", "lfs", "ls-files", "--json"], + cwd=repo_dir, + capture_output=True, + text=True, + encoding='utf-8', + ) + + if result.returncode == 0: + data = json.loads(result.stdout) + if isinstance(data, dict) and "files" in data: + for f in data["files"]: + if isinstance(f, dict) and "oid" in f and "name" in f: + oid = f["oid"] + if oid not in lfs_mapping: + lfs_mapping[oid] = { + "name": f["name"], + "size": f.get("size", 0), + "branches": [] + } + lfs_mapping[oid]["branches"].append(branch) + + return lfs_mapping + except Exception as e: + raise RuntimeError(f"获取LFS映射失败: {str(e)}") + + +def force_remove(path): + """跨平台强制删除文件/目录""" + if not os.path.exists(path): + return + try: + shutil.rmtree(path) if os.path.isdir(path) else os.remove(path) + except: + os.system(f'rm -rf "{path}"' if os.name != 'nt' else f'rd /s /q "{path}"') + + +def main(owner, repo, output_file="lfs_mapping.json", username=None, token=None): + try: + repo_dir = clone_repo_skip_lfs(owner, repo, username, token) + mapping = get_all_branches_lfs_mapping(repo_dir) + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(mapping, f, indent=2, ensure_ascii=False) + + print(f"结果已保存到 {output_file}") + return True + except Exception as e: + print(f"错误: {str(e)}", file=sys.stderr) + return False + finally: + if 'repo_dir' in locals(): + force_remove(repo_dir) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("用法: python lfsNameQuery.py [output_file] [username] [token]") + sys.exit(1) + + args = { + "owner": sys.argv[1], + "repo": sys.argv[2], + "output_file": sys.argv[3] if len(sys.argv) > 3 else "lfs_mapping.json", + "username": sys.argv[4] if len(sys.argv) > 4 else None, + "token": sys.argv[5] if len(sys.argv) > 5 else None + } + + sys.exit(0 if main( ** args) else 1) diff --git a/server/daily_task.go b/server/daily_task.go index a355e5e..8e649ac 100644 --- a/server/daily_task.go +++ b/server/daily_task.go @@ -29,6 +29,16 @@ func StartScheduledTask() { } } +func ScheduledCheckOidAndFileName() { + ticker := time.NewTicker(3 * time.Hour) + defer ticker.Stop() + + for { + checkOidFileName() //立即执行第一次 + <-ticker.C // 等待三小时 + } +} + func ScanUploadExistTask() { // 获取所有 LfsObj 记录 lfsObjs, err := db.GetUploadLfsObj() @@ -48,7 +58,7 @@ func ScanUploadExistTask() { func checkExist(lfsObjs []db.LfsObj) { for i := range lfsObjs { obj := lfsObjs[i] - + obj.UpdateTime = time.Now().Add(8 * time.Hour) // 调用 check 函数检查每个 Oid exists, err := check(obj.Oid) if err != nil { diff --git a/server/server.go b/server/server.go index 368bd45..cb4727e 100644 --- a/server/server.go +++ b/server/server.go @@ -107,7 +107,8 @@ func New(o Options) (http.Handler, error) { r.Post("/{owner}/{repo}/delete/{oid}", s.delete) r.Get("/info/lfs/objects/{oid}", s.download) r.Get("/repos/list", s.listAllRepos) - + r.Get("/oid/filename", checkOid) + r.Post("/webhook/merge", s.handleGiteeWebhook) return r, nil } @@ -153,7 +154,7 @@ func (s *server) handleBatch(w http.ResponseWriter, r *http.Request) { return } - if err = auth.CheckRepoOwner(userInRepo); req.Operation == "upload" || err != nil { + if _, err = auth.CheckRepoOwner(userInRepo); req.Operation == "upload" || err != nil { err := s.dealWithAuthError(userInRepo, w, r) if err != nil { return @@ -162,6 +163,13 @@ func (s *server) handleBatch(w http.ResponseWriter, r *http.Request) { resp := s.handleRequestObject(req) + // 添加元数据 + addMetaData(req, w, userInRepo) + + must(json.NewEncoder(w).Encode(resp)) +} + +func addMetaData(req batch.Request, w http.ResponseWriter, userInRepo auth.UserInRepo) { // 添加元数据 if req.Operation == "upload" { for _, object := range req.Objects { @@ -170,10 +178,9 @@ func (s *server) handleBatch(w http.ResponseWriter, r *http.Request) { Owner: userInRepo.Owner, Oid: object.OID, Size: object.Size, - Exist: 2, // 默认设置为2 - Platform: "gitee", // 默认平台 - //TODO - Operator: "", // 操作人 + Exist: 2, // 默认设置为2 + Platform: "gitee", // 默认平台 + Operator: userInRepo.Username, // 操作人 } if err := db.InsertLFSObj(lfsObj); err != nil { @@ -185,9 +192,16 @@ func (s *server) handleBatch(w http.ResponseWriter, r *http.Request) { } logrus.Infof("insert lfsobj succeed") } + // 10分钟后异步执行,带错误恢复 + time.AfterFunc(10*time.Minute, func() { + defer func() { + if err := recover(); err != nil { + logrus.Errorf("checkRepoOidName panic: %v", err) + } + }() + checkRepoOidName(userInRepo) + }) } - - must(json.NewEncoder(w).Encode(resp)) } func (s *server) handleRequestObject(req batch.Request) batch.Response { @@ -512,6 +526,7 @@ type FileResponse struct { Repo string `json:"repo"` Size int `json:"size"` Oid string `json:"oid"` + FileName string `json:"file_name"` CreateTime int64 `json:"create_time"` UpdateTime int64 `json:"update_time"` } @@ -524,6 +539,7 @@ func (s *server) buildListResponse(files []db.LfsObj, total int64) interface{} { Repo: file.Repo, Size: file.Size, Oid: file.Oid, + FileName: file.FileName, CreateTime: file.CreateTime.Unix(), UpdateTime: file.UpdateTime.Unix(), } @@ -541,7 +557,7 @@ func (s *server) buildListResponse(files []db.LfsObj, total int64) interface{} { func (s *server) listAllRepos(w http.ResponseWriter, r *http.Request) { searchKey, page, limit := s.getQueryParams(r) - repoList, total, err := s.fetchRepoList(searchKey, page, limit) + repoList, total, err := fetchRepoList(searchKey, page, limit) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -557,6 +573,115 @@ func (s *server) listAllRepos(w http.ResponseWriter, r *http.Request) { return } } +func checkOid(w http.ResponseWriter, r *http.Request) { + checkOidFileName() +} + +func checkOidFileName() { + repoList, _, err := fetchRepoList("", 0, 0) + if err != nil { + logrus.Errorf("fetch repo list failed: %v", err) + return + } + for _, repo := range repoList { + userInRepo := auth.UserInRepo{ + Repo: repo.Repo, + Owner: repo.Owner} + logrus.Infof("checkOidFileName owner:%v repo:%v", repo.Owner, repo.Repo) + checkRepoOidName(userInRepo) + + } + +} + +func checkRepoOidName(userInRepo auth.UserInRepo) (oidFileNameMap map[string]auth.FileInfo) { + oidFileNameMap, err := auth.GetLFSMapping(userInRepo) + if err != nil { + logrus.Errorf("get lfs mapping failed: %v", err) + } + checkOidFileNameMap(oidFileNameMap, userInRepo) + if strings.ToLower(userInRepo.Owner) != "src-openeuler" { + logrus.Infof("after check owner:%v repo:%v, check src-openeuler", userInRepo.Owner, userInRepo.Repo) + repo, err := auth.CheckRepoOwner(userInRepo) + if err != nil { + return nil + } + if repo.Parent.Fullname != "" { + userInRepo.Owner = strings.Split(repo.Parent.Fullname, "/")[0] + userInRepo.Repo = strings.Split(repo.Parent.Fullname, "/")[1] + return checkRepoOidName(userInRepo) + } + } + return oidFileNameMap +} + +func checkOidFileNameMap(oidFileNameMap map[string]auth.FileInfo, userInRepo auth.UserInRepo) { + if oidFileNameMap == nil { + return + } + for oid, fileInfo := range oidFileNameMap { + lfsObjs, err := db.SelectLfsObjByOid(oid) + if err != nil { + logrus.Errorf("get lfs obj by oid failed: %v", err) + continue + } + + if len(lfsObjs) == 0 { + logrus.Infof("oid:%v not exist, create", oid) + lfsObj := db.LfsObj{ + Repo: userInRepo.Repo, + Owner: userInRepo.Owner, + Oid: oid, + Size: int(fileInfo.Size), + FileName: fileInfo.Name, + Exist: 2, + Platform: "gitee", + Operator: "", + } + if err = db.InsertLFSObj(lfsObj); err != nil { + logrus.Errorf("insert lfs obj failed: %v", err) + } + continue + } + + // 检查对应oid文件在对应仓库下是否存在,如果不存在则创建对应数据 + checkLfsObjsInfo(oid, lfsObjs, fileInfo, userInRepo) + } +} + +func checkLfsObjsInfo(oid string, lfsObjs []db.LfsObj, fileInfo auth.FileInfo, userInRepo auth.UserInRepo) { + exist := false + logrus.Infof("check oid:%v info", oid) + for _, lfsObj := range lfsObjs { + if lfsObj.Owner == userInRepo.Owner { + exist = true + } + + if "" == lfsObj.FileName { + err := db.UpdateLFSObjFileName(oid, fileInfo.Name, "") + if err != nil { + logrus.Errorf("update file name failed: %v", err) + return + } + } + } + + if !exist { + lfsObj := db.LfsObj{ + Repo: userInRepo.Repo, + Owner: userInRepo.Owner, + Oid: oid, + Size: int(fileInfo.Size), + FileName: fileInfo.Name, + Exist: 2, + Platform: "gitee", + Operator: "", + } + if err := db.InsertLFSObj(lfsObj); err != nil { + logrus.Errorf("insert not exist lfs obj failed: %v", err) + } + } +} func (s *server) getQueryParams(r *http.Request) (string, int, int) { searchKey := r.URL.Query().Get("searchKey") @@ -578,7 +703,7 @@ func (s *server) getQueryParams(r *http.Request) (string, int, int) { return searchKey, page, limit } -func (s *server) fetchRepoList(searchKey string, page, limit int) ([]struct { +func fetchRepoList(searchKey string, page, limit int) ([]struct { Owner string `json:"owner"` Repo string `json:"repo"` TotalSize int `json:"total_size"` @@ -601,7 +726,7 @@ func (s *server) fetchRepoList(searchKey string, page, limit int) ([]struct { Having("SUM(CASE WHEN exist = 1 THEN size ELSE 0 END) > 0") if searchKey != "" { - query = s.applySearchFilter(query, searchKey) + query = applySearchFilter(query, searchKey) } var total int64 @@ -627,7 +752,7 @@ func (s *server) fetchRepoList(searchKey string, page, limit int) ([]struct { return repoList, total, nil } -func (s *server) applySearchFilter(query *gorm.DB, searchKey string) *gorm.DB { +func applySearchFilter(query *gorm.DB, searchKey string) *gorm.DB { parts := strings.SplitN(searchKey, "/", 2) if len(parts) == 2 { owner := parts[0] diff --git a/server/validate.go b/server/validate.go index f7e40fb..5decc4f 100644 --- a/server/validate.go +++ b/server/validate.go @@ -14,9 +14,12 @@ type validateConfig struct { } var validatecfg validateConfig +var Webhook_key string func Init(cfg config.ValidateConfig) error { var err error + Webhook_key = cfg.WebhookKey + validatecfg.ownerRegexp, err = regexp.Compile(cfg.OwnerRegexp) if err != nil { return fmt.Errorf("failed to compile owner regexp: %w", err) diff --git a/server/webhook.go b/server/webhook.go new file mode 100644 index 0000000..4261882 --- /dev/null +++ b/server/webhook.go @@ -0,0 +1,303 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/metalogical/BigFiles/db" + "github.com/sirupsen/logrus" + "gorm.io/gorm" + "io" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +var diffcheck = "+size " + +func (s *server) handleGiteeWebhook(w http.ResponseWriter, r *http.Request) { + if !verifyWebhookKey(r) { + logrus.Errorf("Invalid Gitee token") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + payload, err := s.parseWebhookPayload(r) + if err != nil { + logrus.Errorf("Failed to decode webhook payload: %v", err) + http.Error(w, "Invalid payload", http.StatusBadRequest) + return + } + + if shouldSkipProcessing(payload) { + writeJSONResponse(w, http.StatusOK, map[string]string{"message": "Skipped non-merge request"}) + return + } + + lfsFiles, err := s.processMergeRequest(payload) + if err != nil { + logrus.Errorf("Failed to process merge request: %v", err) + http.Error(w, "Failed to process request", http.StatusInternalServerError) + return + } + + s.writeSuccessResponse(w, payload, lfsFiles) +} + +// verifyWebhookKey 验证 Gitee Webhook 的 Token +func verifyWebhookKey(r *http.Request) bool { + // 从 Header 中获取 token + receivedToken := r.Header.Get("X-Gitee-Token") + if receivedToken == "" { + logrus.Warn("Missing X-Gitee-Token in header") + return false + } + return receivedToken == Webhook_key +} + +// parseWebhookPayload 解析webhook请求体 +func (s *server) parseWebhookPayload(r *http.Request) (*GiteeWebhookPayload, error) { + var payload GiteeWebhookPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + return nil, err + } + return &payload, nil +} + +// shouldSkipProcessing 判断是否应该跳过处理 +func shouldSkipProcessing(payload *GiteeWebhookPayload) bool { + return payload.HookName != "merge_request_hooks" || !payload.PullRequest.Merged +} + +// processMergeRequest 处理合并请求的核心逻辑 +func (s *server) processMergeRequest(payload *GiteeWebhookPayload) ([]LFSFile, error) { + lfsFiles, err := s.extractLFSFilesFromDiff(payload.PullRequest.DiffURL) + if err != nil { + return nil, fmt.Errorf("failed to check diff for LFS files: %w", err) + } + + if len(lfsFiles) == 0 { + return nil, nil + } + + repoOwner, repoName, _ := strings.Cut(payload.PullRequest.Base.Repo.FullName, "/") + operator := payload.PullRequest.User.Login + + for _, lfsFile := range lfsFiles { + if err := s.processLFSFile(lfsFile, repoOwner, repoName, operator); err != nil { + return nil, err + } + } + + return lfsFiles, nil +} + +// processLFSFile 处理单个LFS文件 +func (s *server) processLFSFile(lfsFile LFSFile, repoOwner, repoName, operator string) error { + existingObj, err := db.SelectLfsObjByOid(lfsFile.Oid) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query LFS object by OID: %w", err) + } + + if existingObj == nil { + logrus.Infof("LFS object with OID %s not exists, skipping", lfsFile.Oid) + return nil + } + + obj := db.LfsObj{ + Oid: lfsFile.Oid, + FileName: lfsFile.FileName, + Size: lfsFile.Size, + Platform: "gitee", + Owner: repoOwner, + Repo: repoName, + Operator: operator, + Exist: 2, + } + + if err := db.InsertLFSObj(obj); err != nil { + return fmt.Errorf("failed to insert LFS object: %w", err) + } + + return nil +} + +// writeSuccessResponse 写入成功响应 +func (s *server) writeSuccessResponse(w http.ResponseWriter, payload *GiteeWebhookPayload, lfsFiles []LFSFile) { + response := map[string]interface{}{ + "message": "Webhook processed successfully", + "lfs_files_count": len(lfsFiles), + "pull_request_id": payload.PullRequest.ID, + "pull_request_url": payload.PullRequest.HTMLURL, + "merged": payload.PullRequest.Merged, + } + writeJSONResponse(w, http.StatusOK, response) +} + +// GiteeWebhookPayload 定义webhook负载结构 +type GiteeWebhookPayload struct { + HookName string `json:"hook_name"` + PullRequest struct { + ID int `json:"id"` + Number int `json:"number"` + State string `json:"state"` + Title string `json:"title"` + HTMLURL string `json:"html_url"` + DiffURL string `json:"diff_url"` + Merged bool `json:"merged"` + MergedAt string `json:"merged_at"` + CreatedAt string `json:"created_at"` + User struct { + Login string `json:"login"` + } `json:"user"` + Head struct { + Ref string `json:"ref"` + Sha string `json:"sha"` + Repo struct { + FullName string `json:"full_name"` + Owner struct { + Login string `json:"login"` + } `json:"owner"` + Name string `json:"name"` + } `json:"repo"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + Sha string `json:"sha"` + Repo struct { + FullName string `json:"full_name"` + Owner struct { + Login string `json:"login"` + } `json:"owner"` + Name string `json:"name"` + } `json:"repo"` + } `json:"base"` + } `json:"pull_request"` +} + +// LFSFile 表示从diff中提取的LFS文件信息 +type LFSFile struct { + Oid string `json:"oid"` + FileName string `json:"file_name"` + Size int `json:"size"` +} + +// extractLFSFilesFromDiff 从diff中提取LFS文件信息 +func (s *server) extractLFSFilesFromDiff(diffURL string) ([]LFSFile, error) { + parsedURL, err := url.Parse(diffURL) + if err != nil { + return nil, fmt.Errorf("invalid URL format: %w", err) + } + + if parsedURL.Scheme != "https" { + return nil, fmt.Errorf("only HTTPS protocol is allowed") + } + + hostname := parsedURL.Hostname() + if !strings.HasSuffix(hostname, ".gitee.com") && hostname != "gitee.com" { + return nil, fmt.Errorf("only gitee.com domains are permitted") + } + + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } + + resp, err := client.Get(diffURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch diff: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + diff, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read diff: %w", err) + } + + return parseLFSFilesFromDiff(string(diff)) +} + +// parseLFSFilesFromDiff 从diff内容中解析LFS文件信息 +func parseLFSFilesFromDiff(diffContent string) ([]LFSFile, error) { + var lfsFiles []LFSFile + lines := strings.Split(diffContent, "\n") + + for i := 0; i < len(lines); i++ { + if !isOIDLine(lines[i]) { + continue + } + + fileInfo, skip := extractLFSFileInfo(lines, i) + if fileInfo != nil { + lfsFiles = append(lfsFiles, *fileInfo) + } + if skip { + i++ // 跳过已处理的size行 + } + } + + return lfsFiles, nil +} + +// extractLFSFileInfo 从指定位置提取LFS文件信息 +func extractLFSFileInfo(lines []string, currentIdx int) (*LFSFile, bool) { + // 提取OID + oid := strings.TrimPrefix(lines[currentIdx], "+oid sha256:") + + // 提取Size + size, skip := 0, false + if currentIdx+1 < len(lines) && strings.HasPrefix(lines[currentIdx+1], diffcheck) { + sizeStr := strings.TrimPrefix(lines[currentIdx+1], diffcheck) + size, _ = strconv.Atoi(sizeStr) + skip = true + } + + // 提取文件名 + fileName := findFileName(lines, currentIdx) + + if oid != "" && fileName != "" { + return &LFSFile{ + Oid: oid, + FileName: fileName, + Size: size, + }, skip + } + + return nil, skip +} + +// isOIDLine 判断是否是OID行 +func isOIDLine(line string) bool { + return strings.HasPrefix(line, "+oid sha256:") +} + +// findFileName 向上查找文件名 +func findFileName(lines []string, currentIdx int) string { + for j := currentIdx; j >= 0 && j >= currentIdx-10; j-- { + if strings.HasPrefix(lines[j], "diff --git a/") { + parts := strings.SplitN(lines[j][len("diff --git a/"):], " ", 2) + if len(parts) > 0 { + return parts[0] + } + } + } + return "" +} + +// writeJSONResponse 辅助函数 +func writeJSONResponse(w http.ResponseWriter, statusCode int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Printf("Failed to encode JSON response: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +}