refactor(baoyu-image-gen): export functions for testability and add module entry guard

This commit is contained in:
Jim Liu 宝玉 2026-03-13 16:16:50 -05:00
parent a11613c11b
commit 3398509d9e
5 changed files with 54 additions and 40 deletions

View File

@ -1,6 +1,7 @@
import path from "node:path"; import path from "node:path";
import process from "node:process"; import process from "node:process";
import { homedir } from "node:os"; import { homedir } from "node:os";
import { fileURLToPath } from "node:url";
import { access, mkdir, readFile, writeFile } from "node:fs/promises"; import { access, mkdir, readFile, writeFile } from "node:fs/promises";
import type { import type {
BatchFile, BatchFile,
@ -136,7 +137,7 @@ Environment variables:
Env file load order: CLI args > EXTEND.md > process.env > <cwd>/.baoyu-skills/.env > ~/.baoyu-skills/.env`); Env file load order: CLI args > EXTEND.md > process.env > <cwd>/.baoyu-skills/.env > ~/.baoyu-skills/.env`);
} }
function parseArgs(argv: string[]): CliArgs { export function parseArgs(argv: string[]): CliArgs {
const out: CliArgs = { const out: CliArgs = {
prompt: null, prompt: null,
promptFiles: [], promptFiles: [],
@ -338,12 +339,12 @@ async function loadEnv(): Promise<void> {
} }
} }
function extractYamlFrontMatter(content: string): string | null { export function extractYamlFrontMatter(content: string): string | null {
const match = content.match(/^---\s*\n([\s\S]*?)\n---\s*$/m); const match = content.match(/^---\s*\n([\s\S]*?)\n---\s*$/m);
return match ? match[1] : null; return match ? match[1] : null;
} }
function parseSimpleYaml(yaml: string): Partial<ExtendConfig> { export function parseSimpleYaml(yaml: string): Partial<ExtendConfig> {
const config: Partial<ExtendConfig> = {}; const config: Partial<ExtendConfig> = {};
const lines = yaml.split("\n"); const lines = yaml.split("\n");
let currentKey: string | null = null; let currentKey: string | null = null;
@ -473,7 +474,7 @@ async function loadExtendConfig(): Promise<Partial<ExtendConfig>> {
return {}; return {};
} }
function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs { export function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
return { return {
...args, ...args,
provider: args.provider ?? extend.default_provider ?? null, provider: args.provider ?? extend.default_provider ?? null,
@ -483,13 +484,13 @@ function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
}; };
} }
function parsePositiveInt(value: string | undefined): number | null { export function parsePositiveInt(value: string | undefined): number | null {
if (!value) return null; if (!value) return null;
const parsed = parseInt(value, 10); const parsed = parseInt(value, 10);
return Number.isFinite(parsed) && parsed > 0 ? parsed : null; return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
} }
function parsePositiveBatchInt(value: unknown): number | null { export function parsePositiveBatchInt(value: unknown): number | null {
if (value === null || value === undefined) return null; if (value === null || value === undefined) return null;
if (typeof value === "number") { if (typeof value === "number") {
return Number.isInteger(value) && value > 0 ? value : null; return Number.isInteger(value) && value > 0 ? value : null;
@ -500,13 +501,13 @@ function parsePositiveBatchInt(value: unknown): number | null {
return null; return null;
} }
function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number { export function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number {
const envValue = parsePositiveInt(process.env.BAOYU_IMAGE_GEN_MAX_WORKERS); const envValue = parsePositiveInt(process.env.BAOYU_IMAGE_GEN_MAX_WORKERS);
const configValue = extendConfig.batch?.max_workers ?? null; const configValue = extendConfig.batch?.max_workers ?? null;
return Math.max(1, envValue ?? configValue ?? DEFAULT_MAX_WORKERS); return Math.max(1, envValue ?? configValue ?? DEFAULT_MAX_WORKERS);
} }
function getConfiguredProviderRateLimits( export function getConfiguredProviderRateLimits(
extendConfig: Partial<ExtendConfig> extendConfig: Partial<ExtendConfig>
): Record<Provider, ProviderRateLimit> { ): Record<Provider, ProviderRateLimit> {
const configured: Record<Provider, ProviderRateLimit> = { const configured: Record<Provider, ProviderRateLimit> = {
@ -559,14 +560,14 @@ async function readPromptFromStdin(): Promise<string | null> {
} }
} }
function normalizeOutputImagePath(p: string): string { export function normalizeOutputImagePath(p: string): string {
const full = path.resolve(p); const full = path.resolve(p);
const ext = path.extname(full); const ext = path.extname(full);
if (ext) return full; if (ext) return full;
return `${full}.png`; return `${full}.png`;
} }
function detectProvider(args: CliArgs): Provider { export function detectProvider(args: CliArgs): Provider {
if ( if (
args.referenceImages.length > 0 && args.referenceImages.length > 0 &&
args.provider && args.provider &&
@ -619,7 +620,7 @@ function detectProvider(args: CliArgs): Provider {
); );
} }
async function validateReferenceImages(referenceImages: string[]): Promise<void> { export async function validateReferenceImages(referenceImages: string[]): Promise<void> {
for (const refPath of referenceImages) { for (const refPath of referenceImages) {
const fullPath = path.resolve(refPath); const fullPath = path.resolve(refPath);
try { try {
@ -630,7 +631,7 @@ async function validateReferenceImages(referenceImages: string[]): Promise<void>
} }
} }
function isRetryableGenerationError(error: unknown): boolean { export function isRetryableGenerationError(error: unknown): boolean {
const msg = error instanceof Error ? error.message : String(error); const msg = error instanceof Error ? error.message : String(error);
const nonRetryableMarkers = [ const nonRetryableMarkers = [
"Reference image", "Reference image",
@ -712,7 +713,7 @@ async function prepareSingleTask(args: CliArgs, extendConfig: Partial<ExtendConf
}; };
} }
async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> { export async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> {
const resolvedBatchFilePath = path.resolve(batchFilePath); const resolvedBatchFilePath = path.resolve(batchFilePath);
const content = await readFile(resolvedBatchFilePath, "utf8"); const content = await readFile(resolvedBatchFilePath, "utf8");
const parsed = JSON.parse(content.replace(/^\uFEFF/, "")) as BatchFile; const parsed = JSON.parse(content.replace(/^\uFEFF/, "")) as BatchFile;
@ -738,11 +739,11 @@ async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks>
throw new Error("Invalid batch file. Expected an array of tasks or an object with a tasks array."); throw new Error("Invalid batch file. Expected an array of tasks or an object with a tasks array.");
} }
function resolveBatchPath(batchDir: string, filePath: string): string { export function resolveBatchPath(batchDir: string, filePath: string): string {
return path.isAbsolute(filePath) ? filePath : path.resolve(batchDir, filePath); return path.isAbsolute(filePath) ? filePath : path.resolve(batchDir, filePath);
} }
function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs { export function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs {
return { return {
...baseArgs, ...baseArgs,
prompt: task.prompt ?? null, prompt: task.prompt ?? null,
@ -881,7 +882,7 @@ function createProviderGate(providerRateLimits: Record<Provider, ProviderRateLim
}; };
} }
function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number { export function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number {
const requested = jobs ?? Math.min(taskCount, maxWorkers); const requested = jobs ?? Math.min(taskCount, maxWorkers);
return Math.max(1, Math.min(requested, taskCount, maxWorkers)); return Math.max(1, Math.min(requested, taskCount, maxWorkers));
} }
@ -1011,8 +1012,21 @@ async function main(): Promise<void> {
await runSingleMode(mergedArgs, extendConfig); await runSingleMode(mergedArgs, extendConfig);
} }
main().catch((error) => { function isDirectExecution(metaUrl: string): boolean {
const message = error instanceof Error ? error.message : String(error); const entryPath = process.argv[1];
console.error(message); if (!entryPath) return false;
process.exit(1);
}); try {
return path.resolve(entryPath) === fileURLToPath(metaUrl);
} catch {
return false;
}
}
if (isDirectExecution(import.meta.url)) {
main().catch((error) => {
const message = error instanceof Error ? error.message : String(error);
console.error(message);
process.exit(1);
});
}

View File

@ -13,7 +13,7 @@ function getBaseUrl(): string {
return base.replace(/\/+$/g, ""); return base.replace(/\/+$/g, "");
} }
function parseAspectRatio(ar: string): { width: number; height: number } | null { export function parseAspectRatio(ar: string): { width: number; height: number } | null {
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/); const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
if (!match) return null; if (!match) return null;
const w = parseFloat(match[1]!); const w = parseFloat(match[1]!);
@ -45,7 +45,7 @@ const STANDARD_SIZES_2K: [number, number][] = [
[2048, 2048], [2048, 2048],
]; ];
function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string { export function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string {
const is2k = quality === "2k"; const is2k = quality === "2k";
const defaultSize = is2k ? "1536*1536" : "1024*1024"; const defaultSize = is2k ? "1536*1536" : "1024*1024";
@ -71,7 +71,7 @@ function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]):
return best; return best;
} }
function normalizeSize(size: string): string { export function normalizeSize(size: string): string {
return size.replace("x", "*"); return size.replace("x", "*");
} }

View File

@ -17,16 +17,16 @@ export function getDefaultModel(): string {
return process.env.GOOGLE_IMAGE_MODEL || "gemini-3-pro-image-preview"; return process.env.GOOGLE_IMAGE_MODEL || "gemini-3-pro-image-preview";
} }
function normalizeGoogleModelId(model: string): string { export function normalizeGoogleModelId(model: string): string {
return model.startsWith("models/") ? model.slice("models/".length) : model; return model.startsWith("models/") ? model.slice("models/".length) : model;
} }
function isGoogleMultimodal(model: string): boolean { export function isGoogleMultimodal(model: string): boolean {
const normalized = normalizeGoogleModelId(model); const normalized = normalizeGoogleModelId(model);
return GOOGLE_MULTIMODAL_MODELS.some((m) => normalized.includes(m)); return GOOGLE_MULTIMODAL_MODELS.some((m) => normalized.includes(m));
} }
function isGoogleImagen(model: string): boolean { export function isGoogleImagen(model: string): boolean {
const normalized = normalizeGoogleModelId(model); const normalized = normalizeGoogleModelId(model);
return GOOGLE_IMAGEN_MODELS.some((m) => normalized.includes(m)); return GOOGLE_IMAGEN_MODELS.some((m) => normalized.includes(m));
} }
@ -35,7 +35,7 @@ function getGoogleApiKey(): string | null {
return process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || null; return process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || null;
} }
function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" { export function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" {
if (args.imageSize) return args.imageSize as "1K" | "2K" | "4K"; if (args.imageSize) return args.imageSize as "1K" | "2K" | "4K";
return args.quality === "2k" ? "2K" : "1K"; return args.quality === "2k" ? "2K" : "1K";
} }
@ -46,7 +46,7 @@ function getGoogleBaseUrl(): string {
return base.replace(/\/+$/g, ""); return base.replace(/\/+$/g, "");
} }
function buildGoogleUrl(pathname: string): string { export function buildGoogleUrl(pathname: string): string {
const base = getGoogleBaseUrl(); const base = getGoogleBaseUrl();
const cleanedPath = pathname.replace(/^\/+/g, ""); const cleanedPath = pathname.replace(/^\/+/g, "");
if (base.endsWith("/v1beta")) return `${base}/${cleanedPath}`; if (base.endsWith("/v1beta")) return `${base}/${cleanedPath}`;
@ -162,7 +162,7 @@ async function postGoogleJson<T>(pathname: string, body: unknown): Promise<T> {
return postGoogleJsonViaFetch<T>(url, apiKey, body); return postGoogleJsonViaFetch<T>(url, apiKey, body);
} }
function buildPromptWithAspect( export function buildPromptWithAspect(
prompt: string, prompt: string,
ar: string | null, ar: string | null,
quality: CliArgs["quality"], quality: CliArgs["quality"],
@ -177,7 +177,7 @@ function buildPromptWithAspect(
return result; return result;
} }
function addAspectRatioToPrompt(prompt: string, ar: string | null): string { export function addAspectRatioToPrompt(prompt: string, ar: string | null): string {
if (!ar) return prompt; if (!ar) return prompt;
return `${prompt} Aspect ratio: ${ar}.`; return `${prompt} Aspect ratio: ${ar}.`;
} }
@ -194,7 +194,7 @@ async function readImageAsBase64(
return { data: buf.toString("base64"), mimeType }; return { data: buf.toString("base64"), mimeType };
} }
function extractInlineImageData(response: { export function extractInlineImageData(response: {
candidates?: Array<{ candidates?: Array<{
content?: { parts?: Array<{ inlineData?: { data?: string } }> }; content?: { parts?: Array<{ inlineData?: { data?: string } }> };
}>; }>;
@ -208,7 +208,7 @@ function extractInlineImageData(response: {
return null; return null;
} }
function extractPredictedImageData(response: { export function extractPredictedImageData(response: {
predictions?: Array<any>; predictions?: Array<any>;
generatedImages?: Array<any>; generatedImages?: Array<any>;
}): string | null { }): string | null {

View File

@ -8,7 +8,7 @@ export function getDefaultModel(): string {
type OpenAIImageResponse = { data: Array<{ url?: string; b64_json?: string }> }; type OpenAIImageResponse = { data: Array<{ url?: string; b64_json?: string }> };
function parseAspectRatio(ar: string): { width: number; height: number } | null { export function parseAspectRatio(ar: string): { width: number; height: number } | null {
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/); const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
if (!match) return null; if (!match) return null;
const w = parseFloat(match[1]!); const w = parseFloat(match[1]!);
@ -23,7 +23,7 @@ type SizeMapping = {
portrait: string; portrait: string;
}; };
function getOpenAISize( export function getOpenAISize(
model: string, model: string,
ar: string | null, ar: string | null,
quality: CliArgs["quality"] quality: CliArgs["quality"]
@ -201,7 +201,7 @@ async function generateWithOpenAIEdits(
return extractImageFromResponse(result); return extractImageFromResponse(result);
} }
function getMimeType(filename: string): string { export function getMimeType(filename: string): string {
const ext = path.extname(filename).toLowerCase(); const ext = path.extname(filename).toLowerCase();
if (ext === ".jpg" || ext === ".jpeg") return "image/jpeg"; if (ext === ".jpg" || ext === ".jpeg") return "image/jpeg";
if (ext === ".webp") return "image/webp"; if (ext === ".webp") return "image/webp";
@ -209,7 +209,7 @@ function getMimeType(filename: string): string {
return "image/png"; return "image/png";
} }
async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> { export async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> {
const img = result.data[0]; const img = result.data[0];
if (img?.b64_json) { if (img?.b64_json) {

View File

@ -20,7 +20,7 @@ function getBaseUrl(): string {
return base.replace(/\/+$/g, ""); return base.replace(/\/+$/g, "");
} }
function parseModelId(model: string): { owner: string; name: string; version: string | null } { export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
const [ownerName, version] = model.split(":"); const [ownerName, version] = model.split(":");
const parts = ownerName!.split("/"); const parts = ownerName!.split("/");
if (parts.length !== 2 || !parts[0] || !parts[1]) { if (parts.length !== 2 || !parts[0] || !parts[1]) {
@ -31,7 +31,7 @@ function parseModelId(model: string): { owner: string; name: string; version: st
return { owner: parts[0], name: parts[1], version: version || null }; return { owner: parts[0], name: parts[1], version: version || null };
} }
function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> { export function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
const input: Record<string, unknown> = { prompt }; const input: Record<string, unknown> = { prompt };
if (args.aspectRatio) { if (args.aspectRatio) {
@ -144,7 +144,7 @@ async function pollPrediction(apiToken: string, getUrl: string): Promise<Predict
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`); throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
} }
function extractOutputUrl(prediction: PredictionResponse): string { export function extractOutputUrl(prediction: PredictionResponse): string {
const output = prediction.output; const output = prediction.output;
if (typeof output === "string") return output; if (typeof output === "string") return output;