537 lines
15 KiB
Go
537 lines
15 KiB
Go
package handlers
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"ai-media-hub/backend/models"
|
|
"ai-media-hub/backend/services"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type App struct {
|
|
DB *sql.DB
|
|
DownloadsDir string
|
|
WorkerScript string
|
|
SearchService *services.SearchService
|
|
GeminiService *services.GeminiService
|
|
Hub *Hub
|
|
}
|
|
|
|
type Hub struct {
|
|
clients map[*websocket.Conn]bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewHub() *Hub {
|
|
return &Hub{clients: map[*websocket.Conn]bool{}}
|
|
}
|
|
|
|
func (h *Hub) Broadcast(event string, data any) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
payload, _ := json.Marshal(gin.H{"event": event, "data": data})
|
|
for conn := range h.clients {
|
|
_ = conn.WriteMessage(websocket.TextMessage, payload)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Add(conn *websocket.Conn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.clients[conn] = true
|
|
}
|
|
|
|
func (h *Hub) Remove(conn *websocket.Conn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
delete(h.clients, conn)
|
|
_ = conn.Close()
|
|
}
|
|
|
|
type PreviewResponse struct {
|
|
Title string `json:"title"`
|
|
Thumbnail string `json:"thumbnail"`
|
|
PreviewStreamURL string `json:"previewStreamUrl"`
|
|
Duration string `json:"duration"`
|
|
DurationSeconds int `json:"durationSeconds"`
|
|
StartDefault string `json:"startDefault"`
|
|
EndDefault string `json:"endDefault"`
|
|
Qualities []map[string]any `json:"qualities"`
|
|
}
|
|
|
|
func RegisterRoutes(router *gin.Engine, app *App) {
|
|
router.GET("/healthz", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
router.GET("/ws", app.handleWS)
|
|
router.GET("/api/history/check", app.checkDuplicate)
|
|
router.POST("/api/download/preview", app.previewDownload)
|
|
router.POST("/api/upload", app.uploadFile)
|
|
router.POST("/api/download", app.startDownload)
|
|
router.POST("/api/search", app.searchMedia)
|
|
}
|
|
|
|
func (a *App) handleWS(c *gin.Context) {
|
|
upgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
a.Hub.Add(conn)
|
|
defer a.Hub.Remove(conn)
|
|
|
|
for {
|
|
if _, _, err := conn.ReadMessage(); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *App) checkDuplicate(c *gin.Context) {
|
|
url := strings.TrimSpace(c.Query("url"))
|
|
if url == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
|
return
|
|
}
|
|
record, err := models.FindByURL(a.DB, url)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"exists": record != nil, "record": record})
|
|
}
|
|
|
|
func (a *App) uploadFile(c *gin.Context) {
|
|
file, err := c.FormFile("file")
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "file is required"})
|
|
return
|
|
}
|
|
|
|
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "started", "progress": 5, "filename": file.Filename})
|
|
|
|
safeName := normalizeFilename(file.Filename)
|
|
targetPath := filepath.Join(a.DownloadsDir, safeName)
|
|
if err := c.SaveUploadedFile(file, targetPath); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "completed", "progress": 100, "filename": safeName})
|
|
c.JSON(http.StatusOK, gin.H{"message": "uploaded", "path": targetPath, "filename": safeName})
|
|
}
|
|
|
|
func (a *App) startDownload(c *gin.Context) {
|
|
var req struct {
|
|
URL string `json:"url"`
|
|
Start string `json:"start"`
|
|
End string `json:"end"`
|
|
Quality string `json:"quality"`
|
|
Force bool `json:"force"`
|
|
}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
rec, err := models.FindByURL(a.DB, req.URL)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if rec != nil && !req.Force {
|
|
c.JSON(http.StatusConflict, gin.H{"error": "duplicate url", "record": rec})
|
|
return
|
|
}
|
|
|
|
outputBase := uuid.NewString()
|
|
outputPath := filepath.Join(a.DownloadsDir, outputBase+".mp4")
|
|
recordID, err := models.InsertDownload(a.DB, req.URL, detectSource(req.URL), outputPath, "queued")
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
quality := strings.TrimSpace(req.Quality)
|
|
if quality == "" {
|
|
quality = "best"
|
|
}
|
|
|
|
go a.runDownload(recordID, req.URL, req.Start, req.End, quality, outputPath)
|
|
c.JSON(http.StatusAccepted, gin.H{"message": "download started", "recordId": recordID})
|
|
}
|
|
|
|
func (a *App) previewDownload(c *gin.Context) {
|
|
var req struct {
|
|
URL string `json:"url"`
|
|
}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.URL) == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
|
return
|
|
}
|
|
|
|
cmd := exec.Command("python3", a.WorkerScript, "--mode", "probe", "--url", req.URL)
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": summarizeOutput("download preview probe failed", output, err)})
|
|
return
|
|
}
|
|
|
|
var preview PreviewResponse
|
|
if err := json.Unmarshal(output, &preview); err != nil {
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": summarizeOutput("download preview returned invalid JSON", output, err)})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, preview)
|
|
}
|
|
|
|
func (a *App) runDownload(recordID int64, url, start, end, quality, outputPath string) {
|
|
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "queued", "progress": 0, "url": url})
|
|
cmd := exec.Command("python3", a.WorkerScript, "--url", url, "--start", start, "--end", end, "--quality", quality, "--output", outputPath)
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
|
|
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
|
|
return
|
|
}
|
|
cmd.Stderr = cmd.Stdout
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
|
|
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
|
|
return
|
|
}
|
|
|
|
scanner := bufio.NewScanner(stdout)
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
var msg map[string]any
|
|
if err := json.Unmarshal(line, &msg); err == nil {
|
|
msg["type"] = "download"
|
|
a.Hub.Broadcast("progress", msg)
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 100, "message": err.Error()})
|
|
}
|
|
|
|
status := "completed"
|
|
if err := cmd.Wait(); err != nil {
|
|
status = "failed"
|
|
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 100, "message": err.Error()})
|
|
}
|
|
_ = models.MarkDownloadCompleted(a.DB, recordID, status)
|
|
}
|
|
|
|
func (a *App) searchMedia(c *gin.Context) {
|
|
var req struct {
|
|
Query string `json:"query"`
|
|
Platforms []string `json:"platforms"`
|
|
}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.Query) == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "query is required"})
|
|
return
|
|
}
|
|
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "expanding query with Gemini", "progress": 10})
|
|
queryVariants, _ := a.GeminiService.ExpandQuery(req.Query)
|
|
if len(queryVariants) == 0 {
|
|
queryVariants = []string{req.Query}
|
|
}
|
|
|
|
enabledPlatforms := normalizePlatforms(req.Platforms)
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "searching " + selectedPlatformLabel(enabledPlatforms), "progress": 35})
|
|
results, err := a.SearchService.SearchMedia(queryVariants, enabledPlatforms)
|
|
if err != nil {
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "search failed", "progress": 100, "message": err.Error()})
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if len(results) == 0 {
|
|
warning := "SearXNG returned no renderable results."
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "no renderable search results", "progress": 100, "message": warning})
|
|
c.JSON(http.StatusOK, gin.H{"results": []services.AIRecommendation{}, "warning": warning})
|
|
return
|
|
}
|
|
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "ranking thumbnail candidates", "progress": 55})
|
|
rankQuery := req.Query
|
|
if len(queryVariants) > 0 {
|
|
rankQuery = strings.Join(queryVariants[:min(len(queryVariants), 3)], " ")
|
|
}
|
|
scored := rankSearchResults(rankQuery, results)
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "analyzing all candidate visuals with Gemini Vision", "progress": 75})
|
|
recommended := evaluateAllCandidatesWithGemini(a.GeminiService, req.Query, scored)
|
|
err = nil
|
|
if len(recommended) == 0 {
|
|
err = fmt.Errorf("gemini vision returned no recommended items across all candidate batches")
|
|
}
|
|
if err != nil {
|
|
fallback := make([]services.AIRecommendation, 0, min(20, len(scored)))
|
|
for _, result := range scored[:min(20, len(scored))] {
|
|
fallback = append(fallback, services.AIRecommendation{
|
|
Title: result.Title,
|
|
Link: result.Link,
|
|
ThumbnailURL: result.ThumbnailURL,
|
|
PreviewVideoURL: result.PreviewVideoURL,
|
|
Source: result.Source,
|
|
Reason: "Keyword-ranked result added without extra Gemini vision tokens.",
|
|
Recommended: true,
|
|
})
|
|
}
|
|
warning := err.Error()
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "Gemini Vision fallback to ranked results", "progress": 90, "message": warning})
|
|
c.JSON(http.StatusOK, gin.H{"results": fallback, "warning": warning, "queries": queryVariants})
|
|
return
|
|
}
|
|
|
|
response := gin.H{"results": mergeRecommendations(recommended, scored, 20), "queries": queryVariants}
|
|
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "search complete", "progress": 100})
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
func normalizeFilename(name string) string {
|
|
base := strings.ToLower(strings.TrimSpace(name))
|
|
ext := filepath.Ext(base)
|
|
base = strings.TrimSuffix(base, ext)
|
|
re := regexp.MustCompile(`[^a-z0-9]+`)
|
|
base = strings.Trim(re.ReplaceAllString(base, "-"), "-")
|
|
if base == "" {
|
|
base = fmt.Sprintf("upload-%d", time.Now().Unix())
|
|
}
|
|
if ext == "" {
|
|
ext = ".bin"
|
|
}
|
|
return base + ext
|
|
}
|
|
|
|
func detectSource(url string) string {
|
|
switch {
|
|
case strings.Contains(url, "youtube"):
|
|
return "YouTube"
|
|
case strings.Contains(url, "tiktok"):
|
|
return "TikTok"
|
|
default:
|
|
return "direct"
|
|
}
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func normalizePlatforms(platforms []string) map[string]bool {
|
|
if len(platforms) == 0 {
|
|
return map[string]bool{
|
|
"envato": true,
|
|
"artgrid": true,
|
|
"google video": true,
|
|
}
|
|
}
|
|
normalized := map[string]bool{}
|
|
for _, item := range platforms {
|
|
switch strings.ToLower(strings.TrimSpace(item)) {
|
|
case "envato":
|
|
normalized["envato"] = true
|
|
case "artgrid":
|
|
normalized["artgrid"] = true
|
|
case "google video", "google_video", "google":
|
|
normalized["google video"] = true
|
|
}
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func selectedPlatformLabel(platforms map[string]bool) string {
|
|
labels := make([]string, 0, len(platforms))
|
|
if platforms["envato"] {
|
|
labels = append(labels, "Envato")
|
|
}
|
|
if platforms["artgrid"] {
|
|
labels = append(labels, "Artgrid")
|
|
}
|
|
if platforms["google video"] {
|
|
labels = append(labels, "Google Video")
|
|
}
|
|
if len(labels) == 0 {
|
|
return "selected platforms"
|
|
}
|
|
return strings.Join(labels, ", ")
|
|
}
|
|
|
|
func evaluateAllCandidatesWithGemini(service *services.GeminiService, query string, ranked []services.SearchResult) []services.AIRecommendation {
|
|
const chunkSize = 8
|
|
merged := make([]services.AIRecommendation, 0, len(ranked))
|
|
seen := map[string]bool{}
|
|
for start := 0; start < len(ranked); start += chunkSize {
|
|
end := start + chunkSize
|
|
if end > len(ranked) {
|
|
end = len(ranked)
|
|
}
|
|
batch := ranked[start:end]
|
|
recommended, err := service.Recommend(query, batch)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
for _, item := range recommended {
|
|
if item.Link == "" || seen[item.Link] {
|
|
continue
|
|
}
|
|
seen[item.Link] = true
|
|
merged = append(merged, item)
|
|
}
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func rankSearchResults(query string, results []services.SearchResult) []services.SearchResult {
|
|
queryTerms := strings.Fields(strings.ToLower(query))
|
|
positiveTerms := []string{
|
|
"b-roll", "b roll", "stock", "stock footage", "footage", "cinematic", "editorial",
|
|
"establishing", "4k", "hd", "drone", "ambient", "scene", "urban", "cityscape",
|
|
}
|
|
negativeTerms := []string{
|
|
"shocking", "amazing", "crazy", "must watch", "reaction", "gossip", "celebrity",
|
|
"thumbnail", "meme", "prank", "drama", "breaking", "viral", "tutorial",
|
|
"how to", "review", "walkthrough", "course", "lesson", "podcast", "interview",
|
|
"premiere pro", "after effects", "explained", "breakdown", "vlog",
|
|
}
|
|
type scoredResult struct {
|
|
item services.SearchResult
|
|
score int
|
|
}
|
|
|
|
scored := make([]scoredResult, 0, len(results))
|
|
for _, result := range results {
|
|
score := 0
|
|
text := strings.ToLower(result.Title + " " + result.Snippet + " " + result.Source)
|
|
for _, term := range queryTerms {
|
|
if strings.Contains(text, term) {
|
|
score += 3
|
|
}
|
|
}
|
|
for _, term := range positiveTerms {
|
|
if strings.Contains(text, term) {
|
|
score += 2
|
|
}
|
|
}
|
|
for _, term := range negativeTerms {
|
|
if strings.Contains(text, term) {
|
|
score -= 4
|
|
}
|
|
}
|
|
if result.ThumbnailURL != "" {
|
|
score += 2
|
|
}
|
|
if result.PreviewVideoURL != "" {
|
|
score += 3
|
|
}
|
|
switch result.Source {
|
|
case "Google Video":
|
|
score -= 1
|
|
case "Envato":
|
|
score += 7
|
|
case "Artgrid":
|
|
score += 7
|
|
}
|
|
scored = append(scored, scoredResult{item: result, score: score})
|
|
}
|
|
|
|
sort.SliceStable(scored, func(i, j int) bool {
|
|
return scored[i].score > scored[j].score
|
|
})
|
|
|
|
ranked := make([]services.SearchResult, 0, len(scored))
|
|
for _, item := range scored {
|
|
ranked = append(ranked, item.item)
|
|
}
|
|
return ranked
|
|
}
|
|
|
|
func mergeRecommendations(recommended []services.AIRecommendation, ranked []services.SearchResult, limit int) []services.AIRecommendation {
|
|
merged := make([]services.AIRecommendation, 0, min(limit, len(ranked)))
|
|
seen := map[string]bool{}
|
|
|
|
for _, item := range recommended {
|
|
if item.Link == "" || seen[item.Link] {
|
|
continue
|
|
}
|
|
seen[item.Link] = true
|
|
merged = append(merged, item)
|
|
}
|
|
|
|
for _, item := range ranked {
|
|
if len(merged) >= limit || item.Link == "" || seen[item.Link] {
|
|
continue
|
|
}
|
|
seen[item.Link] = true
|
|
merged = append(merged, services.AIRecommendation{
|
|
Title: item.Title,
|
|
Link: item.Link,
|
|
ThumbnailURL: item.ThumbnailURL,
|
|
PreviewVideoURL: item.PreviewVideoURL,
|
|
Source: item.Source,
|
|
Reason: "Keyword-ranked result added without extra Gemini vision tokens.",
|
|
Recommended: true,
|
|
})
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func EnsurePaths(downloadsDir, workerScript string) error {
|
|
if err := os.MkdirAll(downloadsDir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
if _, err := os.Stat(workerScript); err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("worker script not found: %s", workerScript)
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func summarizeOutput(prefix string, output []byte, err error) string {
|
|
trimmed := strings.TrimSpace(string(output))
|
|
if trimmed == "" && err != nil {
|
|
return prefix + ": " + err.Error()
|
|
}
|
|
if trimmed == "" {
|
|
return prefix
|
|
}
|
|
if len(trimmed) > 1200 {
|
|
trimmed = trimmed[:1200] + "..."
|
|
}
|
|
return prefix + ": " + trimmed
|
|
}
|