From 99a2a3203942e42bc0df686bc4b8b2d3f3cdf854 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Fri, 13 Mar 2026 01:03:14 +0000 Subject: [PATCH 1/2] feat: implement threads service --- .github/workflows/ci.yml | 34 +++ .github/workflows/release.yml | 99 +++++++ .gitignore | 4 + Dockerfile | 28 ++ Makefile | 25 ++ buf.gen.yaml | 16 ++ buf.yaml | 4 + charts/threads/Chart.yaml | 6 + charts/threads/templates/_helpers.tpl | 42 +++ charts/threads/templates/deployment.yaml | 112 ++++++++ charts/threads/templates/service.yaml | 22 ++ charts/threads/templates/serviceaccount.yaml | 12 + charts/threads/values.yaml | 61 +++++ cmd/threads/main.go | 85 ++++++ go.mod | 23 ++ go.sum | 62 +++++ internal/config/config.go | 29 ++ internal/db/migrate.go | 57 ++++ internal/notifier/notifier.go | 51 ++++ internal/server/converter.go | 53 ++++ internal/server/server.go | 263 +++++++++++++++++++ internal/store/messages.go | 228 ++++++++++++++++ internal/store/pagination.go | 155 +++++++++++ internal/store/store.go | 136 ++++++++++ internal/store/threads.go | 180 +++++++++++++ internal/store/types.go | 73 +++++ migrations/0001_init.sql | 40 +++ migrations/embed.go | 8 + 28 files changed, 1908 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/release.yml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 buf.gen.yaml create mode 100644 buf.yaml create mode 100644 charts/threads/Chart.yaml create mode 100644 charts/threads/templates/_helpers.tpl create mode 100644 charts/threads/templates/deployment.yaml create mode 100644 charts/threads/templates/service.yaml create mode 100644 charts/threads/templates/serviceaccount.yaml create mode 100644 charts/threads/values.yaml create mode 100644 cmd/threads/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/config/config.go create mode 100644 internal/db/migrate.go create mode 100644 internal/notifier/notifier.go create mode 100644 internal/server/converter.go create mode 100644 internal/server/server.go create mode 100644 internal/store/messages.go create mode 100644 internal/store/pagination.go create mode 100644 internal/store/store.go create mode 100644 internal/store/threads.go create mode 100644 internal/store/types.go create mode 100644 migrations/0001_init.sql create mode 100644 migrations/embed.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..348868b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + env: + GOTOOLCHAIN: local + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.24.x' + + - name: Install buf + run: | + curl -sSL https://github.com/bufbuild/buf/releases/download/v1.64.0/buf-Linux-x86_64.tar.gz \ + | sudo tar -xzf - -C /usr/local --strip-components=1 buf/bin/buf + + - name: Generate protobuf bindings + run: | + buf generate buf.build/agynio/api --path agynio/api/threads/v1 --path agynio/api/notifications/v1 + + - name: Run tests + run: go test ./... + + - name: Build + run: go build ./... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..79c3798 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,99 @@ +name: Release + +on: + push: + branches: + - main + tags: + - 'v*' + +permissions: + contents: read + packages: write + +jobs: + publish-edge: + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push edge image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: true + tags: ghcr.io/agynio/threads:edge + + publish-release: + if: startsWith(github.ref, 'refs/tags/v') + runs-on: ubuntu-latest + env: + GOTOOLCHAIN: local + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.24.x' + + - uses: docker/setup-buildx-action@v3 + + - name: Install buf + run: | + curl -sSL https://github.com/bufbuild/buf/releases/download/v1.64.0/buf-Linux-x86_64.tar.gz \ + | sudo tar -xzf - -C /usr/local --strip-components=1 buf/bin/buf + + - name: Set release version + run: echo "VERSION=${GITHUB_REF_NAME#v}" >> "$GITHUB_ENV" + + - name: Generate protobufs + run: make proto + + - name: Run tests + run: go test ./... + + - name: Build binaries + run: go build ./... + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push release images + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: true + tags: | + ghcr.io/agynio/threads:${{ env.VERSION }} + ghcr.io/agynio/threads:latest + ghcr.io/agynio/threads:${{ github.sha }} + + - name: Setup Helm + uses: azure/setup-helm@v4 + + - name: Package Helm chart + run: | + helm package charts/threads --version "${VERSION}" --app-version "${VERSION}" + + - name: Push Helm chart to GHCR + env: + VERSION: ${{ env.VERSION }} + run: | + helm registry login ghcr.io --username ${{ github.actor }} --password ${{ secrets.GITHUB_TOKEN }} + helm push threads-${VERSION}.tgz oci://ghcr.io/agynio/charts diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..58587e0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.local +.env +.DS_Store +gen/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d2be725 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +# syntax=docker/dockerfile:1.8 + +FROM golang:1.24 AS builder + +ARG TARGETOS +ARG TARGETARCH + +WORKDIR /src + +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg/mod \ + go mod download + +COPY . . + +RUN --mount=type=cache,target=/go/pkg/mod \ + CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ + go build -trimpath -ldflags "-s -w" -o /out/threads ./cmd/threads + +FROM gcr.io/distroless/base-debian12 AS runtime + +WORKDIR /app + +COPY --from=builder /out/threads /usr/local/bin/threads + +USER nonroot:nonroot + +ENTRYPOINT ["/usr/local/bin/threads"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..81b3554 --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +SHELL := /bin/bash + +PROTO_PATHS := agynio/api/threads/v1 agynio/api/notifications/v1 + +.PHONY: all proto build test lint fmt clean + +all: build + +proto: + buf generate buf.build/agynio/api $(foreach p,$(PROTO_PATHS),--path $(p)) + +build: + GOFLAGS=-mod=mod go build ./... + +test: + GOFLAGS=-mod=mod go test ./... + +lint: + GOFLAGS=-mod=mod go vet ./... + +fmt: + gofmt -w $(shell find . -type f -name '*.go') + +clean: + rm -rf gen diff --git a/buf.gen.yaml b/buf.gen.yaml new file mode 100644 index 0000000..c8b741f --- /dev/null +++ b/buf.gen.yaml @@ -0,0 +1,16 @@ +version: v1 + +plugins: + - plugin: buf.build/protocolbuffers/go + out: gen/go + opt: + - paths=source_relative + - Magynio/api/threads/v1/threads.proto=github.com/agynio/threads/gen/go/agynio/api/threads/v1 + - Magynio/api/notifications/v1/notifications.proto=github.com/agynio/threads/gen/go/agynio/api/notifications/v1 + - plugin: buf.build/grpc/go + out: gen/go + opt: + - paths=source_relative + - require_unimplemented_servers=false + - Magynio/api/threads/v1/threads.proto=github.com/agynio/threads/gen/go/agynio/api/threads/v1 + - Magynio/api/notifications/v1/notifications.proto=github.com/agynio/threads/gen/go/agynio/api/notifications/v1 diff --git a/buf.yaml b/buf.yaml new file mode 100644 index 0000000..1569982 --- /dev/null +++ b/buf.yaml @@ -0,0 +1,4 @@ +version: v1 + +deps: + - buf.build/agynio/api diff --git a/charts/threads/Chart.yaml b/charts/threads/Chart.yaml new file mode 100644 index 0000000..0191b60 --- /dev/null +++ b/charts/threads/Chart.yaml @@ -0,0 +1,6 @@ +apiVersion: v2 +name: threads +description: Helm chart for the agynio threads gRPC service +type: application +version: 0.1.0 +appVersion: 0.1.0 diff --git a/charts/threads/templates/_helpers.tpl b/charts/threads/templates/_helpers.tpl new file mode 100644 index 0000000..637e009 --- /dev/null +++ b/charts/threads/templates/_helpers.tpl @@ -0,0 +1,42 @@ +{{- define "threads.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}} +{{- end -}} + +{{- define "threads.fullname" -}} +{{- if .Values.fullnameOverride -}} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" -}} +{{- else -}} +{{- $name := default .Chart.Name .Values.nameOverride -}} +{{- if contains $name .Release.Name -}} +{{- .Release.Name | trunc 63 | trimSuffix "-" -}} +{{- else -}} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" -}} +{{- end -}} +{{- end -}} +{{- end -}} + +{{- define "threads.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" -}} +{{- end -}} + +{{- define "threads.labels" -}} +helm.sh/chart: {{ include "threads.chart" . }} +{{ include "threads.selectorLabels" . }} +{{- with .Chart.AppVersion }} +app.kubernetes.io/version: {{ . | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end -}} + +{{- define "threads.selectorLabels" -}} +app.kubernetes.io/name: {{ include "threads.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end -}} + +{{- define "threads.serviceAccountName" -}} +{{- if .Values.serviceAccount.create -}} +{{- default (include "threads.fullname" .) .Values.serviceAccount.name -}} +{{- else -}} +{{- default "default" .Values.serviceAccount.name -}} +{{- end -}} +{{- end -}} diff --git a/charts/threads/templates/deployment.yaml b/charts/threads/templates/deployment.yaml new file mode 100644 index 0000000..06c09fd --- /dev/null +++ b/charts/threads/templates/deployment.yaml @@ -0,0 +1,112 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "threads.fullname" . }} + labels: + {{- include "threads.labels" . | nindent 4 }} +spec: + replicas: {{ .Values.replicaCount }} + selector: + matchLabels: + {{- include "threads.selectorLabels" . | nindent 6 }} + template: + metadata: + labels: + {{- include "threads.selectorLabels" . | nindent 8 }} + {{- with .Values.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + serviceAccountName: {{ include "threads.serviceAccountName" . }} + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.podSecurityContext }} + securityContext: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: threads + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + ports: + - name: grpc + containerPort: {{ .Values.service.port }} + protocol: TCP + {{- if .Values.probes.liveness.enabled }} + livenessProbe: + tcpSocket: + port: grpc + initialDelaySeconds: {{ .Values.probes.liveness.initialDelaySeconds }} + periodSeconds: {{ .Values.probes.liveness.periodSeconds }} + timeoutSeconds: {{ .Values.probes.liveness.timeoutSeconds }} + failureThreshold: {{ .Values.probes.liveness.failureThreshold }} + successThreshold: {{ .Values.probes.liveness.successThreshold }} + {{- end }} + {{- if .Values.probes.readiness.enabled }} + readinessProbe: + tcpSocket: + port: grpc + initialDelaySeconds: {{ .Values.probes.readiness.initialDelaySeconds }} + periodSeconds: {{ .Values.probes.readiness.periodSeconds }} + timeoutSeconds: {{ .Values.probes.readiness.timeoutSeconds }} + failureThreshold: {{ .Values.probes.readiness.failureThreshold }} + successThreshold: {{ .Values.probes.readiness.successThreshold }} + {{- end }} + {{- with .Values.securityContext }} + securityContext: + {{- toYaml . | nindent 12 }} + {{- end }} + env: + {{- $db := .Values.database -}} + {{- $url := "" -}} + {{- $secretName := "" -}} + {{- $secretKey := "database-url" -}} + {{- if $db }} + {{- with $db.url }} + {{- $url = . -}} + {{- end }} + {{- with $db.existingSecret }} + {{- with .name }} + {{- $secretName = . -}} + {{- end }} + {{- with .key }} + {{- $secretKey = . -}} + {{- end }} + {{- end }} + {{- end }} + {{- if and $secretName $url }} + {{- fail "set only one of database.url or database.existingSecret.name" -}} + {{- end }} + - name: DATABASE_URL + {{- if $secretName }} + valueFrom: + secretKeyRef: + name: {{ $secretName }} + key: {{ $secretKey }} + {{- else }} + value: {{ required "database.url or database.existingSecret.name must be set" $url | quote }} + {{- end }} + - name: NOTIFICATIONS_ADDRESS + value: {{ required "notifications.address must be set" .Values.notifications.address | quote }} + {{- with .Values.resources }} + resources: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/threads/templates/service.yaml b/charts/threads/templates/service.yaml new file mode 100644 index 0000000..c7992d7 --- /dev/null +++ b/charts/threads/templates/service.yaml @@ -0,0 +1,22 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "threads.fullname" . }} + labels: + {{- include "threads.labels" . | nindent 4 }} + {{- with .Values.service.labels }} + {{- toYaml . | nindent 4 }} + {{- end }} + {{- with .Values.service.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + type: {{ .Values.service.type }} + selector: + {{- include "threads.selectorLabels" . | nindent 4 }} + ports: + - name: grpc + port: {{ .Values.service.port }} + targetPort: grpc + protocol: TCP diff --git a/charts/threads/templates/serviceaccount.yaml b/charts/threads/templates/serviceaccount.yaml new file mode 100644 index 0000000..10711ee --- /dev/null +++ b/charts/threads/templates/serviceaccount.yaml @@ -0,0 +1,12 @@ +{{- if .Values.serviceAccount.create -}} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ include "threads.serviceAccountName" . }} + labels: + {{- include "threads.labels" . | nindent 4 }} + {{- with .Values.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +{{- end }} diff --git a/charts/threads/values.yaml b/charts/threads/values.yaml new file mode 100644 index 0000000..0b21ab3 --- /dev/null +++ b/charts/threads/values.yaml @@ -0,0 +1,61 @@ +replicaCount: 1 + +image: + repository: ghcr.io/agynio/threads + tag: "" + pullPolicy: IfNotPresent + +imagePullSecrets: [] +nameOverride: "" +fullnameOverride: "" + +serviceAccount: + create: true + annotations: {} + name: "" + +podAnnotations: {} +podLabels: {} + +podSecurityContext: {} + +securityContext: {} + +service: + type: ClusterIP + port: 50051 + annotations: {} + labels: {} + +resources: {} + +nodeSelector: {} + +tolerations: [] + +affinity: {} + +probes: + liveness: + enabled: true + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 1 + failureThreshold: 3 + successThreshold: 1 + readiness: + enabled: true + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 1 + failureThreshold: 3 + successThreshold: 1 + +database: + url: "" + existingSecret: + name: "" + key: database-url + +notifications: + address: "" diff --git a/cmd/threads/main.go b/cmd/threads/main.go new file mode 100644 index 0000000..66d7113 --- /dev/null +++ b/cmd/threads/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log" + "net" + "os" + "os/signal" + "syscall" + + notificationsv1 "github.com/agynio/threads/gen/go/agynio/api/notifications/v1" + threadsv1 "github.com/agynio/threads/gen/go/agynio/api/threads/v1" + "github.com/jackc/pgx/v5/pgxpool" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/agynio/threads/internal/config" + "github.com/agynio/threads/internal/db" + "github.com/agynio/threads/internal/notifier" + "github.com/agynio/threads/internal/server" + "github.com/agynio/threads/internal/store" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("threads: %v", err) + } +} + +func run() error { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + cfg, err := config.FromEnv() + if err != nil { + return err + } + + poolCfg, err := pgxpool.ParseConfig(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("parse database url: %w", err) + } + pool, err := pgxpool.NewWithConfig(ctx, poolCfg) + if err != nil { + return fmt.Errorf("create connection pool: %w", err) + } + defer pool.Close() + + if err := db.ApplyMigrations(ctx, pool); err != nil { + return fmt.Errorf("apply migrations: %w", err) + } + + notificationsConn, err := grpc.DialContext(ctx, cfg.NotificationsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return fmt.Errorf("dial notifications: %w", err) + } + defer notificationsConn.Close() + + threadsServer := grpc.NewServer() + threadsv1.RegisterThreadsServiceServer( + threadsServer, + server.New(store.NewStore(pool), notifier.New(notificationsv1.NewNotificationsServiceClient(notificationsConn))), + ) + + lis, err := net.Listen("tcp", cfg.GRPCAddress) + if err != nil { + return fmt.Errorf("listen on %s: %w", cfg.GRPCAddress, err) + } + + go func() { + <-ctx.Done() + threadsServer.GracefulStop() + }() + + log.Printf("ThreadsService listening on %s", cfg.GRPCAddress) + if err := threadsServer.Serve(lis); err != nil { + if errors.Is(err, grpc.ErrServerStopped) { + return nil + } + return fmt.Errorf("serve: %w", err) + } + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..64e2b95 --- /dev/null +++ b/go.mod @@ -0,0 +1,23 @@ +module github.com/agynio/threads + +go 1.24.0 + +toolchain go1.24.13 + +require ( + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.8.0 + google.golang.org/grpc v1.79.1 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f930091 --- /dev/null +++ b/go.sum @@ -0,0 +1,62 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= +google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..b8e1884 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + "os" +) + +type Config struct { + GRPCAddress string + DatabaseURL string + NotificationsAddress string +} + +func FromEnv() (Config, error) { + cfg := Config{} + cfg.GRPCAddress = os.Getenv("GRPC_ADDRESS") + if cfg.GRPCAddress == "" { + cfg.GRPCAddress = ":50051" + } + cfg.DatabaseURL = os.Getenv("DATABASE_URL") + if cfg.DatabaseURL == "" { + return Config{}, fmt.Errorf("DATABASE_URL must be set") + } + cfg.NotificationsAddress = os.Getenv("NOTIFICATIONS_ADDRESS") + if cfg.NotificationsAddress == "" { + return Config{}, fmt.Errorf("NOTIFICATIONS_ADDRESS must be set") + } + return cfg, nil +} diff --git a/internal/db/migrate.go b/internal/db/migrate.go new file mode 100644 index 0000000..1d5a5a2 --- /dev/null +++ b/internal/db/migrate.go @@ -0,0 +1,57 @@ +package db + +import ( + "context" + "fmt" + "io/fs" + "sort" + + "github.com/agynio/threads/migrations" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +func ApplyMigrations(ctx context.Context, pool *pgxpool.Pool) error { + conn, err := pool.Acquire(ctx) + if err != nil { + return fmt.Errorf("acquire connection: %w", err) + } + defer conn.Release() + + return pgx.BeginFunc(ctx, conn.Conn(), func(tx pgx.Tx) error { + if _, err := tx.Exec(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())`); err != nil { + return fmt.Errorf("ensure schema_migrations: %w", err) + } + + entries, err := fs.ReadDir(migrations.Files, ".") + if err != nil { + return fmt.Errorf("read migrations: %w", err) + } + sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + version := entry.Name() + var applied bool + if err := tx.QueryRow(ctx, `SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)`, version).Scan(&applied); err != nil { + return fmt.Errorf("check migration %s: %w", version, err) + } + if applied { + continue + } + content, err := migrations.Files.ReadFile(version) + if err != nil { + return fmt.Errorf("read migration %s: %w", version, err) + } + if _, err := tx.Exec(ctx, string(content)); err != nil { + return fmt.Errorf("apply migration %s: %w", version, err) + } + if _, err := tx.Exec(ctx, `INSERT INTO schema_migrations (version) VALUES ($1)`, version); err != nil { + return fmt.Errorf("record migration %s: %w", version, err) + } + } + return nil + }) +} diff --git a/internal/notifier/notifier.go b/internal/notifier/notifier.go new file mode 100644 index 0000000..b4971b7 --- /dev/null +++ b/internal/notifier/notifier.go @@ -0,0 +1,51 @@ +package notifier + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "google.golang.org/protobuf/types/known/structpb" + + notificationsv1 "github.com/agynio/threads/gen/go/agynio/api/notifications/v1" +) + +const ( + messageCreatedEvent = "message.created" + messageSource = "threads" +) + +type Notifier struct { + client notificationsv1.NotificationsServiceClient +} + +func New(client notificationsv1.NotificationsServiceClient) *Notifier { + return &Notifier{client: client} +} + +func (n *Notifier) PublishMessageCreated(ctx context.Context, threadID, messageID uuid.UUID, recipients []uuid.UUID) error { + if len(recipients) == 0 { + return nil + } + rooms := make([]string, len(recipients)) + for i, recipient := range recipients { + rooms[i] = fmt.Sprintf("thread_participant:%s", recipient) + } + payload, err := structpb.NewStruct(map[string]any{ + "thread_id": threadID.String(), + "message_id": messageID.String(), + }) + if err != nil { + return fmt.Errorf("build payload: %w", err) + } + _, err = n.client.Publish(ctx, ¬ificationsv1.PublishRequest{ + Event: messageCreatedEvent, + Rooms: rooms, + Payload: payload, + Source: messageSource, + }) + if err != nil { + return fmt.Errorf("publish notification: %w", err) + } + return nil +} diff --git a/internal/server/converter.go b/internal/server/converter.go new file mode 100644 index 0000000..bd28c47 --- /dev/null +++ b/internal/server/converter.go @@ -0,0 +1,53 @@ +package server + +import ( + threadsv1 "github.com/agynio/threads/gen/go/agynio/api/threads/v1" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/agynio/threads/internal/store" +) + +func toProtoThread(thread store.Thread) *threadsv1.Thread { + protoThread := &threadsv1.Thread{ + Id: thread.ID.String(), + Status: toProtoThreadStatus(thread.Status), + CreatedAt: timestamppb.New(thread.CreatedAt), + UpdatedAt: timestamppb.New(thread.UpdatedAt), + } + if len(thread.Participants) > 0 { + protoThread.Participants = make([]*threadsv1.Participant, len(thread.Participants)) + for i, participant := range thread.Participants { + protoThread.Participants[i] = &threadsv1.Participant{ + Id: participant.ID.String(), + JoinedAt: timestamppb.New(participant.JoinedAt), + } + } + } + return protoThread +} + +func toProtoMessage(message store.Message) *threadsv1.Message { + fileIDs := make([]string, len(message.FileIDs)) + for i, id := range message.FileIDs { + fileIDs[i] = id.String() + } + return &threadsv1.Message{ + Id: message.ID.String(), + ThreadId: message.ThreadID.String(), + SenderId: message.SenderID.String(), + Body: message.Body, + FileIds: fileIDs, + CreatedAt: timestamppb.New(message.CreatedAt), + } +} + +func toProtoThreadStatus(status store.ThreadStatus) threadsv1.ThreadStatus { + switch status { + case store.ThreadStatusActive: + return threadsv1.ThreadStatus_THREAD_STATUS_ACTIVE + case store.ThreadStatusArchived: + return threadsv1.ThreadStatus_THREAD_STATUS_ARCHIVED + default: + return threadsv1.ThreadStatus_THREAD_STATUS_UNSPECIFIED + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..4d3948a --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,263 @@ +package server + +import ( + "context" + "errors" + "fmt" + + threadsv1 "github.com/agynio/threads/gen/go/agynio/api/threads/v1" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/agynio/threads/internal/notifier" + "github.com/agynio/threads/internal/store" +) + +type Server struct { + threadsv1.UnimplementedThreadsServiceServer + store *store.Store + notifier *notifier.Notifier +} + +func New(store *store.Store, notifier *notifier.Notifier) *Server { + return &Server{store: store, notifier: notifier} +} + +func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRequest) (*threadsv1.CreateThreadResponse, error) { + ids := req.GetParticipantIds() + if len(ids) == 0 { + return nil, status.Error(codes.InvalidArgument, "participant_ids must be provided") + } + participantIDs := make([]uuid.UUID, len(ids)) + seen := make(map[uuid.UUID]struct{}, len(ids)) + for i, raw := range ids { + id, err := parseUUID(raw) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_ids[%d]: %v", i, err) + } + if _, ok := seen[id]; ok { + return nil, status.Errorf(codes.InvalidArgument, "participant_ids[%d]: duplicate participant", i) + } + seen[id] = struct{}{} + participantIDs[i] = id + } + + thread, err := s.store.CreateThread(ctx, participantIDs) + if err != nil { + return nil, toStatusError(err) + } + return &threadsv1.CreateThreadResponse{Thread: toProtoThread(thread)}, nil +} + +func (s *Server) ArchiveThread(ctx context.Context, req *threadsv1.ArchiveThreadRequest) (*threadsv1.ArchiveThreadResponse, error) { + threadID, err := parseUUID(req.GetThreadId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "thread_id: %v", err) + } + thread, err := s.store.ArchiveThread(ctx, threadID) + if err != nil { + return nil, toStatusError(err) + } + return &threadsv1.ArchiveThreadResponse{Thread: toProtoThread(thread)}, nil +} + +func (s *Server) AddParticipant(ctx context.Context, req *threadsv1.AddParticipantRequest) (*threadsv1.AddParticipantResponse, error) { + threadID, err := parseUUID(req.GetThreadId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "thread_id: %v", err) + } + participantID, err := parseUUID(req.GetParticipantId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err) + } + thread, err := s.store.AddParticipant(ctx, threadID, participantID) + if err != nil { + return nil, toStatusError(err) + } + return &threadsv1.AddParticipantResponse{Thread: toProtoThread(thread)}, nil +} + +func (s *Server) SendMessage(ctx context.Context, req *threadsv1.SendMessageRequest) (*threadsv1.SendMessageResponse, error) { + threadID, err := parseUUID(req.GetThreadId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "thread_id: %v", err) + } + senderID, err := parseUUID(req.GetSenderId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "sender_id: %v", err) + } + if req.GetBody() == "" && len(req.GetFileIds()) == 0 { + return nil, status.Error(codes.InvalidArgument, "body or file_ids must be provided") + } + fileIDs := make([]uuid.UUID, len(req.GetFileIds())) + for i, raw := range req.GetFileIds() { + id, err := parseUUID(raw) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "file_ids[%d]: %v", i, err) + } + fileIDs[i] = id + } + + result, err := s.store.SendMessage(ctx, threadID, senderID, req.GetBody(), fileIDs) + if err != nil { + return nil, toStatusError(err) + } + if err := s.notifier.PublishMessageCreated(ctx, threadID, result.Message.ID, result.Recipients); err != nil { + return nil, status.Errorf(codes.Internal, "notify recipients: %v", err) + } + return &threadsv1.SendMessageResponse{Message: toProtoMessage(result.Message)}, nil +} + +func (s *Server) GetThreads(ctx context.Context, req *threadsv1.GetThreadsRequest) (*threadsv1.GetThreadsResponse, error) { + participantID, err := parseUUID(req.GetParticipantId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err) + } + var cursor *store.ThreadCursor + if token := req.GetPageToken(); token != "" { + tokenID, tokenCursor, err := store.DecodeThreadPageToken(token) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid page_token: %v", err) + } + if tokenID != participantID { + return nil, status.Error(codes.InvalidArgument, "page_token does not match participant") + } + cursor = &tokenCursor + } + + result, err := s.store.ListThreads(ctx, participantID, req.GetPageSize(), cursor) + if err != nil { + return nil, toStatusError(err) + } + resp := &threadsv1.GetThreadsResponse{Threads: make([]*threadsv1.Thread, len(result.Threads))} + for i, thread := range result.Threads { + resp.Threads[i] = toProtoThread(thread) + } + if result.NextCursor != nil { + token, err := store.EncodeThreadPageToken(participantID, *result.NextCursor) + if err != nil { + return nil, status.Errorf(codes.Internal, "encode page token: %v", err) + } + resp.NextPageToken = token + } + return resp, nil +} + +func (s *Server) GetMessages(ctx context.Context, req *threadsv1.GetMessagesRequest) (*threadsv1.GetMessagesResponse, error) { + threadID, err := parseUUID(req.GetThreadId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "thread_id: %v", err) + } + var cursor *store.MessageCursor + if token := req.GetPageToken(); token != "" { + tokenID, tokenCursor, err := store.DecodeThreadMessagePageToken(token) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid page_token: %v", err) + } + if tokenID != threadID { + return nil, status.Error(codes.InvalidArgument, "page_token does not match thread") + } + cursor = &tokenCursor + } + + result, err := s.store.ListMessages(ctx, threadID, req.GetPageSize(), cursor) + if err != nil { + return nil, toStatusError(err) + } + resp := &threadsv1.GetMessagesResponse{Messages: make([]*threadsv1.Message, len(result.Messages))} + for i, message := range result.Messages { + resp.Messages[i] = toProtoMessage(message) + } + if result.NextCursor != nil { + token, err := store.EncodeThreadMessagePageToken(threadID, *result.NextCursor) + if err != nil { + return nil, status.Errorf(codes.Internal, "encode page token: %v", err) + } + resp.NextPageToken = token + } + return resp, nil +} + +func (s *Server) GetUnackedMessages(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest) (*threadsv1.GetUnackedMessagesResponse, error) { + participantID, err := parseUUID(req.GetParticipantId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err) + } + var cursor *store.MessageCursor + if token := req.GetPageToken(); token != "" { + tokenID, tokenCursor, err := store.DecodeUnackedMessagePageToken(token) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid page_token: %v", err) + } + if tokenID != participantID { + return nil, status.Error(codes.InvalidArgument, "page_token does not match participant") + } + cursor = &tokenCursor + } + + result, err := s.store.ListUnackedMessages(ctx, participantID, req.GetPageSize(), cursor) + if err != nil { + return nil, toStatusError(err) + } + resp := &threadsv1.GetUnackedMessagesResponse{Messages: make([]*threadsv1.Message, len(result.Messages))} + for i, message := range result.Messages { + resp.Messages[i] = toProtoMessage(message) + } + if result.NextCursor != nil { + token, err := store.EncodeUnackedMessagePageToken(participantID, *result.NextCursor) + if err != nil { + return nil, status.Errorf(codes.Internal, "encode page token: %v", err) + } + resp.NextPageToken = token + } + return resp, nil +} + +func (s *Server) AckMessages(ctx context.Context, req *threadsv1.AckMessagesRequest) (*threadsv1.AckMessagesResponse, error) { + participantID, err := parseUUID(req.GetParticipantId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err) + } + if len(req.GetMessageIds()) == 0 { + return nil, status.Error(codes.InvalidArgument, "message_ids must be provided") + } + messageIDs := make([]uuid.UUID, len(req.GetMessageIds())) + for i, raw := range req.GetMessageIds() { + id, err := parseUUID(raw) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "message_ids[%d]: %v", i, err) + } + messageIDs[i] = id + } + + count, err := s.store.AckMessages(ctx, participantID, messageIDs) + if err != nil { + return nil, toStatusError(err) + } + return &threadsv1.AckMessagesResponse{AckedCount: count}, nil +} + +func parseUUID(value string) (uuid.UUID, error) { + if value == "" { + return uuid.UUID{}, fmt.Errorf("value is empty") + } + id, err := uuid.Parse(value) + if err != nil { + return uuid.UUID{}, err + } + return id, nil +} + +func toStatusError(err error) error { + switch { + case errors.Is(err, store.ErrThreadNotFound): + return status.Error(codes.NotFound, err.Error()) + case errors.Is(err, store.ErrThreadArchived): + return status.Error(codes.FailedPrecondition, err.Error()) + case errors.Is(err, store.ErrParticipantNotInThread): + return status.Error(codes.InvalidArgument, err.Error()) + default: + return status.Errorf(codes.Internal, "internal error: %v", err) + } +} diff --git a/internal/store/messages.go b/internal/store/messages.go new file mode 100644 index 0000000..d0c1847 --- /dev/null +++ b/internal/store/messages.go @@ -0,0 +1,228 @@ +package store + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +type SendMessageResult struct { + Message Message + Recipients []uuid.UUID +} + +func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, body string, fileIDs []uuid.UUID) (SendMessageResult, error) { + var result SendMessageResult + err := s.runTx(ctx, func(tx pgx.Tx) error { + status, _, _, err := loadThreadRow(ctx, tx, threadID, true) + if err != nil { + return err + } + if status == ThreadStatusArchived { + return ErrThreadArchived + } + var isParticipant bool + if err := tx.QueryRow(ctx, `SELECT EXISTS (SELECT 1 FROM thread_participants WHERE thread_id = $1 AND participant_id = $2)`, threadID, senderID).Scan(&isParticipant); err != nil { + return err + } + if !isParticipant { + return ErrParticipantNotInThread + } + now := time.Now().UTC() + messageID := uuid.New() + fileIDArray := pgtype.FlatArray[string](uuidsToStrings(fileIDs)) + if _, err := tx.Exec(ctx, `INSERT INTO messages (id, thread_id, sender_id, body, file_ids, created_at) VALUES ($1, $2, $3, $4, $5, $6)`, messageID, threadID, senderID, body, fileIDArray, now); err != nil { + return err + } + recipients, err := loadRecipients(ctx, tx, threadID, senderID) + if err != nil { + return err + } + if len(recipients) > 0 { + rows := make([][]any, len(recipients)) + for i, recipientID := range recipients { + rows[i] = []any{messageID, threadID, recipientID} + } + if _, err := tx.CopyFrom(ctx, pgx.Identifier{"message_recipients"}, []string{"message_id", "thread_id", "participant_id"}, pgx.CopyFromRows(rows)); err != nil { + return err + } + } + if _, err := tx.Exec(ctx, `UPDATE threads SET updated_at = $2 WHERE id = $1`, threadID, now); err != nil { + return err + } + result = SendMessageResult{ + Message: Message{ + ID: messageID, + ThreadID: threadID, + SenderID: senderID, + Body: body, + FileIDs: fileIDs, + CreatedAt: now, + }, + Recipients: recipients, + } + return nil + }) + if err != nil { + return SendMessageResult{}, err + } + return result, nil +} + +func loadRecipients(ctx context.Context, q queryer, threadID, senderID uuid.UUID) ([]uuid.UUID, error) { + rows, err := q.Query(ctx, `SELECT participant_id FROM thread_participants WHERE thread_id = $1 AND participant_id <> $2 ORDER BY participant_id ASC`, threadID, senderID) + if err != nil { + return nil, err + } + defer rows.Close() + + recipients := []uuid.UUID{} + for rows.Next() { + var recipientID uuid.UUID + if err := rows.Scan(&recipientID); err != nil { + return nil, err + } + recipients = append(recipients, recipientID) + } + if err := rows.Err(); err != nil { + return nil, err + } + return recipients, nil +} + +func (s *Store) ListMessages(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *MessageCursor) (MessageListResult, error) { + if err := ensureThreadExists(ctx, s.pool, threadID); err != nil { + return MessageListResult{}, err + } + limit := normalizePageSize(pageSize) + query := strings.Builder{} + query.WriteString(`SELECT id, thread_id, sender_id, body, file_ids, created_at + FROM messages + WHERE thread_id = $1`) + args := []any{threadID} + paramIndex := 2 + if cursor != nil { + query.WriteString(fmt.Sprintf(" AND (created_at, id) > ($%d, $%d)", paramIndex, paramIndex+1)) + args = append(args, cursor.CreatedAt, cursor.MessageID) + paramIndex += 2 + } + query.WriteString(fmt.Sprintf(" ORDER BY created_at ASC, id ASC LIMIT $%d", paramIndex)) + args = append(args, int(limit)+1) + + rows, err := s.pool.Query(ctx, query.String(), args...) + if err != nil { + return MessageListResult{}, err + } + defer rows.Close() + + messages := make([]Message, 0, limit) + var ( + nextCursor *MessageCursor + lastID uuid.UUID + lastTime time.Time + hasMore bool + ) + for rows.Next() { + var msg Message + var fileIDs []string + if err := rows.Scan(&msg.ID, &msg.ThreadID, &msg.SenderID, &msg.Body, &fileIDs, &msg.CreatedAt); err != nil { + return MessageListResult{}, err + } + if int32(len(messages)) == limit { + hasMore = true + break + } + parsedIDs, err := stringsToUUIDs(fileIDs) + if err != nil { + return MessageListResult{}, fmt.Errorf("parse file ids: %w", err) + } + msg.FileIDs = parsedIDs + messages = append(messages, msg) + lastID = msg.ID + lastTime = msg.CreatedAt + } + if err := rows.Err(); err != nil { + return MessageListResult{}, err + } + if hasMore { + nextCursor = &MessageCursor{CreatedAt: lastTime, MessageID: lastID} + } + return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil +} + +func (s *Store) ListUnackedMessages(ctx context.Context, participantID uuid.UUID, pageSize int32, cursor *MessageCursor) (MessageListResult, error) { + limit := normalizePageSize(pageSize) + query := strings.Builder{} + query.WriteString(`SELECT m.id, m.thread_id, m.sender_id, m.body, m.file_ids, m.created_at + FROM message_recipients mr + JOIN messages m ON m.id = mr.message_id + WHERE mr.participant_id = $1 AND mr.acked_at IS NULL`) + args := []any{participantID} + paramIndex := 2 + if cursor != nil { + query.WriteString(fmt.Sprintf(" AND (m.created_at, m.id) > ($%d, $%d)", paramIndex, paramIndex+1)) + args = append(args, cursor.CreatedAt, cursor.MessageID) + paramIndex += 2 + } + query.WriteString(fmt.Sprintf(" ORDER BY m.created_at ASC, m.id ASC LIMIT $%d", paramIndex)) + args = append(args, int(limit)+1) + + rows, err := s.pool.Query(ctx, query.String(), args...) + if err != nil { + return MessageListResult{}, err + } + defer rows.Close() + + messages := make([]Message, 0, limit) + var ( + nextCursor *MessageCursor + lastID uuid.UUID + lastTime time.Time + hasMore bool + ) + for rows.Next() { + var msg Message + var fileIDs []string + if err := rows.Scan(&msg.ID, &msg.ThreadID, &msg.SenderID, &msg.Body, &fileIDs, &msg.CreatedAt); err != nil { + return MessageListResult{}, err + } + if int32(len(messages)) == limit { + hasMore = true + break + } + parsedIDs, err := stringsToUUIDs(fileIDs) + if err != nil { + return MessageListResult{}, fmt.Errorf("parse file ids: %w", err) + } + msg.FileIDs = parsedIDs + messages = append(messages, msg) + lastID = msg.ID + lastTime = msg.CreatedAt + } + if err := rows.Err(); err != nil { + return MessageListResult{}, err + } + if hasMore { + nextCursor = &MessageCursor{CreatedAt: lastTime, MessageID: lastID} + } + return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil +} + +func (s *Store) AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) { + now := time.Now().UTC() + messageIDArray := pgtype.FlatArray[uuid.UUID](messageIDs) + cmd, err := s.pool.Exec(ctx, `UPDATE message_recipients SET acked_at = $1 WHERE participant_id = $2 AND message_id = ANY($3) AND acked_at IS NULL`, now, participantID, messageIDArray) + if err != nil { + return 0, err + } + count := cmd.RowsAffected() + if count > int64(^uint32(0)>>1) { + return 0, fmt.Errorf("acked count overflow: %d", count) + } + return int32(count), nil +} diff --git a/internal/store/pagination.go b/internal/store/pagination.go new file mode 100644 index 0000000..673c5c1 --- /dev/null +++ b/internal/store/pagination.go @@ -0,0 +1,155 @@ +package store + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/google/uuid" +) + +const ( + defaultPageSize int32 = 50 + maxPageSize int32 = 100 +) + +func normalizePageSize(size int32) int32 { + if size <= 0 { + return defaultPageSize + } + if size > maxPageSize { + return maxPageSize + } + return size +} + +type threadPageToken struct { + ParticipantID string `json:"participant_id"` + UpdatedAtNanos int64 `json:"updated_at_nanos"` + ThreadID string `json:"thread_id"` +} + +func EncodeThreadPageToken(participantID uuid.UUID, cursor ThreadCursor) (string, error) { + payload := threadPageToken{ + ParticipantID: participantID.String(), + UpdatedAtNanos: cursor.UpdatedAt.UnixNano(), + ThreadID: cursor.ThreadID.String(), + } + buf, err := json.Marshal(payload) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func DecodeThreadPageToken(token string) (uuid.UUID, ThreadCursor, error) { + if token == "" { + return uuid.UUID{}, ThreadCursor{}, errors.New("empty token") + } + data, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return uuid.UUID{}, ThreadCursor{}, fmt.Errorf("decode token: %w", err) + } + var payload threadPageToken + if err := json.Unmarshal(data, &payload); err != nil { + return uuid.UUID{}, ThreadCursor{}, fmt.Errorf("unmarshal token: %w", err) + } + participantID, err := uuid.Parse(payload.ParticipantID) + if err != nil { + return uuid.UUID{}, ThreadCursor{}, fmt.Errorf("parse participant id: %w", err) + } + threadID, err := uuid.Parse(payload.ThreadID) + if err != nil { + return uuid.UUID{}, ThreadCursor{}, fmt.Errorf("parse thread id: %w", err) + } + return participantID, ThreadCursor{ + UpdatedAt: time.Unix(0, payload.UpdatedAtNanos).UTC(), + ThreadID: threadID, + }, nil +} + +type messagePageToken struct { + OwnerID string `json:"owner_id"` + CreatedAtNanos int64 `json:"created_at_nanos"` + MessageID string `json:"message_id"` +} + +func EncodeThreadMessagePageToken(threadID uuid.UUID, cursor MessageCursor) (string, error) { + payload := messagePageToken{ + OwnerID: threadID.String(), + CreatedAtNanos: cursor.CreatedAt.UnixNano(), + MessageID: cursor.MessageID.String(), + } + buf, err := json.Marshal(payload) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func DecodeThreadMessagePageToken(token string) (uuid.UUID, MessageCursor, error) { + if token == "" { + return uuid.UUID{}, MessageCursor{}, errors.New("empty token") + } + data, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("decode token: %w", err) + } + var payload messagePageToken + if err := json.Unmarshal(data, &payload); err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("unmarshal token: %w", err) + } + ownerID, err := uuid.Parse(payload.OwnerID) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse owner id: %w", err) + } + messageID, err := uuid.Parse(payload.MessageID) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse message id: %w", err) + } + return ownerID, MessageCursor{ + CreatedAt: time.Unix(0, payload.CreatedAtNanos).UTC(), + MessageID: messageID, + }, nil +} + +func EncodeUnackedMessagePageToken(participantID uuid.UUID, cursor MessageCursor) (string, error) { + payload := messagePageToken{ + OwnerID: participantID.String(), + CreatedAtNanos: cursor.CreatedAt.UnixNano(), + MessageID: cursor.MessageID.String(), + } + buf, err := json.Marshal(payload) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func DecodeUnackedMessagePageToken(token string) (uuid.UUID, MessageCursor, error) { + if token == "" { + return uuid.UUID{}, MessageCursor{}, errors.New("empty token") + } + data, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("decode token: %w", err) + } + var payload messagePageToken + if err := json.Unmarshal(data, &payload); err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("unmarshal token: %w", err) + } + ownerID, err := uuid.Parse(payload.OwnerID) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse owner id: %w", err) + } + messageID, err := uuid.Parse(payload.MessageID) + if err != nil { + return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse message id: %w", err) + } + return ownerID, MessageCursor{ + CreatedAt: time.Unix(0, payload.CreatedAtNanos).UTC(), + MessageID: messageID, + }, nil +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..cdd84a1 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,136 @@ +package store + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +type Store struct { + pool *pgxpool.Pool +} + +func NewStore(pool *pgxpool.Pool) *Store { + return &Store{pool: pool} +} + +func (s *Store) runTx(ctx context.Context, fn func(pgx.Tx) error) error { + tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback(ctx) + } + }() + if err = fn(tx); err != nil { + return err + } + if err = tx.Commit(ctx); err != nil { + return err + } + return nil +} + +type queryer interface { + Query(context.Context, string, ...any) (pgx.Rows, error) + QueryRow(context.Context, string, ...any) pgx.Row +} + +func loadThreadRow(ctx context.Context, q queryer, id uuid.UUID, forUpdate bool) (ThreadStatus, time.Time, time.Time, error) { + query := "SELECT status, created_at, updated_at FROM threads WHERE id = $1" + if forUpdate { + query += " FOR UPDATE" + } + var statusValue int16 + var createdAt time.Time + var updatedAt time.Time + err := q.QueryRow(ctx, query, id).Scan(&statusValue, &createdAt, &updatedAt) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ThreadStatusUnspecified, time.Time{}, time.Time{}, ErrThreadNotFound + } + return ThreadStatusUnspecified, time.Time{}, time.Time{}, err + } + status, err := ParseThreadStatus(statusValue) + if err != nil { + return ThreadStatusUnspecified, time.Time{}, time.Time{}, fmt.Errorf("invalid thread status: %w", err) + } + return status, createdAt, updatedAt, nil +} + +func loadThread(ctx context.Context, q queryer, id uuid.UUID) (Thread, error) { + status, createdAt, updatedAt, err := loadThreadRow(ctx, q, id, false) + if err != nil { + return Thread{}, err + } + participants, err := loadParticipants(ctx, q, id) + if err != nil { + return Thread{}, err + } + return Thread{ + ID: id, + Status: status, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + Participants: participants, + }, nil +} + +func loadParticipants(ctx context.Context, q queryer, threadID uuid.UUID) ([]Participant, error) { + rows, err := q.Query(ctx, `SELECT participant_id, joined_at FROM thread_participants WHERE thread_id = $1 ORDER BY joined_at ASC, participant_id ASC`, threadID) + if err != nil { + return nil, err + } + defer rows.Close() + + participants := []Participant{} + for rows.Next() { + var participant Participant + if err := rows.Scan(&participant.ID, &participant.JoinedAt); err != nil { + return nil, err + } + participants = append(participants, participant) + } + if err := rows.Err(); err != nil { + return nil, err + } + return participants, nil +} + +func ensureThreadExists(ctx context.Context, q queryer, threadID uuid.UUID) error { + var exists bool + if err := q.QueryRow(ctx, `SELECT EXISTS (SELECT 1 FROM threads WHERE id = $1)`, threadID).Scan(&exists); err != nil { + return err + } + if !exists { + return ErrThreadNotFound + } + return nil +} + +func uuidsToStrings(ids []uuid.UUID) []string { + values := make([]string, len(ids)) + for i, id := range ids { + values[i] = id.String() + } + return values +} + +func stringsToUUIDs(values []string) ([]uuid.UUID, error) { + ids := make([]uuid.UUID, len(values)) + for i, raw := range values { + id, err := uuid.Parse(raw) + if err != nil { + return nil, fmt.Errorf("parse uuid: %w", err) + } + ids[i] = id + } + return ids, nil +} diff --git a/internal/store/threads.go b/internal/store/threads.go new file mode 100644 index 0000000..0ad709b --- /dev/null +++ b/internal/store/threads.go @@ -0,0 +1,180 @@ +package store + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" +) + +func (s *Store) CreateThread(ctx context.Context, participantIDs []uuid.UUID) (Thread, error) { + var thread Thread + err := s.runTx(ctx, func(tx pgx.Tx) error { + threadID := uuid.New() + now := time.Now().UTC() + if _, err := tx.Exec(ctx, `INSERT INTO threads (id, status, created_at, updated_at) VALUES ($1, $2, $3, $3)`, threadID, int16(ThreadStatusActive), now); err != nil { + return err + } + participants := make([]Participant, len(participantIDs)) + for i, participantID := range participantIDs { + if _, err := tx.Exec(ctx, `INSERT INTO thread_participants (thread_id, participant_id, joined_at) VALUES ($1, $2, $3)`, threadID, participantID, now); err != nil { + return err + } + participants[i] = Participant{ID: participantID, JoinedAt: now} + } + thread = Thread{ + ID: threadID, + Status: ThreadStatusActive, + CreatedAt: now, + UpdatedAt: now, + Participants: participants, + } + return nil + }) + if err != nil { + return Thread{}, err + } + return thread, nil +} + +func (s *Store) ArchiveThread(ctx context.Context, threadID uuid.UUID) (Thread, error) { + var thread Thread + err := s.runTx(ctx, func(tx pgx.Tx) error { + now := time.Now().UTC() + var createdAt time.Time + var updatedAt time.Time + if err := tx.QueryRow(ctx, `UPDATE threads SET status = $2, updated_at = $3 WHERE id = $1 RETURNING created_at, updated_at`, threadID, int16(ThreadStatusArchived), now).Scan(&createdAt, &updatedAt); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrThreadNotFound + } + return err + } + participants, err := loadParticipants(ctx, tx, threadID) + if err != nil { + return err + } + thread = Thread{ + ID: threadID, + Status: ThreadStatusArchived, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + Participants: participants, + } + return nil + }) + if err != nil { + return Thread{}, err + } + return thread, nil +} + +func (s *Store) AddParticipant(ctx context.Context, threadID, participantID uuid.UUID) (Thread, error) { + var thread Thread + err := s.runTx(ctx, func(tx pgx.Tx) error { + status, createdAt, updatedAt, err := loadThreadRow(ctx, tx, threadID, true) + if err != nil { + return err + } + if status == ThreadStatusArchived { + return ErrThreadArchived + } + now := time.Now().UTC() + cmd, err := tx.Exec(ctx, `INSERT INTO thread_participants (thread_id, participant_id, joined_at) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, threadID, participantID, now) + if err != nil { + return err + } + if cmd.RowsAffected() > 0 { + if _, err := tx.Exec(ctx, `UPDATE threads SET updated_at = $2 WHERE id = $1`, threadID, now); err != nil { + return err + } + updatedAt = now + } + participants, err := loadParticipants(ctx, tx, threadID) + if err != nil { + return err + } + thread = Thread{ + ID: threadID, + Status: status, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + Participants: participants, + } + return nil + }) + if err != nil { + return Thread{}, err + } + return thread, nil +} + +func (s *Store) ListThreads(ctx context.Context, participantID uuid.UUID, pageSize int32, cursor *ThreadCursor) (ThreadListResult, error) { + limit := normalizePageSize(pageSize) + query := strings.Builder{} + query.WriteString(`SELECT t.id, t.status, t.created_at, t.updated_at + FROM threads t + JOIN thread_participants tp ON tp.thread_id = t.id + WHERE tp.participant_id = $1`) + args := []any{participantID} + paramIndex := 2 + if cursor != nil { + query.WriteString(fmt.Sprintf(" AND (t.updated_at, t.id) < ($%d, $%d)", paramIndex, paramIndex+1)) + args = append(args, cursor.UpdatedAt, cursor.ThreadID) + paramIndex += 2 + } + query.WriteString(fmt.Sprintf(" ORDER BY t.updated_at DESC, t.id DESC LIMIT $%d", paramIndex)) + args = append(args, int(limit)+1) + + rows, err := s.pool.Query(ctx, query.String(), args...) + if err != nil { + return ThreadListResult{}, err + } + defer rows.Close() + + threads := make([]Thread, 0, limit) + var ( + nextCursor *ThreadCursor + lastID uuid.UUID + lastTime time.Time + hasMore bool + ) + for rows.Next() { + var thread Thread + var statusValue int16 + if err := rows.Scan(&thread.ID, &statusValue, &thread.CreatedAt, &thread.UpdatedAt); err != nil { + return ThreadListResult{}, err + } + if int32(len(threads)) == limit { + hasMore = true + break + } + status, err := ParseThreadStatus(statusValue) + if err != nil { + return ThreadListResult{}, fmt.Errorf("invalid thread status: %w", err) + } + thread.Status = status + threads = append(threads, thread) + lastID = thread.ID + lastTime = thread.UpdatedAt + } + if err := rows.Err(); err != nil { + return ThreadListResult{}, err + } + if hasMore { + nextCursor = &ThreadCursor{UpdatedAt: lastTime, ThreadID: lastID} + } + + for i := range threads { + participants, err := loadParticipants(ctx, s.pool, threads[i].ID) + if err != nil { + return ThreadListResult{}, err + } + threads[i].Participants = participants + } + + return ThreadListResult{Threads: threads, NextCursor: nextCursor}, nil +} diff --git a/internal/store/types.go b/internal/store/types.go new file mode 100644 index 0000000..270ae67 --- /dev/null +++ b/internal/store/types.go @@ -0,0 +1,73 @@ +package store + +import ( + "errors" + "time" + + "github.com/google/uuid" +) + +var ( + ErrThreadNotFound = errors.New("thread not found") + ErrThreadArchived = errors.New("thread is archived") + ErrParticipantNotInThread = errors.New("participant not in thread") +) + +type ThreadStatus int16 + +const ( + ThreadStatusUnspecified ThreadStatus = 0 + ThreadStatusActive ThreadStatus = 1 + ThreadStatusArchived ThreadStatus = 2 +) + +func ParseThreadStatus(value int16) (ThreadStatus, error) { + switch ThreadStatus(value) { + case ThreadStatusActive, ThreadStatusArchived: + return ThreadStatus(value), nil + default: + return ThreadStatusUnspecified, errors.New("invalid thread status") + } +} + +type Thread struct { + ID uuid.UUID + Participants []Participant + Status ThreadStatus + CreatedAt time.Time + UpdatedAt time.Time +} + +type Participant struct { + ID uuid.UUID + JoinedAt time.Time +} + +type Message struct { + ID uuid.UUID + ThreadID uuid.UUID + SenderID uuid.UUID + Body string + FileIDs []uuid.UUID + CreatedAt time.Time +} + +type ThreadCursor struct { + UpdatedAt time.Time + ThreadID uuid.UUID +} + +type MessageCursor struct { + CreatedAt time.Time + MessageID uuid.UUID +} + +type ThreadListResult struct { + Threads []Thread + NextCursor *ThreadCursor +} + +type MessageListResult struct { + Messages []Message + NextCursor *MessageCursor +} diff --git a/migrations/0001_init.sql b/migrations/0001_init.sql new file mode 100644 index 0000000..d48b8c7 --- /dev/null +++ b/migrations/0001_init.sql @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE threads ( + id UUID PRIMARY KEY, + status SMALLINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE thread_participants ( + thread_id UUID NOT NULL REFERENCES threads(id) ON DELETE CASCADE, + participant_id UUID NOT NULL, + joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (thread_id, participant_id) +); + +CREATE TABLE messages ( + id UUID PRIMARY KEY, + thread_id UUID NOT NULL REFERENCES threads(id) ON DELETE CASCADE, + sender_id UUID NOT NULL, + body TEXT NOT NULL, + file_ids TEXT[] NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE message_recipients ( + message_id UUID NOT NULL REFERENCES messages(id) ON DELETE CASCADE, + thread_id UUID NOT NULL REFERENCES threads(id) ON DELETE CASCADE, + participant_id UUID NOT NULL, + acked_at TIMESTAMPTZ, + PRIMARY KEY (message_id, participant_id) +); + +CREATE INDEX idx_thread_participants_participant ON thread_participants (participant_id, thread_id); +CREATE INDEX idx_messages_thread_created_at ON messages (thread_id, created_at, id); +CREATE INDEX idx_message_recipients_participant_ack ON message_recipients (participant_id, acked_at, message_id); +CREATE INDEX idx_message_recipients_thread_participant ON message_recipients (thread_id, participant_id); diff --git a/migrations/embed.go b/migrations/embed.go new file mode 100644 index 0000000..0fcc1ee --- /dev/null +++ b/migrations/embed.go @@ -0,0 +1,8 @@ +package migrations + +import "embed" + +// Files exposes embedded SQL migrations. +// +//go:embed *.sql +var Files embed.FS From 48c60b21a4f716abd30065f72f0fc0b7733ca906 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Fri, 13 Mar 2026 01:46:19 +0000 Subject: [PATCH 2/2] fix: address threads review feedback --- Dockerfile | 10 +++++ cmd/threads/main.go | 2 +- internal/server/converter.go | 16 +++---- internal/server/server.go | 8 ++-- internal/store/messages.go | 87 +++++++++++------------------------- internal/store/pagination.go | 45 ++----------------- internal/store/store.go | 27 +++++++++++ internal/store/threads.go | 15 ++++--- migrations/0001_init.sql | 5 --- 9 files changed, 89 insertions(+), 126 deletions(-) diff --git a/Dockerfile b/Dockerfile index d2be725..2e7954b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,12 +7,22 @@ ARG TARGETARCH WORKDIR /src +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -sSL https://github.com/bufbuild/buf/releases/download/v1.64.0/buf-Linux-x86_64.tar.gz \ + | tar -xzf - -C /usr/local --strip-components=1 buf/bin/buf + COPY go.mod go.sum ./ RUN --mount=type=cache,target=/go/pkg/mod \ go mod download COPY . . +RUN buf generate buf.build/agynio/api --path agynio/api/threads/v1 --path agynio/api/notifications/v1 + RUN --mount=type=cache,target=/go/pkg/mod \ CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ go build -trimpath -ldflags "-s -w" -o /out/threads ./cmd/threads diff --git a/cmd/threads/main.go b/cmd/threads/main.go index 66d7113..72a51ed 100644 --- a/cmd/threads/main.go +++ b/cmd/threads/main.go @@ -52,7 +52,7 @@ func run() error { return fmt.Errorf("apply migrations: %w", err) } - notificationsConn, err := grpc.DialContext(ctx, cfg.NotificationsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + notificationsConn, err := grpc.NewClient(cfg.NotificationsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return fmt.Errorf("dial notifications: %w", err) } diff --git a/internal/server/converter.go b/internal/server/converter.go index bd28c47..8f6770e 100644 --- a/internal/server/converter.go +++ b/internal/server/converter.go @@ -1,6 +1,8 @@ package server import ( + "fmt" + threadsv1 "github.com/agynio/threads/gen/go/agynio/api/threads/v1" "google.golang.org/protobuf/types/known/timestamppb" @@ -14,13 +16,11 @@ func toProtoThread(thread store.Thread) *threadsv1.Thread { CreatedAt: timestamppb.New(thread.CreatedAt), UpdatedAt: timestamppb.New(thread.UpdatedAt), } - if len(thread.Participants) > 0 { - protoThread.Participants = make([]*threadsv1.Participant, len(thread.Participants)) - for i, participant := range thread.Participants { - protoThread.Participants[i] = &threadsv1.Participant{ - Id: participant.ID.String(), - JoinedAt: timestamppb.New(participant.JoinedAt), - } + protoThread.Participants = make([]*threadsv1.Participant, len(thread.Participants)) + for i, participant := range thread.Participants { + protoThread.Participants[i] = &threadsv1.Participant{ + Id: participant.ID.String(), + JoinedAt: timestamppb.New(participant.JoinedAt), } } return protoThread @@ -48,6 +48,6 @@ func toProtoThreadStatus(status store.ThreadStatus) threadsv1.ThreadStatus { case store.ThreadStatusArchived: return threadsv1.ThreadStatus_THREAD_STATUS_ARCHIVED default: - return threadsv1.ThreadStatus_THREAD_STATUS_UNSPECIFIED + panic(fmt.Sprintf("unexpected thread status: %d", status)) } } diff --git a/internal/server/server.go b/internal/server/server.go index 4d3948a..0f0a176 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -151,7 +151,7 @@ func (s *Server) GetMessages(ctx context.Context, req *threadsv1.GetMessagesRequ } var cursor *store.MessageCursor if token := req.GetPageToken(); token != "" { - tokenID, tokenCursor, err := store.DecodeThreadMessagePageToken(token) + tokenID, tokenCursor, err := store.DecodeMessagePageToken(token) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid page_token: %v", err) } @@ -170,7 +170,7 @@ func (s *Server) GetMessages(ctx context.Context, req *threadsv1.GetMessagesRequ resp.Messages[i] = toProtoMessage(message) } if result.NextCursor != nil { - token, err := store.EncodeThreadMessagePageToken(threadID, *result.NextCursor) + token, err := store.EncodeMessagePageToken(threadID, *result.NextCursor) if err != nil { return nil, status.Errorf(codes.Internal, "encode page token: %v", err) } @@ -186,7 +186,7 @@ func (s *Server) GetUnackedMessages(ctx context.Context, req *threadsv1.GetUnack } var cursor *store.MessageCursor if token := req.GetPageToken(); token != "" { - tokenID, tokenCursor, err := store.DecodeUnackedMessagePageToken(token) + tokenID, tokenCursor, err := store.DecodeMessagePageToken(token) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid page_token: %v", err) } @@ -205,7 +205,7 @@ func (s *Server) GetUnackedMessages(ctx context.Context, req *threadsv1.GetUnack resp.Messages[i] = toProtoMessage(message) } if result.NextCursor != nil { - token, err := store.EncodeUnackedMessagePageToken(participantID, *result.NextCursor) + token, err := store.EncodeMessagePageToken(participantID, *result.NextCursor) if err != nil { return nil, status.Errorf(codes.Internal, "encode page token: %v", err) } diff --git a/internal/store/messages.go b/internal/store/messages.go index d0c1847..9d20bac 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -3,6 +3,7 @@ package store import ( "context" "fmt" + "math" "strings" "time" @@ -95,26 +96,8 @@ func loadRecipients(ctx context.Context, q queryer, threadID, senderID uuid.UUID return recipients, nil } -func (s *Store) ListMessages(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *MessageCursor) (MessageListResult, error) { - if err := ensureThreadExists(ctx, s.pool, threadID); err != nil { - return MessageListResult{}, err - } - limit := normalizePageSize(pageSize) - query := strings.Builder{} - query.WriteString(`SELECT id, thread_id, sender_id, body, file_ids, created_at - FROM messages - WHERE thread_id = $1`) - args := []any{threadID} - paramIndex := 2 - if cursor != nil { - query.WriteString(fmt.Sprintf(" AND (created_at, id) > ($%d, $%d)", paramIndex, paramIndex+1)) - args = append(args, cursor.CreatedAt, cursor.MessageID) - paramIndex += 2 - } - query.WriteString(fmt.Sprintf(" ORDER BY created_at ASC, id ASC LIMIT $%d", paramIndex)) - args = append(args, int(limit)+1) - - rows, err := s.pool.Query(ctx, query.String(), args...) +func (s *Store) scanMessagePage(ctx context.Context, query string, args []any, limit int32) (MessageListResult, error) { + rows, err := s.pool.Query(ctx, query, args...) if err != nil { return MessageListResult{}, err } @@ -155,6 +138,28 @@ func (s *Store) ListMessages(ctx context.Context, threadID uuid.UUID, pageSize i return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil } +func (s *Store) ListMessages(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *MessageCursor) (MessageListResult, error) { + if err := ensureThreadExists(ctx, s.pool, threadID); err != nil { + return MessageListResult{}, err + } + limit := normalizePageSize(pageSize) + query := strings.Builder{} + query.WriteString(`SELECT id, thread_id, sender_id, body, file_ids, created_at + FROM messages + WHERE thread_id = $1`) + args := []any{threadID} + paramIndex := 2 + if cursor != nil { + query.WriteString(fmt.Sprintf(" AND (created_at, id) > ($%d, $%d)", paramIndex, paramIndex+1)) + args = append(args, cursor.CreatedAt, cursor.MessageID) + paramIndex += 2 + } + query.WriteString(fmt.Sprintf(" ORDER BY created_at ASC, id ASC LIMIT $%d", paramIndex)) + args = append(args, int(limit)+1) + + return s.scanMessagePage(ctx, query.String(), args, limit) +} + func (s *Store) ListUnackedMessages(ctx context.Context, participantID uuid.UUID, pageSize int32, cursor *MessageCursor) (MessageListResult, error) { limit := normalizePageSize(pageSize) query := strings.Builder{} @@ -172,45 +177,7 @@ func (s *Store) ListUnackedMessages(ctx context.Context, participantID uuid.UUID query.WriteString(fmt.Sprintf(" ORDER BY m.created_at ASC, m.id ASC LIMIT $%d", paramIndex)) args = append(args, int(limit)+1) - rows, err := s.pool.Query(ctx, query.String(), args...) - if err != nil { - return MessageListResult{}, err - } - defer rows.Close() - - messages := make([]Message, 0, limit) - var ( - nextCursor *MessageCursor - lastID uuid.UUID - lastTime time.Time - hasMore bool - ) - for rows.Next() { - var msg Message - var fileIDs []string - if err := rows.Scan(&msg.ID, &msg.ThreadID, &msg.SenderID, &msg.Body, &fileIDs, &msg.CreatedAt); err != nil { - return MessageListResult{}, err - } - if int32(len(messages)) == limit { - hasMore = true - break - } - parsedIDs, err := stringsToUUIDs(fileIDs) - if err != nil { - return MessageListResult{}, fmt.Errorf("parse file ids: %w", err) - } - msg.FileIDs = parsedIDs - messages = append(messages, msg) - lastID = msg.ID - lastTime = msg.CreatedAt - } - if err := rows.Err(); err != nil { - return MessageListResult{}, err - } - if hasMore { - nextCursor = &MessageCursor{CreatedAt: lastTime, MessageID: lastID} - } - return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil + return s.scanMessagePage(ctx, query.String(), args, limit) } func (s *Store) AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) { @@ -221,7 +188,7 @@ func (s *Store) AckMessages(ctx context.Context, participantID uuid.UUID, messag return 0, err } count := cmd.RowsAffected() - if count > int64(^uint32(0)>>1) { + if count > math.MaxInt32 { return 0, fmt.Errorf("acked count overflow: %d", count) } return int32(count), nil diff --git a/internal/store/pagination.go b/internal/store/pagination.go index 673c5c1..f46aaac 100644 --- a/internal/store/pagination.go +++ b/internal/store/pagination.go @@ -76,9 +76,9 @@ type messagePageToken struct { MessageID string `json:"message_id"` } -func EncodeThreadMessagePageToken(threadID uuid.UUID, cursor MessageCursor) (string, error) { +func EncodeMessagePageToken(ownerID uuid.UUID, cursor MessageCursor) (string, error) { payload := messagePageToken{ - OwnerID: threadID.String(), + OwnerID: ownerID.String(), CreatedAtNanos: cursor.CreatedAt.UnixNano(), MessageID: cursor.MessageID.String(), } @@ -89,46 +89,7 @@ func EncodeThreadMessagePageToken(threadID uuid.UUID, cursor MessageCursor) (str return base64.RawURLEncoding.EncodeToString(buf), nil } -func DecodeThreadMessagePageToken(token string) (uuid.UUID, MessageCursor, error) { - if token == "" { - return uuid.UUID{}, MessageCursor{}, errors.New("empty token") - } - data, err := base64.RawURLEncoding.DecodeString(token) - if err != nil { - return uuid.UUID{}, MessageCursor{}, fmt.Errorf("decode token: %w", err) - } - var payload messagePageToken - if err := json.Unmarshal(data, &payload); err != nil { - return uuid.UUID{}, MessageCursor{}, fmt.Errorf("unmarshal token: %w", err) - } - ownerID, err := uuid.Parse(payload.OwnerID) - if err != nil { - return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse owner id: %w", err) - } - messageID, err := uuid.Parse(payload.MessageID) - if err != nil { - return uuid.UUID{}, MessageCursor{}, fmt.Errorf("parse message id: %w", err) - } - return ownerID, MessageCursor{ - CreatedAt: time.Unix(0, payload.CreatedAtNanos).UTC(), - MessageID: messageID, - }, nil -} - -func EncodeUnackedMessagePageToken(participantID uuid.UUID, cursor MessageCursor) (string, error) { - payload := messagePageToken{ - OwnerID: participantID.String(), - CreatedAtNanos: cursor.CreatedAt.UnixNano(), - MessageID: cursor.MessageID.String(), - } - buf, err := json.Marshal(payload) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(buf), nil -} - -func DecodeUnackedMessagePageToken(token string) (uuid.UUID, MessageCursor, error) { +func DecodeMessagePageToken(token string) (uuid.UUID, MessageCursor, error) { if token == "" { return uuid.UUID{}, MessageCursor{}, errors.New("empty token") } diff --git a/internal/store/store.go b/internal/store/store.go index cdd84a1..4ba2ff1 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" ) @@ -104,6 +105,32 @@ func loadParticipants(ctx context.Context, q queryer, threadID uuid.UUID) ([]Par return participants, nil } +func loadParticipantsByThreadIDs(ctx context.Context, q queryer, threadIDs []uuid.UUID) (map[uuid.UUID][]Participant, error) { + participantsByThread := make(map[uuid.UUID][]Participant) + if len(threadIDs) == 0 { + return participantsByThread, nil + } + threadIDArray := pgtype.FlatArray[uuid.UUID](threadIDs) + rows, err := q.Query(ctx, `SELECT thread_id, participant_id, joined_at FROM thread_participants WHERE thread_id = ANY($1) ORDER BY thread_id ASC, joined_at ASC, participant_id ASC`, threadIDArray) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var threadID uuid.UUID + var participant Participant + if err := rows.Scan(&threadID, &participant.ID, &participant.JoinedAt); err != nil { + return nil, err + } + participantsByThread[threadID] = append(participantsByThread[threadID], participant) + } + if err := rows.Err(); err != nil { + return nil, err + } + return participantsByThread, nil +} + func ensureThreadExists(ctx context.Context, q queryer, threadID uuid.UUID) error { var exists bool if err := q.QueryRow(ctx, `SELECT EXISTS (SELECT 1 FROM threads WHERE id = $1)`, threadID).Scan(&exists); err != nil { diff --git a/internal/store/threads.go b/internal/store/threads.go index 0ad709b..fbf3922 100644 --- a/internal/store/threads.go +++ b/internal/store/threads.go @@ -167,13 +167,16 @@ func (s *Store) ListThreads(ctx context.Context, participantID uuid.UUID, pageSi if hasMore { nextCursor = &ThreadCursor{UpdatedAt: lastTime, ThreadID: lastID} } - + threadIDs := make([]uuid.UUID, len(threads)) + for i, thread := range threads { + threadIDs[i] = thread.ID + } + participantsByThread, err := loadParticipantsByThreadIDs(ctx, s.pool, threadIDs) + if err != nil { + return ThreadListResult{}, err + } for i := range threads { - participants, err := loadParticipants(ctx, s.pool, threads[i].ID) - if err != nil { - return ThreadListResult{}, err - } - threads[i].Participants = participants + threads[i].Participants = participantsByThread[threads[i].ID] } return ThreadListResult{Threads: threads, NextCursor: nextCursor}, nil diff --git a/migrations/0001_init.sql b/migrations/0001_init.sql index d48b8c7..351bf74 100644 --- a/migrations/0001_init.sql +++ b/migrations/0001_init.sql @@ -1,8 +1,3 @@ -CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - CREATE TABLE threads ( id UUID PRIMARY KEY, status SMALLINT NOT NULL,