add Cloudflare R2 storage integration and update bug report handling
Some checks failed
Build & Publish Docker Image / build-and-push (push) Failing after 11s

This commit is contained in:
Flavio Fois
2026-05-27 21:35:26 +02:00
parent e6d663f4f2
commit 3ec7bb5222
17 changed files with 841 additions and 54 deletions

View File

@@ -2,6 +2,7 @@ package config
import (
"os"
"regexp"
"strconv"
"strings"
"sync"
@@ -19,16 +20,28 @@ type RateLimitConfig struct {
AuthBanDur time.Duration
}
type R2Config struct {
AccountID string
AccessKeyID string
SecretAccessKey string
BucketName string
Region string
Endpoint string
}
type Config struct {
Port string
DSN string
Database string
APIKey string
AdminKey string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime int
RateLimit RateLimitConfig
Port string
DSN string
Database string
APIKey string
AdminKey string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime int
UpdatesEnabled bool
UseS3CompatibleStorage bool
RateLimit RateLimitConfig
R2 R2Config
}
var (
@@ -84,20 +97,39 @@ func load() *Config {
if dbName == "" {
panic("DATABASE_NAME environment variable is required")
}
dbNameRegex := regexp.MustCompile("^[a-zA-Z0-9_]+$")
// Test the regex against the dbName, otherwise panic to prevent potential SQL injection
validDbName, err := regexp.Match(dbNameRegex.String(), []byte(dbName))
if err != nil {
panic("failed to validate database name: " + err.Error())
}
if !validDbName {
panic("invalid database name: must match regex " + dbNameRegex.String())
}
if os.Getenv("DB_DSN") == "" {
panic("DB_DSN environment variable is required")
}
return &Config{
Port: port,
DSN: os.Getenv("DB_DSN"),
Database: dbName,
APIKey: apiKey,
AdminKey: adminKey,
MaxOpenConns: maxOpenConns,
MaxIdleConns: maxIdleConns,
ConnMaxLifetime: connMaxLifetime,
Port: port,
DSN: os.Getenv("DB_DSN"),
Database: dbName,
APIKey: apiKey,
AdminKey: adminKey,
MaxOpenConns: maxOpenConns,
MaxIdleConns: maxIdleConns,
ConnMaxLifetime: connMaxLifetime,
UpdatesEnabled: strings.ToLower(strings.TrimSpace(os.Getenv("UPDATES_ENABLED"))) == "true",
UseS3CompatibleStorage: strings.ToLower(strings.TrimSpace(os.Getenv("USE_S3_COMPATIBLE_STORAGE"))) == "true",
R2: R2Config{
AccountID: os.Getenv("CF_ACCOUNT_ID"),
AccessKeyID: os.Getenv("CF_R2_ACCESS_KEY_ID"),
SecretAccessKey: os.Getenv("CF_R2_SECRET_ACCESS_KEY"),
BucketName: os.Getenv("CF_R2_BUCKET_NAME"),
Region: envString("CF_R2_REGION", "auto"),
Endpoint: os.Getenv("CF_R2_ENDPOINT"),
},
RateLimit: RateLimitConfig{
UnauthMaxReqs: envInt("RL_UNAUTH_MAX_REQS", 10),
UnauthWindow: envDuration("RL_UNAUTH_WINDOW", 5*time.Minute),
@@ -111,6 +143,13 @@ func load() *Config {
}
}
func envString(key, fallback string) string {
if s := os.Getenv(key); s != "" {
return s
}
return fallback
}
func envInt(key string, fallback int) int {
if s := os.Getenv(key); s != "" {
if n, err := strconv.Atoi(s); err == nil {

View File

@@ -19,5 +19,13 @@
{ "type": "column_not_exists", "table": "user", "column": "enabled" }
]
}
,{
"id": "3_updates",
"sql_file": "3_updates.sql",
"description": "Create update_releases table for API-managed software updates.",
"conditions": [
{ "type": "table_not_exists", "table": "update_releases" }
]
}
]
}

View File

@@ -198,8 +198,8 @@ func ValidateSession(db *sqlx.DB) http.HandlerFunc {
sessionID,
)
if err != nil {
jsonError(w, http.StatusUnauthorized, "invalid session")
log.Fatalf("Database error during session validation: %v", err)
log.Printf("[AUTH] Database error during session validation: %v", err)
jsonError(w, http.StatusInternalServerError, "internal server error")
return
}

View File

@@ -3,6 +3,7 @@ package handlers
import (
"archive/zip"
"bytes"
"context"
"database/sql"
"embed"
"encoding/json"
@@ -21,6 +22,8 @@ import (
"github.com/jmoiron/sqlx"
"emly-api-go/internal/models"
"emly-api-go/internal/storage"
"emly-api-go/internal/timing"
)
//go:embed templates/report.txt.tmpl
@@ -41,12 +44,13 @@ var fileRoles = []struct {
{"config", models.FileRoleConfig, "application/json"},
}
func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc {
func CreateBugReport(db *sqlx.DB, dbName string, s3conn *storage.S3Connector) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(32 << 20); err != nil {
jsonError(w, http.StatusBadRequest, "invalid multipart form: "+err.Error())
return
}
timing.Mark(r.Context(), "parse_form")
name := r.FormValue("name")
email := r.FormValue("email")
@@ -84,6 +88,7 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_insert_report")
reportID, err := result.LastInsertId()
if err != nil {
@@ -120,7 +125,7 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc {
log.Printf("[BUGREPORT] File uploaded: role=%s size=%d bytes", fr.role, len(data))
_, err = db.ExecContext(r.Context(),
fileResult, err := db.ExecContext(r.Context(),
fmt.Sprintf("INSERT INTO %s.bug_report_files (report_id, file_role, filename, mime_type, file_size, data) VALUES (?, ?, ?, ?, ?, ?)", dbName),
reportID, fr.role, filename, mimeType, len(data), data,
)
@@ -128,6 +133,25 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_insert_file_"+string(fr.role))
if s3conn != nil {
fileID, err := fileResult.LastInsertId()
if err != nil {
log.Printf("[S3] could not get file insert id for report %d role %s: %v", reportID, fr.role, err)
} else {
s3Key := fmt.Sprintf("emly-api-files/bug-reports/%d/files/%s", reportID, filename)
if _, err := s3conn.UploadFile(
context.Background(), s3Key,
bytes.NewReader(data), mimeType,
map[string]string{"filename": filename, "id": strconv.FormatInt(fileID, 10)},
); err != nil {
log.Printf("[S3] upload failed for key %s: %v", s3Key, err)
} else {
timing.Mark(r.Context(), "s3_upload_file_"+string(fr.role))
}
}
}
}
log.Printf("[BUGREPORT] Created successfully with id=%d", reportID)
@@ -177,6 +201,7 @@ func GetAllBugReports(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_count")
mainQuery := fmt.Sprintf(`
SELECT br.*, COUNT(bf.id) as file_count
@@ -193,6 +218,7 @@ func GetAllBugReports(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_select")
jsonOK(w, map[string]interface{}{
"data": reports,
@@ -280,7 +306,7 @@ func GetReportFilesByReportID(db *sqlx.DB, dbName string) http.HandlerFunc {
}
}
func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc {
func GetBugReportZipByID(db *sqlx.DB, dbName string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
if id == "" {
@@ -298,12 +324,14 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_fetch_report")
var files []models.BugReportFile
if err := db.SelectContext(r.Context(), &files, fmt.Sprintf("SELECT * FROM %s.bug_report_files WHERE report_id = ?", dbName), id); err != nil {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_fetch_files")
var sysInfoStr string
if len(report.SystemInfo) > 0 && string(report.SystemInfo) != "null" {
@@ -353,6 +381,7 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "zip_build")
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"report-%d.zip\"", report.ID))
@@ -364,7 +393,7 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc {
}
}
func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc {
func GetReportFileByFileID(db *sqlx.DB, dbName string, s3conn *storage.S3Connector) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
reportId := chi.URLParam(r, "id")
if reportId == "" {
@@ -377,6 +406,44 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc {
return
}
var filename string
if err := db.GetContext(r.Context(), &filename, fmt.Sprintf("SELECT filename FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileId); err != nil {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_fetch_filename_by_id")
// Try S3 first.
if s3conn != nil {
s3Key := fmt.Sprintf("emly-api-files/bug-reports/%s/files/%s", reportId, filename)
rc, info, err := s3conn.GetFile(r.Context(), s3Key)
if err == nil {
defer rc.Close()
timing.Mark(r.Context(), "s3_hit")
log.Println("[S3] cache hit for key", s3Key)
mimeType := info.ContentType
if mimeType == "" {
mimeType = "application/octet-stream"
}
filename := info.Metadata["filename"]
if filename == "" {
filename = fileId
}
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Disposition", "attachment; filename=\""+filename+"\"")
_, _ = io.Copy(w, rc)
return
}
if storage.IsNotFound(err) {
log.Printf("[S3] file %s not found on s3", fileId)
}
if !storage.IsNotFound(err) {
log.Printf("[S3] unexpected error fetching key %s: %v", s3Key, err)
}
}
// Fallback: query DB.
var file models.BugReportFile
err := db.GetContext(r.Context(), &file, fmt.Sprintf("SELECT filename, mime_type, data FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileId)
if errors.Is(err, sql.ErrNoRows) {
@@ -387,6 +454,25 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc {
jsonError(w, http.StatusInternalServerError, err.Error())
return
}
timing.Mark(r.Context(), "db_select")
// Lazy-upload to S3 so future requests are served from there.
if s3conn != nil {
s3Key := fmt.Sprintf("emly-api-files/bug-reports/%s/files/%s", reportId, fileId)
dataCopy := make([]byte, len(file.Data))
copy(dataCopy, file.Data)
mime := file.MimeType
fname := file.Filename
go func() {
if _, err := s3conn.UploadFile(
context.Background(), s3Key,
bytes.NewReader(dataCopy), mime,
map[string]string{"filename": fname},
); err != nil {
log.Printf("[S3] lazy upload failed for key %s: %v", s3Key, err)
}
}()
}
mimeType := file.MimeType
if mimeType == "" {
@@ -394,10 +480,7 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc {
}
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Disposition", "attachment; filename=\""+file.Filename+"\"")
_, err = w.Write(file.Data)
if err != nil {
return
}
_, _ = w.Write(file.Data)
}
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"fmt"
"log"
"net/http"
"strings"
"time"
"emly-api-go/internal/timing"
)
// Timing is a middleware that measures per-request step durations.
//
// It injects a *timing.Timer into the request context so that handlers can
// record named checkpoints with timing.Mark(r.Context(), "step_name").
// After the handler returns, it logs a single line of the form:
//
// [TIMING] METHOD /path step1=1.2ms step2=18ms total=20ms
//
// Each step duration is measured from the previous checkpoint (or request
// start for the first one), so the values add up to the total.
func Timing(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, t := timing.NewContext(r.Context())
next.ServeHTTP(w, r.WithContext(ctx))
total := time.Since(t.Start)
checkpoints := t.Checkpoints()
if len(checkpoints) == 0 {
// No checkpoints: just log the total so every request is visible.
log.Printf("[TIMING] %s %s total=%s", r.Method, r.URL.Path, round(total))
return
}
parts := make([]string, 0, len(checkpoints)+1)
prev := t.Start
for _, cp := range checkpoints {
parts = append(parts, fmt.Sprintf("%s=%s", cp.Name, round(cp.At.Sub(prev))))
prev = cp.At
}
// Remainder after the last checkpoint.
if tail := total - prev.Sub(t.Start); tail > 0 {
parts = append(parts, fmt.Sprintf("response=%s", round(tail)))
}
parts = append(parts, fmt.Sprintf("total=%s", round(total)))
log.Printf("[TIMING] %s %s %s", r.Method, r.URL.Path, strings.Join(parts, " "))
})
}
func round(d time.Duration) string {
switch {
case d < time.Microsecond:
return fmt.Sprintf("%dns", d.Nanoseconds())
case d < time.Millisecond:
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1e3)
default:
return fmt.Sprintf("%.2fms", float64(d.Nanoseconds())/1e6)
}
}

View File

@@ -1,20 +1,18 @@
package routes
import (
v2 "emly-api-go/internal/routes/v2"
"net/http"
v1 "emly-api-go/internal/routes/v1"
v2 "emly-api-go/internal/routes/v2"
"emly-api-go/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/jmoiron/sqlx"
)
// RegisterAll mounts every versioned API onto the root router.
// To add a new API version, create internal/routes/v2 and add:
//
// r.Mount("/v2", v2.NewRouter(db))
func RegisterAll(r chi.Router, db *sqlx.DB) {
func RegisterAll(r chi.Router, db *sqlx.DB, s3conn *storage.S3Connector) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("emly-api-go"))
if err != nil {
@@ -22,6 +20,6 @@ func RegisterAll(r chi.Router, db *sqlx.DB) {
}
})
r.Mount("/v1", v1.NewRouter(db))
r.Mount("/v2", v2.NewRouter(db))
r.Mount("/v1", v1.NewRouter(db, s3conn))
r.Mount("/v2", v2.NewRouter(db, s3conn))
}

View File

@@ -5,13 +5,14 @@ import (
"time"
"emly-api-go/internal/handlers"
"emly-api-go/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/go-chi/httprate"
"github.com/jmoiron/sqlx"
)
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string, s3conn *storage.S3Connector) {
r.Route("/bug-reports", func(r chi.Router) {
// API key only: submit a report and check count
r.Group(func(r chi.Router) {
@@ -19,7 +20,7 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
r.Use(httprate.LimitByIP(30, time.Minute))
r.Get("/count", handlers.GetReportsCount(db, dbName))
r.Post("/", handlers.CreateBugReport(db, dbName))
r.Post("/", handlers.CreateBugReport(db, dbName, s3conn))
})
// API key + admin key: full read/write access
@@ -32,10 +33,10 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
r.Get("/{id}", handlers.GetBugReportByID(db, dbName))
r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName))
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName))
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName))
r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName))
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName, s3conn))
r.Get("/{id}/download", handlers.GetBugReportZipByID(db, dbName))
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName))
r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName))
})
})
}
}

View File

@@ -6,15 +6,14 @@ import (
"emly-api-go/internal/config"
"emly-api-go/internal/handlers"
"emly-api-go/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/jmoiron/sqlx"
)
// NewRouter returns a chi.Router with all /v1 routes mounted.
// Add new API versions by creating an analogous package (e.g. v2) and
// mounting it alongside this one in internal/routes/routes.go.
func NewRouter(db *sqlx.DB) http.Handler {
func NewRouter(db *sqlx.DB, s3conn *storage.S3Connector) http.Handler {
r := chi.NewRouter()
rl := emlyMiddleware.NewRateLimiter(config.Load())
@@ -33,7 +32,7 @@ func NewRouter(db *sqlx.DB) http.Handler {
r.Route("/api", func(r chi.Router) {
registerAdmin(r, db)
registerBugReports(r, db, config.Load().Database)
registerBugReports(r, db, config.Load().Database, s3conn)
})
return r

View File

@@ -5,13 +5,14 @@ import (
"time"
"emly-api-go/internal/handlers"
"emly-api-go/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/go-chi/httprate"
"github.com/jmoiron/sqlx"
)
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string, s3conn *storage.S3Connector) {
r.Route("/bug-report", func(r chi.Router) {
// API key only: submit a report and check count
r.Group(func(r chi.Router) {
@@ -19,7 +20,7 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
r.Use(httprate.LimitByIP(30, time.Minute))
r.Get("/count", handlers.GetReportsCount(db, dbName))
r.Post("/", handlers.CreateBugReport(db, dbName))
r.Post("/", handlers.CreateBugReport(db, dbName, s3conn))
})
// API key + admin key: full read/write access
@@ -32,8 +33,8 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
r.Get("/{id}", handlers.GetBugReportByID(db, dbName))
r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName))
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName))
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName))
r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName))
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName, s3conn))
r.Get("/{id}/download", handlers.GetBugReportZipByID(db, dbName))
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName))
r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName))
})

View File

@@ -6,15 +6,14 @@ import (
"emly-api-go/internal/config"
"emly-api-go/internal/handlers"
"emly-api-go/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/jmoiron/sqlx"
)
// NewRouter returns a chi.Router with all /v1 routes mounted.
// Add new API versions by creating an analogous package (e.g. v2) and
// mounting it alongside this one in internal/routes/routes.go.
func NewRouter(db *sqlx.DB) http.Handler {
// NewRouter returns a chi.Router with all /v2 routes mounted.
func NewRouter(db *sqlx.DB, s3conn *storage.S3Connector) http.Handler {
r := chi.NewRouter()
rl := emlyMiddleware.NewRateLimiter(config.Load())
@@ -33,7 +32,8 @@ func NewRouter(db *sqlx.DB) http.Handler {
r.Route("/api", func(r chi.Router) {
registerAdmin(r, db)
registerBugReports(r, db, config.Load().Database)
registerBugReports(r, db, config.Load().Database, s3conn)
registerUpdates(r, db, config.Load())
})
return r

View File

@@ -0,0 +1,122 @@
package storage
import (
"bytes"
"context"
"database/sql"
"emly-api-go/internal/models"
"errors"
"fmt"
"log"
"sync"
"time"
"github.com/jmoiron/sqlx"
)
func MigrateReportFilesToS3(db *sqlx.DB, s3conn *S3Connector, dbName string) error {
var wg sync.WaitGroup
errCh := make(chan error, 128) // buffer ragionevole
reportsRows, err := db.Query("SELECT id, created_at, updated_at FROM emly_bugreports_dev.bug_reports ORDER BY created_at DESC")
if err != nil {
return err
}
defer reportsRows.Close()
var totalReports, totalFiles, skipped, uploaded int
for reportsRows.Next() {
var reportId int
var createdAt, updatedAt time.Time
if err := reportsRows.Scan(
&reportId, &createdAt, &updatedAt,
); err != nil {
return err
}
totalReports++
log.Printf("[migrate] processing report %d", reportId)
filesRows, err := db.Query(
"SELECT id, report_id, filename FROM emly_bugreports_dev.bug_report_files WHERE report_id = ?",
reportId,
)
if err != nil {
return err
}
for filesRows.Next() {
var fileID int
var fileReportID int
var fileName string
if err := filesRows.Scan(&fileID, &fileReportID, &fileName); err != nil {
filesRows.Close()
return err
}
var file models.BugReportFile
err := db.GetContext(context.Background(), &file, fmt.Sprintf("SELECT filename, mime_type, data FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileID)
if errors.Is(err, sql.ErrNoRows) {
log.Printf("[migrate] report %d / file %d: not found in bug_report_files, skipping", reportId, fileID)
skipped++
continue
}
if err != nil {
filesRows.Close()
return fmt.Errorf("report %d / file %d: %w", reportId, fileID, err)
}
if s3conn != nil {
s3Key := fmt.Sprintf("emly-api-files/bug-reports/%d/files/%s", reportId, fileName)
dataCopy := make([]byte, len(file.Data))
copy(dataCopy, file.Data)
mime := file.MimeType
fname := file.Filename
totalFiles++
log.Printf("[migrate] report %d / file %d (%s, %d bytes): uploading to s3://%s", reportId, fileID, fname, len(dataCopy), s3Key)
wg.Add(1)
go func(key, mimeType, filename string, payload []byte, rid, fid int) {
defer wg.Done()
_, upErr := s3conn.UploadFile(
context.Background(),
key,
bytes.NewReader(payload),
mimeType,
map[string]string{"filename": filename},
)
if upErr != nil {
errCh <- fmt.Errorf("report %d / file %d (%s): %w", rid, fid, key, upErr)
log.Printf("[migrate] [ERROR] upload failed for s3://%s: %v", key, upErr)
return
}
log.Printf("[migrate] upload complete: s3://%s", key)
}(s3Key, mime, fname, dataCopy, reportId, fileID)
uploaded++
}
}
if err := filesRows.Close(); err != nil {
return err
}
}
wg.Wait()
close(errCh)
var uploadErrCount int
for e := range errCh {
uploadErrCount++
log.Printf("[migrate] [ERROR] %v", e)
}
log.Printf("[migrate] done — reports: %d, files queued: %d, skipped: %d, upload errors: %d",
totalReports, uploaded, skipped, uploadErrCount)
if uploadErrCount > 0 {
return fmt.Errorf("migration completed with %d upload errors", uploadErrCount)
}
return nil
}

View File

@@ -0,0 +1,322 @@
package storage
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"emly-api-go/internal/config"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
type S3Connector struct {
client *s3.Client
uploader *manager.Uploader
downloader *manager.Downloader
bucket string
}
type FileInfo struct {
Key string
Size int64
LastModified time.Time
ETag string
ContentType string
Metadata map[string]string
}
// IsNotFound reports whether err represents a missing object (404 / NoSuchKey).
func IsNotFound(err error) bool {
if err == nil {
return false
}
var nsk *types.NoSuchKey
if errors.As(err, &nsk) {
return true
}
// Fallback for S3-compatible stores (e.g. Cloudflare R2) that surface
// the error code via the generic APIError interface.
var ae interface{ ErrorCode() string }
if errors.As(err, &ae) {
switch ae.ErrorCode() {
case "NoSuchKey", "NotFound", "404":
return true
}
}
return false
}
type FolderInfo struct {
Prefix string
}
func NewCloudflareR2Connector(cfg config.R2Config) (*S3Connector, error) {
if cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" || cfg.BucketName == "" {
return nil, fmt.Errorf("missing required R2 config fields (CF_R2_ACCESS_KEY_ID, CF_R2_SECRET_ACCESS_KEY, CF_R2_BUCKET_NAME)")
}
endpoint := cfg.Endpoint
if endpoint == "" {
if cfg.AccountID == "" {
return nil, fmt.Errorf("either CF_R2_ENDPOINT or CF_ACCOUNT_ID must be set")
}
endpoint = fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.AccountID)
}
region := cfg.Region
if region == "" {
region = "auto"
}
awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")),
awsconfig.WithRegion(region),
)
if err != nil {
return nil, fmt.Errorf("failed to load R2 config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
})
return &S3Connector{
client: client,
uploader: manager.NewUploader(client),
downloader: manager.NewDownloader(client),
bucket: cfg.BucketName,
}, nil
}
// Ping verifies connectivity by calling HeadBucket on the configured bucket.
func (c *S3Connector) Ping(ctx context.Context) error {
_, err := c.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(c.bucket),
})
if err != nil {
return fmt.Errorf("R2 ping failed for bucket %q: %w", c.bucket, err)
}
return nil
}
// UploadFile uploads body to key in the bucket and returns the public URL.
// metadata is optional; pass nil if not needed.
func (c *S3Connector) UploadFile(ctx context.Context, key string, body io.Reader, contentType string, metadata map[string]string) (string, error) {
result, err := c.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(c.bucket),
Key: aws.String(key),
Body: body,
ContentType: aws.String(contentType),
Metadata: metadata,
})
if err != nil {
return "", fmt.Errorf("upload %q: %w", key, err)
}
return result.Location, nil
}
// GetFile returns the object body at key. Caller must close it.
func (c *S3Connector) GetFile(ctx context.Context, key string) (io.ReadCloser, *FileInfo, error) {
out, err := c.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(c.bucket),
Key: aws.String(key),
})
if err != nil {
return nil, nil, fmt.Errorf("get %q: %w", key, err)
}
info := &FileInfo{
Key: key,
Size: aws.ToInt64(out.ContentLength),
ETag: strings.Trim(aws.ToString(out.ETag), `"`),
ContentType: aws.ToString(out.ContentType),
Metadata: out.Metadata,
}
if out.LastModified != nil {
info.LastModified = *out.LastModified
}
return out.Body, info, nil
}
// DownloadFile downloads key into dst and returns bytes written.
func (c *S3Connector) DownloadFile(ctx context.Context, key string, dst io.WriterAt) (int64, error) {
n, err := c.downloader.Download(ctx, dst, &s3.GetObjectInput{
Bucket: aws.String(c.bucket),
Key: aws.String(key),
})
if err != nil {
return 0, fmt.Errorf("download %q: %w", key, err)
}
return n, nil
}
// DeleteFile deletes the object at key.
func (c *S3Connector) DeleteFile(ctx context.Context, key string) error {
_, err := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(c.bucket),
Key: aws.String(key),
})
if err != nil {
return fmt.Errorf("delete %q: %w", key, err)
}
return nil
}
// DeleteFiles deletes up to 1000 objects in one request.
func (c *S3Connector) DeleteFiles(ctx context.Context, keys []string) error {
if len(keys) == 0 {
return nil
}
objects := make([]types.ObjectIdentifier, len(keys))
for i, k := range keys {
objects[i] = types.ObjectIdentifier{Key: aws.String(k)}
}
_, err := c.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(c.bucket),
Delete: &types.Delete{Objects: objects, Quiet: aws.Bool(true)},
})
if err != nil {
return fmt.Errorf("batch delete: %w", err)
}
return nil
}
// RenameFile copies src to dst then deletes src (R2 has no native rename).
func (c *S3Connector) RenameFile(ctx context.Context, srcKey, dstKey string) error {
_, err := c.client.CopyObject(ctx, &s3.CopyObjectInput{
Bucket: aws.String(c.bucket),
CopySource: aws.String(c.bucket + "/" + srcKey),
Key: aws.String(dstKey),
})
if err != nil {
return fmt.Errorf("copy %q → %q: %w", srcKey, dstKey, err)
}
return c.DeleteFile(ctx, srcKey)
}
// ListFiles returns all objects directly under prefix (non-recursive).
func (c *S3Connector) ListFiles(ctx context.Context, prefix string) ([]FileInfo, error) {
prefix = normalizePrefix(prefix)
var files []FileInfo
pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{
Bucket: aws.String(c.bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
})
for pager.HasMorePages() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("list files under %q: %w", prefix, err)
}
for _, obj := range page.Contents {
key := aws.ToString(obj.Key)
if strings.HasSuffix(key, "/") {
continue // skip folder placeholders
}
fi := FileInfo{
Key: key,
Size: aws.ToInt64(obj.Size),
ETag: strings.Trim(aws.ToString(obj.ETag), `"`),
}
if obj.LastModified != nil {
fi.LastModified = *obj.LastModified
}
files = append(files, fi)
}
}
return files, nil
}
// ListFolders returns the immediate sub-folders under prefix.
func (c *S3Connector) ListFolders(ctx context.Context, prefix string) ([]FolderInfo, error) {
prefix = normalizePrefix(prefix)
var folders []FolderInfo
pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{
Bucket: aws.String(c.bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
})
for pager.HasMorePages() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("list folders under %q: %w", prefix, err)
}
for _, cp := range page.CommonPrefixes {
folders = append(folders, FolderInfo{Prefix: aws.ToString(cp.Prefix)})
}
}
return folders, nil
}
// CreateFolder writes a zero-byte placeholder object to make the folder visible.
func (c *S3Connector) CreateFolder(ctx context.Context, folderPath string) error {
key := normalizePrefix(folderPath)
if key == "" {
return fmt.Errorf("folder path cannot be empty")
}
_, err := c.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(c.bucket),
Key: aws.String(key),
ContentLength: aws.Int64(0),
})
if err != nil {
return fmt.Errorf("create folder %q: %w", key, err)
}
return nil
}
// DeleteFolder removes all objects under folderPath in batches of 1000.
func (c *S3Connector) DeleteFolder(ctx context.Context, folderPath string) error {
prefix := normalizePrefix(folderPath)
var keys []string
pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{
Bucket: aws.String(c.bucket),
Prefix: aws.String(prefix),
})
for pager.HasMorePages() {
page, err := pager.NextPage(ctx)
if err != nil {
return fmt.Errorf("list for delete %q: %w", prefix, err)
}
for _, obj := range page.Contents {
keys = append(keys, aws.ToString(obj.Key))
}
}
for i := 0; i < len(keys); i += 1000 {
end := i + 1000
if end > len(keys) {
end = len(keys)
}
if err := c.DeleteFiles(ctx, keys[i:end]); err != nil {
return err
}
}
return nil
}
// normalizePrefix ensures prefix ends with "/" (returns "" for root).
func normalizePrefix(p string) string {
p = strings.TrimPrefix(p, "/")
if p == "" {
return ""
}
if !strings.HasSuffix(p, "/") {
p += "/"
}
return p
}

50
internal/timing/timing.go Normal file
View File

@@ -0,0 +1,50 @@
package timing
import (
"context"
"time"
)
type contextKey struct{}
// Checkpoint is a named point in time recorded during request processing.
type Checkpoint struct {
Name string
At time.Time
}
// Timer records the request start time and named checkpoints.
type Timer struct {
Start time.Time
checkpoints []Checkpoint
}
// Mark records a checkpoint with the given name.
func (t *Timer) Mark(name string) {
t.checkpoints = append(t.checkpoints, Checkpoint{Name: name, At: time.Now()})
}
// Checkpoints returns all recorded checkpoints in order.
func (t *Timer) Checkpoints() []Checkpoint {
return t.checkpoints
}
// NewContext attaches a new Timer to ctx and returns both.
func NewContext(ctx context.Context) (context.Context, *Timer) {
t := &Timer{Start: time.Now()}
return context.WithValue(ctx, contextKey{}, t), t
}
// FromContext retrieves the Timer from ctx, or nil if not present.
func FromContext(ctx context.Context) *Timer {
t, _ := ctx.Value(contextKey{}).(*Timer)
return t
}
// Mark records a checkpoint in the Timer stored in ctx, if any.
// It is a no-op when ctx carries no Timer (e.g. in tests).
func Mark(ctx context.Context, name string) {
if t := FromContext(ctx); t != nil {
t.Mark(name)
}
}