Files
ai-media-hub/backend/handlers/api.go
T
AI Assistant 7dfb1ad2de
build-push / docker (push) Successful in 4m14s
Stabilize search pipeline and improve preview diagnostics
2026-03-13 18:32:54 +09:00

689 lines
21 KiB
Go

package handlers
import (
"bufio"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
"ai-media-hub/backend/models"
"ai-media-hub/backend/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type App struct {
DB *sql.DB
DownloadsDir string
WorkerScript string
SearchService *services.SearchService
GeminiService *services.GeminiService
Hub *Hub
}
type Hub struct {
clients map[*websocket.Conn]bool
mu sync.Mutex
}
func NewHub() *Hub {
return &Hub{clients: map[*websocket.Conn]bool{}}
}
func (h *Hub) Broadcast(event string, data any) {
h.mu.Lock()
defer h.mu.Unlock()
payload, _ := json.Marshal(gin.H{"event": event, "data": data})
for conn := range h.clients {
_ = conn.WriteMessage(websocket.TextMessage, payload)
}
}
func (h *Hub) Add(conn *websocket.Conn) {
h.mu.Lock()
defer h.mu.Unlock()
h.clients[conn] = true
}
func (h *Hub) Remove(conn *websocket.Conn) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.clients, conn)
_ = conn.Close()
}
type PreviewResponse struct {
Title string `json:"title"`
Thumbnail string `json:"thumbnail"`
PreviewStreamURL string `json:"previewStreamUrl"`
Duration string `json:"duration"`
DurationSeconds int `json:"durationSeconds"`
StartDefault string `json:"startDefault"`
EndDefault string `json:"endDefault"`
Qualities []map[string]any `json:"qualities"`
}
type searchDebugSummary struct {
Total int `json:"total"`
BySource map[string]int `json:"bySource"`
WithPreview int `json:"withPreview"`
WithThumbnail int `json:"withThumbnail"`
Top []map[string]any `json:"top"`
Warning string `json:"warning,omitempty"`
DurationMS int64 `json:"durationMs,omitempty"`
GeminiCandidateCap int `json:"geminiCandidateCap,omitempty"`
}
type geminiBatchStats struct {
CandidateCap int `json:"candidateCap"`
Requested int `json:"requested"`
Batches int `json:"batches"`
Succeeded int `json:"succeeded"`
Failed int `json:"failed"`
RecommendedCount int `json:"recommendedCount"`
Errors []string `json:"errors,omitempty"`
}
func RegisterRoutes(router *gin.Engine, app *App) {
router.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
router.GET("/ws", app.handleWS)
router.GET("/api/history/check", app.checkDuplicate)
router.POST("/api/download/preview", app.previewDownload)
router.POST("/api/upload", app.uploadFile)
router.POST("/api/download", app.startDownload)
router.POST("/api/search", app.searchMedia)
}
func (a *App) debug(message string, data any) {
a.Hub.Broadcast("debug", gin.H{"message": message, "data": data})
}
func (a *App) handleWS(c *gin.Context) {
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
a.Hub.Add(conn)
defer a.Hub.Remove(conn)
for {
if _, _, err := conn.ReadMessage(); err != nil {
return
}
}
}
func (a *App) checkDuplicate(c *gin.Context) {
url := strings.TrimSpace(c.Query("url"))
if url == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
record, err := models.FindByURL(a.DB, url)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"exists": record != nil, "record": record})
}
func (a *App) uploadFile(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "file is required"})
return
}
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "started", "progress": 5, "filename": file.Filename})
safeName := normalizeFilename(file.Filename)
targetPath := filepath.Join(a.DownloadsDir, safeName)
if err := c.SaveUploadedFile(file, targetPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "completed", "progress": 100, "filename": safeName})
c.JSON(http.StatusOK, gin.H{"message": "uploaded", "path": targetPath, "filename": safeName})
}
func (a *App) startDownload(c *gin.Context) {
var req struct {
URL string `json:"url"`
Start string `json:"start"`
End string `json:"end"`
Quality string `json:"quality"`
Force bool `json:"force"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
rec, err := models.FindByURL(a.DB, req.URL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if rec != nil && !req.Force {
c.JSON(http.StatusConflict, gin.H{"error": "duplicate url", "record": rec})
return
}
outputBase := uuid.NewString()
outputPath := filepath.Join(a.DownloadsDir, outputBase+".mp4")
recordID, err := models.InsertDownload(a.DB, req.URL, detectSource(req.URL), outputPath, "queued")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
quality := strings.TrimSpace(req.Quality)
if quality == "" {
quality = "best"
}
go a.runDownload(recordID, req.URL, req.Start, req.End, quality, outputPath)
c.JSON(http.StatusAccepted, gin.H{"message": "download started", "recordId": recordID})
}
func (a *App) previewDownload(c *gin.Context) {
var req struct {
URL string `json:"url"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if strings.TrimSpace(req.URL) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
a.debug("download preview requested", gin.H{"url": req.URL})
cmd := exec.Command("python3", a.WorkerScript, "--mode", "probe", "--url", req.URL)
output, err := cmd.CombinedOutput()
if err != nil {
a.debug("download preview failed", gin.H{"url": req.URL, "output": string(output), "error": err.Error()})
c.JSON(http.StatusBadGateway, gin.H{"error": summarizeOutput("download preview probe failed", output, err)})
return
}
var preview PreviewResponse
if err := json.Unmarshal(output, &preview); err != nil {
a.debug("download preview invalid json", gin.H{"url": req.URL, "output": string(output)})
c.JSON(http.StatusBadGateway, gin.H{"error": summarizeOutput("download preview returned invalid JSON", output, err)})
return
}
a.debug("download preview succeeded", preview)
c.JSON(http.StatusOK, preview)
}
func (a *App) runDownload(recordID int64, url, start, end, quality, outputPath string) {
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "queued", "progress": 0, "url": url})
a.debug("download command started", gin.H{"url": url, "start": start, "end": end, "quality": quality, "outputPath": outputPath})
cmd := exec.Command("python3", a.WorkerScript, "--url", url, "--start", start, "--end", end, "--quality", quality, "--output", outputPath)
stdout, err := cmd.StdoutPipe()
if err != nil {
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
return
}
cmd.Stderr = cmd.Stdout
if err := cmd.Start(); err != nil {
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
return
}
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := scanner.Bytes()
var msg map[string]any
if err := json.Unmarshal(line, &msg); err == nil {
msg["type"] = "download"
a.debug("download worker event", msg)
a.Hub.Broadcast("progress", msg)
}
}
if err := scanner.Err(); err != nil {
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 100, "message": err.Error()})
}
status := "completed"
if err := cmd.Wait(); err != nil {
status = "failed"
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 100, "message": err.Error()})
a.debug("download command failed", gin.H{"url": url, "error": err.Error()})
}
a.debug("download command completed", gin.H{"url": url, "status": status, "outputPath": outputPath})
_ = models.MarkDownloadCompleted(a.DB, recordID, status)
}
func (a *App) searchMedia(c *gin.Context) {
started := time.Now()
var req struct {
Query string `json:"query"`
Platforms []string `json:"platforms"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if strings.TrimSpace(req.Query) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "query is required"})
return
}
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "expanding query with Gemini", "progress": 10})
queryVariants, _ := a.GeminiService.ExpandQuery(req.Query)
if len(queryVariants) == 0 {
queryVariants = []string{req.Query}
}
a.debug("search query variants", gin.H{
"query": req.Query,
"platforms": req.Platforms,
"variants": queryVariants,
"variantCount": len(queryVariants),
"requestIdHint": time.Now().UnixNano(),
})
enabledPlatforms := normalizePlatforms(req.Platforms)
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "searching " + selectedPlatformLabel(enabledPlatforms), "progress": 35})
results, err := a.SearchService.SearchMedia(queryVariants, enabledPlatforms)
if err != nil {
a.debug("search backend failed", gin.H{"error": err.Error(), "variants": queryVariants, "durationMs": time.Since(started).Milliseconds()})
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "search failed", "progress": 100, "message": err.Error()})
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
a.debug("search backend summary", summarizeSearchResults(results, time.Since(started), 0, ""))
if len(results) == 0 {
warning := "SearXNG returned no renderable results."
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "no renderable search results", "progress": 100, "message": warning})
c.JSON(http.StatusOK, gin.H{"results": []services.AIRecommendation{}, "warning": warning})
return
}
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "ranking thumbnail candidates", "progress": 55})
rankQuery := req.Query
if len(queryVariants) > 0 {
rankQuery = strings.Join(queryVariants[:min(len(queryVariants), 3)], " ")
}
scored := rankSearchResults(rankQuery, results)
a.debug("search ranked summary", summarizeSearchResults(scored, time.Since(started), geminiCandidateLimit(len(scored)), ""))
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "analyzing top candidate visuals with Gemini Vision", "progress": 75})
recommended, geminiStats := evaluateAllCandidatesWithGemini(a.GeminiService, req.Query, scored)
a.debug("search gemini evaluation", geminiStats)
err = nil
if len(recommended) == 0 {
err = fmt.Errorf("gemini vision returned no recommended items across all candidate batches")
}
if err != nil {
fallback := make([]services.AIRecommendation, 0, min(20, len(scored)))
for _, result := range scored[:min(20, len(scored))] {
fallback = append(fallback, services.AIRecommendation{
Title: result.Title,
Link: result.Link,
Snippet: result.Snippet,
ThumbnailURL: result.ThumbnailURL,
PreviewVideoURL: result.PreviewVideoURL,
Source: result.Source,
Reason: "Keyword-ranked result added without extra Gemini vision tokens.",
Recommended: true,
})
}
warning := err.Error()
a.debug("search fallback summary", summarizeRecommendationResults(fallback, time.Since(started), warning))
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "Gemini Vision fallback to ranked results", "progress": 90, "message": warning})
c.JSON(http.StatusOK, gin.H{"results": fallback, "warning": warning, "queries": queryVariants})
return
}
merged := mergeRecommendations(recommended, scored, 20)
a.debug("search complete summary", summarizeRecommendationResults(merged, time.Since(started), ""))
response := gin.H{"results": merged, "queries": queryVariants}
a.Hub.Broadcast("progress", gin.H{"type": "search", "status": "search complete", "progress": 100})
c.JSON(http.StatusOK, response)
}
func normalizeFilename(name string) string {
base := strings.ToLower(strings.TrimSpace(name))
ext := filepath.Ext(base)
base = strings.TrimSuffix(base, ext)
re := regexp.MustCompile(`[^a-z0-9]+`)
base = strings.Trim(re.ReplaceAllString(base, "-"), "-")
if base == "" {
base = fmt.Sprintf("upload-%d", time.Now().Unix())
}
if ext == "" {
ext = ".bin"
}
return base + ext
}
func detectSource(url string) string {
switch {
case strings.Contains(url, "youtube"):
return "YouTube"
case strings.Contains(url, "tiktok"):
return "TikTok"
default:
return "direct"
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func normalizePlatforms(platforms []string) map[string]bool {
if len(platforms) == 0 {
return map[string]bool{
"envato": true,
"artgrid": true,
"google video": true,
}
}
normalized := map[string]bool{}
for _, item := range platforms {
switch strings.ToLower(strings.TrimSpace(item)) {
case "envato":
normalized["envato"] = true
case "artgrid":
normalized["artgrid"] = true
case "google video", "google_video", "google":
normalized["google video"] = true
}
}
return normalized
}
func selectedPlatformLabel(platforms map[string]bool) string {
labels := make([]string, 0, len(platforms))
if platforms["envato"] {
labels = append(labels, "Envato")
}
if platforms["artgrid"] {
labels = append(labels, "Artgrid")
}
if platforms["google video"] {
labels = append(labels, "Google Video")
}
if len(labels) == 0 {
return "selected platforms"
}
return strings.Join(labels, ", ")
}
func evaluateAllCandidatesWithGemini(service *services.GeminiService, query string, ranked []services.SearchResult) ([]services.AIRecommendation, geminiBatchStats) {
const chunkSize = 8
limit := geminiCandidateLimit(len(ranked))
stats := geminiBatchStats{
CandidateCap: limit,
Requested: min(limit, len(ranked)),
}
merged := make([]services.AIRecommendation, 0, len(ranked))
seen := map[string]bool{}
for start := 0; start < limit; start += chunkSize {
end := start + chunkSize
if end > limit {
end = limit
}
batch := ranked[start:end]
stats.Batches++
recommended, err := service.Recommend(query, batch)
if err != nil {
stats.Failed++
if len(stats.Errors) < 5 {
stats.Errors = append(stats.Errors, err.Error())
}
continue
}
stats.Succeeded++
for _, item := range recommended {
if item.Link == "" || seen[item.Link] {
continue
}
seen[item.Link] = true
merged = append(merged, item)
}
}
stats.RecommendedCount = len(merged)
return merged, stats
}
func rankSearchResults(query string, results []services.SearchResult) []services.SearchResult {
queryTerms := strings.Fields(strings.ToLower(query))
positiveTerms := []string{
"b-roll", "b roll", "stock", "stock footage", "footage", "cinematic", "editorial",
"establishing", "4k", "hd", "drone", "ambient", "scene", "urban", "cityscape",
}
negativeTerms := []string{
"shocking", "amazing", "crazy", "must watch", "reaction", "gossip", "celebrity",
"thumbnail", "meme", "prank", "drama", "breaking", "viral", "tutorial",
"how to", "review", "walkthrough", "course", "lesson", "podcast", "interview",
"premiere pro", "after effects", "explained", "breakdown", "vlog",
}
type scoredResult struct {
item services.SearchResult
score int
}
scored := make([]scoredResult, 0, len(results))
for _, result := range results {
score := 0
text := strings.ToLower(result.Title + " " + result.Snippet + " " + result.Source)
for _, term := range queryTerms {
if strings.Contains(text, term) {
score += 3
}
}
for _, term := range positiveTerms {
if strings.Contains(text, term) {
score += 2
}
}
for _, term := range negativeTerms {
if strings.Contains(text, term) {
score -= 4
}
}
if result.ThumbnailURL != "" {
score += 2
}
if result.PreviewVideoURL != "" {
score += 3
}
switch result.Source {
case "Google Video":
score -= 1
case "Envato":
score += 7
case "Artgrid":
score += 7
}
scored = append(scored, scoredResult{item: result, score: score})
}
sort.SliceStable(scored, func(i, j int) bool {
return scored[i].score > scored[j].score
})
ranked := make([]services.SearchResult, 0, len(scored))
for _, item := range scored {
ranked = append(ranked, item.item)
}
return ranked
}
func mergeRecommendations(recommended []services.AIRecommendation, ranked []services.SearchResult, limit int) []services.AIRecommendation {
merged := make([]services.AIRecommendation, 0, min(limit, len(ranked)))
seen := map[string]bool{}
for _, item := range recommended {
if item.Link == "" || seen[item.Link] {
continue
}
seen[item.Link] = true
merged = append(merged, item)
}
for _, item := range ranked {
if len(merged) >= limit || item.Link == "" || seen[item.Link] {
continue
}
seen[item.Link] = true
merged = append(merged, services.AIRecommendation{
Title: item.Title,
Link: item.Link,
Snippet: item.Snippet,
ThumbnailURL: item.ThumbnailURL,
PreviewVideoURL: item.PreviewVideoURL,
Source: item.Source,
Reason: "Keyword-ranked result added without extra Gemini vision tokens.",
Recommended: true,
})
}
return merged
}
func geminiCandidateLimit(total int) int {
switch {
case total <= 8:
return total
case total <= 16:
return 12
default:
return 16
}
}
func summarizeSearchResults(results []services.SearchResult, duration time.Duration, geminiCap int, warning string) searchDebugSummary {
bySource := map[string]int{}
withPreview := 0
withThumbnail := 0
top := make([]map[string]any, 0, min(6, len(results)))
for idx, item := range results {
bySource[item.Source]++
if strings.TrimSpace(item.PreviewVideoURL) != "" {
withPreview++
}
if strings.TrimSpace(item.ThumbnailURL) != "" {
withThumbnail++
}
if idx < 6 {
top = append(top, map[string]any{
"title": truncateText(item.Title, 120),
"source": item.Source,
"hasPreview": item.PreviewVideoURL != "",
"hasThumbnail": item.ThumbnailURL != "",
"displayLink": item.DisplayLink,
"snippetSample": truncateText(item.Snippet, 160),
})
}
}
return searchDebugSummary{
Total: len(results),
BySource: bySource,
WithPreview: withPreview,
WithThumbnail: withThumbnail,
Top: top,
Warning: warning,
DurationMS: duration.Milliseconds(),
GeminiCandidateCap: geminiCap,
}
}
func summarizeRecommendationResults(results []services.AIRecommendation, duration time.Duration, warning string) searchDebugSummary {
bySource := map[string]int{}
withPreview := 0
withThumbnail := 0
top := make([]map[string]any, 0, min(6, len(results)))
for idx, item := range results {
bySource[item.Source]++
if strings.TrimSpace(item.PreviewVideoURL) != "" {
withPreview++
}
if strings.TrimSpace(item.ThumbnailURL) != "" {
withThumbnail++
}
if idx < 6 {
top = append(top, map[string]any{
"title": truncateText(item.Title, 120),
"source": item.Source,
"hasPreview": item.PreviewVideoURL != "",
"hasThumbnail": item.ThumbnailURL != "",
"reasonSample": truncateText(item.Reason, 120),
"snippetSample": truncateText(item.Snippet, 160),
})
}
}
return searchDebugSummary{
Total: len(results),
BySource: bySource,
WithPreview: withPreview,
WithThumbnail: withThumbnail,
Top: top,
Warning: warning,
DurationMS: duration.Milliseconds(),
}
}
func truncateText(text string, limit int) string {
trimmed := strings.TrimSpace(text)
if len(trimmed) <= limit {
return trimmed
}
return trimmed[:limit] + "..."
}
func EnsurePaths(downloadsDir, workerScript string) error {
if err := os.MkdirAll(downloadsDir, 0o755); err != nil {
return err
}
if _, err := os.Stat(workerScript); err != nil {
if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("worker script not found: %s", workerScript)
}
return err
}
return nil
}
func summarizeOutput(prefix string, output []byte, err error) string {
trimmed := strings.TrimSpace(string(output))
if trimmed == "" && err != nil {
return prefix + ": " + err.Error()
}
if trimmed == "" {
return prefix
}
if len(trimmed) > 1200 {
trimmed = trimmed[:1200] + "..."
}
return prefix + ": " + trimmed
}