172 lines
3.9 KiB
Go
172 lines
3.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"memory/internal/config"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type UploadHandler struct {
|
|
cfg *config.Config
|
|
s3Client *s3.Client
|
|
}
|
|
|
|
func NewUploadHandler(cfg *config.Config) *UploadHandler {
|
|
h := &UploadHandler{cfg: cfg}
|
|
|
|
if cfg.R2AccountID != "" && cfg.R2AccessKeyID != "" {
|
|
// 初始化 R2 客户端
|
|
r2Resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
|
|
return aws.Endpoint{
|
|
URL: fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.R2AccountID),
|
|
}, nil
|
|
})
|
|
|
|
awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
|
|
awsconfig.WithEndpointResolverWithOptions(r2Resolver),
|
|
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
|
cfg.R2AccessKeyID,
|
|
cfg.R2AccessKeySecret,
|
|
"",
|
|
)),
|
|
awsconfig.WithRegion("auto"),
|
|
)
|
|
if err == nil {
|
|
h.s3Client = s3.NewFromConfig(awsCfg)
|
|
}
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
|
file, err := c.FormFile("file")
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "no file uploaded"})
|
|
return
|
|
}
|
|
|
|
// 检查文件类型
|
|
ext := strings.ToLower(filepath.Ext(file.Filename))
|
|
allowedExts := map[string]string{
|
|
".jpg": "image/jpeg",
|
|
".jpeg": "image/jpeg",
|
|
".png": "image/png",
|
|
".gif": "image/gif",
|
|
".webp": "image/webp",
|
|
".mp4": "video/mp4",
|
|
".mov": "video/quicktime",
|
|
}
|
|
|
|
contentType, ok := allowedExts[ext]
|
|
if !ok {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported file type"})
|
|
return
|
|
}
|
|
|
|
// 检查文件大小 (最大 50MB)
|
|
if file.Size > 50*1024*1024 {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "file too large (max 50MB)"})
|
|
return
|
|
}
|
|
|
|
// 生成唯一文件名
|
|
filename := fmt.Sprintf("%s/%s%s",
|
|
time.Now().Format("2006/01"),
|
|
uuid.New().String(),
|
|
ext,
|
|
)
|
|
|
|
// 打开文件
|
|
src, err := file.Open()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read file"})
|
|
return
|
|
}
|
|
defer src.Close()
|
|
|
|
// 优先使用本地存储
|
|
if h.cfg.LocalUploadPath != "" {
|
|
url, err := h.saveLocal(src, filename)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save file"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"url": url,
|
|
"filename": filename,
|
|
})
|
|
return
|
|
}
|
|
|
|
// 使用 R2 存储
|
|
if h.s3Client != nil {
|
|
url, err := h.uploadToR2(src, filename, contentType)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to upload file"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"url": url,
|
|
"filename": filename,
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage not configured"})
|
|
}
|
|
|
|
func (h *UploadHandler) saveLocal(src io.Reader, filename string) (string, error) {
|
|
// 创建目录
|
|
fullPath := filepath.Join(h.cfg.LocalUploadPath, filename)
|
|
dir := filepath.Dir(fullPath)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// 创建文件
|
|
dst, err := os.Create(fullPath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer dst.Close()
|
|
|
|
// 复制内容
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// 返回 URL
|
|
url := fmt.Sprintf("%s/uploads/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename)
|
|
return url, nil
|
|
}
|
|
|
|
func (h *UploadHandler) uploadToR2(src io.Reader, filename, contentType string) (string, error) {
|
|
_, err := h.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
|
|
Bucket: aws.String(h.cfg.R2BucketName),
|
|
Key: aws.String(filename),
|
|
Body: src,
|
|
ContentType: aws.String(contentType),
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(h.cfg.R2PublicURL, "/"), filename)
|
|
return url, nil
|
|
}
|