Files
recommender/packages/backend/src/pipelines/recommendation.ts
2026-04-20 19:37:33 -03:00

328 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { eq } from 'drizzle-orm';
import { db } from '../db.js';
import { recommendations } from '../db/schema.js';
import { runInterpreter } from '../agents/interpreter.js';
import { runRetrieval } from '../agents/retrieval.js';
import { runValidator } from '../agents/validator.js';
import { runRanking } from '../agents/ranking.js';
import { runCurator } from '../agents/curator.js';
import type { CuratorOutput, InterpreterOutput, MediaType, RankingOutput, RetrievalCandidate, SSEEvent } from '../types/agents.js';
import { generateTitle } from '../agents/titleGenerator.js';
/* -- Agent pipeline --
[1] Interpreter -> gets user input, transforms into structured data
[2] Retrieval -> gets candidates from OpenAI (high temperature)
[2.5] Validator (optional) -> verifies candidates exist, removes trash
[3] Ranking -> ranks candidates based on user input
[4] Curator -> curates candidates based on user input
*/
type RecommendationRecord = typeof recommendations.$inferSelect;
function getBucketCount(count: number): number {
return Math.ceil(count / 15);
}
function deduplicateCandidates(candidates: RetrievalCandidate[], seenTitles?: Set<string>): RetrievalCandidate[] {
const seen = seenTitles ?? new Set<string>();
return candidates.filter((c) => {
const key = c.title.toLowerCase();
if (seen.has(key)) return false;
seen.add(key);
return true;
});
}
function splitIntoBuckets<T>(items: T[], n: number): T[][] {
const size = Math.ceil(items.length / n);
return Array.from({ length: n }, (_, i) => items.slice(i * size, (i + 1) * size))
.filter((b) => b.length > 0);
}
function mergeCuratorOutputs(a: CuratorOutput[], b: CuratorOutput[]): CuratorOutput[] {
const seen = new Set(a.map((x) => x.title.toLowerCase()));
return [...a, ...b.filter((x) => !seen.has(x.title.toLowerCase()))];
}
function log(recId: string, msg: string, data?: unknown) {
const ts = new Date().toISOString();
if (data !== undefined) {
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`, data);
} else {
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`);
}
}
interface SubPipelineCtx {
recId: string;
interpreterOutput: InterpreterOutput;
mediaType: MediaType;
useWebSearch: boolean;
useValidator: boolean;
useHardRequirements: boolean;
brainstormCount: number;
previousFullMatches: string[];
allSeenTitles: Set<string>;
stagePrefix: string;
sseWrite: (event: SSEEvent) => void;
}
async function runSubPipeline(ctx: SubPipelineCtx): Promise<CuratorOutput[]> {
const {
recId, interpreterOutput, mediaType, useWebSearch, useValidator,
useHardRequirements, brainstormCount, previousFullMatches,
allSeenTitles, stagePrefix, sseWrite,
} = ctx;
const p = (stage: string) => (stagePrefix + stage) as SSEEvent['stage'];
// --- Retrieval (bucketed) ---
log(recId, `${stagePrefix}Retrieval: start`);
sseWrite({ stage: p('retrieval'), status: 'start' });
const t1 = Date.now();
const retrievalBucketCount = getBucketCount(brainstormCount);
const perBucketCount = Math.ceil(brainstormCount / retrievalBucketCount);
const retrievalBuckets = await Promise.all(
Array.from({ length: retrievalBucketCount }, () =>
runRetrieval(interpreterOutput, perBucketCount, mediaType, useWebSearch, useHardRequirements, previousFullMatches)
)
);
const allCandidates = retrievalBuckets.flatMap((r) => r.candidates);
const dedupedCandidates = deduplicateCandidates(allCandidates, allSeenTitles);
log(recId, `${stagePrefix}Retrieval: done (${Date.now() - t1}ms) — ${dedupedCandidates.length} candidates (${retrievalBucketCount} buckets, ${allCandidates.length} before dedup)`, {
titles: dedupedCandidates.map((c) => c.title),
});
sseWrite({ stage: p('retrieval'), status: 'done', data: { candidates: dedupedCandidates } });
// --- Validator (optional) ---
let candidatesForRanking = dedupedCandidates;
if (useValidator) {
log(recId, `${stagePrefix}Validator: start`);
sseWrite({ stage: p('validator'), status: 'start' });
const tV = Date.now();
const validatorOutput = await runValidator(dedupedCandidates, mediaType);
const verified = validatorOutput.candidates.filter((c) => !c.isTrash);
const trashCount = validatorOutput.candidates.length - verified.length;
candidatesForRanking = verified.map(({ title, reason }) => ({ title, reason }));
log(recId, `${stagePrefix}Validator: done (${Date.now() - tV}ms) — removed ${trashCount} trash entries`);
sseWrite({ stage: p('validator'), status: 'done', data: { removed: trashCount } });
} else {
sseWrite({ stage: p('validator'), status: 'done', data: { skipped: true } });
}
// --- Ranking (bucketed) ---
log(recId, `${stagePrefix}Ranking: start`);
sseWrite({ stage: p('ranking'), status: 'start' });
const t2 = Date.now();
const rankBucketCount = getBucketCount(candidatesForRanking.length);
const candidateBuckets = splitIntoBuckets(candidatesForRanking, rankBucketCount);
const rankingBuckets = await Promise.all(
candidateBuckets.map((bucket) =>
runRanking(interpreterOutput, { candidates: bucket }, mediaType, useHardRequirements)
)
);
const dedupTitles = (titles: string[]) => [...new Map(titles.map((t) => [t.toLowerCase(), t])).values()];
const rankingOutput: RankingOutput = {
full_match: dedupTitles(rankingBuckets.flatMap((r) => r.full_match)),
definitely_like: dedupTitles(rankingBuckets.flatMap((r) => r.definitely_like)),
might_like: dedupTitles(rankingBuckets.flatMap((r) => r.might_like)),
questionable: dedupTitles(rankingBuckets.flatMap((r) => r.questionable)),
will_not_like: dedupTitles(rankingBuckets.flatMap((r) => r.will_not_like)),
};
log(recId, `${stagePrefix}Ranking: done (${Date.now() - t2}ms) — ${rankBucketCount} buckets`, {
full_match: rankingOutput.full_match.length,
definitely_like: rankingOutput.definitely_like.length,
might_like: rankingOutput.might_like.length,
questionable: rankingOutput.questionable.length,
will_not_like: rankingOutput.will_not_like.length,
});
sseWrite({ stage: p('ranking'), status: 'done', data: rankingOutput });
// --- Curator (bucketed) ---
log(recId, `${stagePrefix}Curator: start`);
sseWrite({ stage: p('curator'), status: 'start' });
const t3 = Date.now();
type CategorizedItem = { title: string; category: keyof RankingOutput };
const categorizedItems: CategorizedItem[] = [
...rankingOutput.full_match.map((t) => ({ title: t, category: 'full_match' as const })),
...rankingOutput.definitely_like.map((t) => ({ title: t, category: 'definitely_like' as const })),
...rankingOutput.might_like.map((t) => ({ title: t, category: 'might_like' as const })),
...rankingOutput.questionable.map((t) => ({ title: t, category: 'questionable' as const })),
...rankingOutput.will_not_like.map((t) => ({ title: t, category: 'will_not_like' as const })),
];
const curatorBucketCount = getBucketCount(categorizedItems.length);
const curatorItemBuckets = splitIntoBuckets(categorizedItems, curatorBucketCount);
const curatorBucketRankings: RankingOutput[] = curatorItemBuckets.map((bucket) => ({
full_match: bucket.filter((i) => i.category === 'full_match').map((i) => i.title),
definitely_like: bucket.filter((i) => i.category === 'definitely_like').map((i) => i.title),
might_like: bucket.filter((i) => i.category === 'might_like').map((i) => i.title),
questionable: bucket.filter((i) => i.category === 'questionable').map((i) => i.title),
will_not_like: bucket.filter((i) => i.category === 'will_not_like').map((i) => i.title),
}));
const curatorBucketOutputs = await Promise.all(
curatorBucketRankings.map((ranking) =>
runCurator(ranking, interpreterOutput, mediaType, useWebSearch)
)
);
const curatorOutput = curatorBucketOutputs.reduce((acc, bucket) => mergeCuratorOutputs(acc, bucket), [] as CuratorOutput[]);
log(recId, `${stagePrefix}Curator: done (${Date.now() - t3}ms) — ${curatorOutput.length} items curated (${curatorBucketCount} buckets)`);
sseWrite({ stage: p('curator'), status: 'done', data: curatorOutput });
return curatorOutput;
}
export async function runPipeline(
rec: RecommendationRecord,
sseWrite: (event: SSEEvent) => void,
feedbackContext?: string,
): Promise<CuratorOutput[]> {
let currentStage: SSEEvent['stage'] = 'interpreter';
const startTime = Date.now();
const mediaType = (rec.media_type ?? 'tv_show') as MediaType;
const useWebSearch = rec.use_web_search ?? false;
const useValidator = rec.use_validator ?? false;
const useHardRequirements = rec.hard_requirements ?? false;
const selfExpansive = rec.self_expansive ?? false;
log(rec.id, `Starting pipeline for "${rec.title}" [${mediaType}${useWebSearch ? ', web_search' : ''}${useValidator ? ', validator' : ''}${useHardRequirements ? ', hard_req' : ''}${selfExpansive ? `, expansive×${rec.expansive_passes}(${rec.expansive_mode})` : ''}]${feedbackContext ? ' (with feedback context)' : ''}`);
try {
// Set status to running
log(rec.id, 'Setting status → running');
await db
.update(recommendations)
.set({ status: 'running' })
.where(eq(recommendations.id, rec.id));
// --- Interpreter ---
currentStage = 'interpreter';
log(rec.id, 'Interpreter: start');
sseWrite({ stage: 'interpreter', status: 'start' });
const t0 = Date.now();
const interpreterOutput = await runInterpreter({
main_prompt: rec.main_prompt,
liked_series: rec.liked_series,
disliked_series: rec.disliked_series,
themes: rec.themes,
media_type: mediaType,
...(feedbackContext !== undefined ? { feedback_context: feedbackContext } : {}),
});
log(rec.id, `Interpreter: done (${Date.now() - t0}ms)`, {
liked: interpreterOutput.liked,
disliked: interpreterOutput.disliked,
themes: interpreterOutput.themes,
tone: interpreterOutput.tone,
avoid: interpreterOutput.avoid,
});
sseWrite({ stage: 'interpreter', status: 'done', data: interpreterOutput });
// --- Pass 1: Retrieval → [Validator?] → Ranking → Curator ---
currentStage = 'retrieval';
const allSeenTitles = new Set<string>();
const pass1Output = await runSubPipeline({
recId: rec.id,
interpreterOutput,
mediaType,
useWebSearch,
useValidator,
useHardRequirements,
brainstormCount: rec.brainstorm_count,
previousFullMatches: [],
allSeenTitles,
stagePrefix: '',
sseWrite: (event) => {
currentStage = event.stage;
sseWrite(event);
},
});
let mergedOutput = pass1Output;
// --- Self Expansive: extra passes ---
if (selfExpansive && rec.expansive_passes > 0) {
const allFullMatches = pass1Output
.filter((c) => c.category === 'Full Match')
.map((c) => c.title);
for (let i = 0; i < rec.expansive_passes; i++) {
const passNum = i + 2;
const passCount = rec.expansive_mode === 'extreme' ? rec.brainstorm_count : 60;
const passPrefix = `pass${passNum}:` as const;
log(rec.id, `Self Expansive Pass ${passNum}: start (${passCount} candidates, ${allFullMatches.length} full matches as context)`);
currentStage = `${passPrefix}retrieval` as SSEEvent['stage'];
const passOutput = await runSubPipeline({
recId: rec.id,
interpreterOutput,
mediaType,
useWebSearch,
useValidator,
useHardRequirements,
brainstormCount: passCount,
previousFullMatches: [...allFullMatches],
allSeenTitles,
stagePrefix: passPrefix,
sseWrite: (event) => {
currentStage = event.stage;
sseWrite(event);
},
});
mergedOutput = mergeCuratorOutputs(mergedOutput, passOutput);
const newFullMatches = passOutput
.filter((c) => c.category === 'Full Match')
.map((c) => c.title);
allFullMatches.push(...newFullMatches);
log(rec.id, `Self Expansive Pass ${passNum}: done — ${passOutput.length} new items, ${mergedOutput.length} total`);
}
}
// Generate AI title
let aiTitle: string = rec.title;
try {
log(rec.id, 'Title generation: start');
aiTitle = await generateTitle(interpreterOutput, mediaType);
log(rec.id, `Title generation: done — "${aiTitle}"`);
} catch (err) {
log(rec.id, `Title generation failed, keeping initial title: ${String(err)}`);
}
// Sort by category order before saving
const CATEGORY_ORDER: Record<string, number> = {
'Full Match': 0,
'Definitely Like': 1,
'Might Like': 2,
'Questionable': 3,
'Will Not Like': 4,
};
mergedOutput.sort((a, b) => (CATEGORY_ORDER[a.category] ?? 99) - (CATEGORY_ORDER[b.category] ?? 99));
// Save results to DB
log(rec.id, 'Saving results to DB');
await db
.update(recommendations)
.set({ recommendations: mergedOutput, status: 'done', title: aiTitle })
.where(eq(recommendations.id, rec.id));
sseWrite({ stage: 'complete', status: 'done', data: { title: aiTitle } });
log(rec.id, `Pipeline complete (total: ${Date.now() - startTime}ms)`);
return mergedOutput;
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
log(rec.id, `Pipeline error at stage "${currentStage}": ${message}`);
sseWrite({ stage: currentStage, status: 'error', data: { message } });
await db
.update(recommendations)
.set({ status: 'error' })
.where(eq(recommendations.id, rec.id));
return [];
}
}