refactor(baoyu-image-gen): export functions for testability and add module entry guard
This commit is contained in:
parent
a11613c11b
commit
3398509d9e
|
|
@ -1,6 +1,7 @@
|
||||||
import path from "node:path";
|
import path from "node:path";
|
||||||
import process from "node:process";
|
import process from "node:process";
|
||||||
import { homedir } from "node:os";
|
import { homedir } from "node:os";
|
||||||
|
import { fileURLToPath } from "node:url";
|
||||||
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
|
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
|
||||||
import type {
|
import type {
|
||||||
BatchFile,
|
BatchFile,
|
||||||
|
|
@ -136,7 +137,7 @@ Environment variables:
|
||||||
Env file load order: CLI args > EXTEND.md > process.env > <cwd>/.baoyu-skills/.env > ~/.baoyu-skills/.env`);
|
Env file load order: CLI args > EXTEND.md > process.env > <cwd>/.baoyu-skills/.env > ~/.baoyu-skills/.env`);
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseArgs(argv: string[]): CliArgs {
|
export function parseArgs(argv: string[]): CliArgs {
|
||||||
const out: CliArgs = {
|
const out: CliArgs = {
|
||||||
prompt: null,
|
prompt: null,
|
||||||
promptFiles: [],
|
promptFiles: [],
|
||||||
|
|
@ -338,12 +339,12 @@ async function loadEnv(): Promise<void> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function extractYamlFrontMatter(content: string): string | null {
|
export function extractYamlFrontMatter(content: string): string | null {
|
||||||
const match = content.match(/^---\s*\n([\s\S]*?)\n---\s*$/m);
|
const match = content.match(/^---\s*\n([\s\S]*?)\n---\s*$/m);
|
||||||
return match ? match[1] : null;
|
return match ? match[1] : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseSimpleYaml(yaml: string): Partial<ExtendConfig> {
|
export function parseSimpleYaml(yaml: string): Partial<ExtendConfig> {
|
||||||
const config: Partial<ExtendConfig> = {};
|
const config: Partial<ExtendConfig> = {};
|
||||||
const lines = yaml.split("\n");
|
const lines = yaml.split("\n");
|
||||||
let currentKey: string | null = null;
|
let currentKey: string | null = null;
|
||||||
|
|
@ -473,7 +474,7 @@ async function loadExtendConfig(): Promise<Partial<ExtendConfig>> {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
export function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
||||||
return {
|
return {
|
||||||
...args,
|
...args,
|
||||||
provider: args.provider ?? extend.default_provider ?? null,
|
provider: args.provider ?? extend.default_provider ?? null,
|
||||||
|
|
@ -483,13 +484,13 @@ function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function parsePositiveInt(value: string | undefined): number | null {
|
export function parsePositiveInt(value: string | undefined): number | null {
|
||||||
if (!value) return null;
|
if (!value) return null;
|
||||||
const parsed = parseInt(value, 10);
|
const parsed = parseInt(value, 10);
|
||||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
|
return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function parsePositiveBatchInt(value: unknown): number | null {
|
export function parsePositiveBatchInt(value: unknown): number | null {
|
||||||
if (value === null || value === undefined) return null;
|
if (value === null || value === undefined) return null;
|
||||||
if (typeof value === "number") {
|
if (typeof value === "number") {
|
||||||
return Number.isInteger(value) && value > 0 ? value : null;
|
return Number.isInteger(value) && value > 0 ? value : null;
|
||||||
|
|
@ -500,13 +501,13 @@ function parsePositiveBatchInt(value: unknown): number | null {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number {
|
export function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number {
|
||||||
const envValue = parsePositiveInt(process.env.BAOYU_IMAGE_GEN_MAX_WORKERS);
|
const envValue = parsePositiveInt(process.env.BAOYU_IMAGE_GEN_MAX_WORKERS);
|
||||||
const configValue = extendConfig.batch?.max_workers ?? null;
|
const configValue = extendConfig.batch?.max_workers ?? null;
|
||||||
return Math.max(1, envValue ?? configValue ?? DEFAULT_MAX_WORKERS);
|
return Math.max(1, envValue ?? configValue ?? DEFAULT_MAX_WORKERS);
|
||||||
}
|
}
|
||||||
|
|
||||||
function getConfiguredProviderRateLimits(
|
export function getConfiguredProviderRateLimits(
|
||||||
extendConfig: Partial<ExtendConfig>
|
extendConfig: Partial<ExtendConfig>
|
||||||
): Record<Provider, ProviderRateLimit> {
|
): Record<Provider, ProviderRateLimit> {
|
||||||
const configured: Record<Provider, ProviderRateLimit> = {
|
const configured: Record<Provider, ProviderRateLimit> = {
|
||||||
|
|
@ -559,14 +560,14 @@ async function readPromptFromStdin(): Promise<string | null> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function normalizeOutputImagePath(p: string): string {
|
export function normalizeOutputImagePath(p: string): string {
|
||||||
const full = path.resolve(p);
|
const full = path.resolve(p);
|
||||||
const ext = path.extname(full);
|
const ext = path.extname(full);
|
||||||
if (ext) return full;
|
if (ext) return full;
|
||||||
return `${full}.png`;
|
return `${full}.png`;
|
||||||
}
|
}
|
||||||
|
|
||||||
function detectProvider(args: CliArgs): Provider {
|
export function detectProvider(args: CliArgs): Provider {
|
||||||
if (
|
if (
|
||||||
args.referenceImages.length > 0 &&
|
args.referenceImages.length > 0 &&
|
||||||
args.provider &&
|
args.provider &&
|
||||||
|
|
@ -619,7 +620,7 @@ function detectProvider(args: CliArgs): Provider {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function validateReferenceImages(referenceImages: string[]): Promise<void> {
|
export async function validateReferenceImages(referenceImages: string[]): Promise<void> {
|
||||||
for (const refPath of referenceImages) {
|
for (const refPath of referenceImages) {
|
||||||
const fullPath = path.resolve(refPath);
|
const fullPath = path.resolve(refPath);
|
||||||
try {
|
try {
|
||||||
|
|
@ -630,7 +631,7 @@ async function validateReferenceImages(referenceImages: string[]): Promise<void>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function isRetryableGenerationError(error: unknown): boolean {
|
export function isRetryableGenerationError(error: unknown): boolean {
|
||||||
const msg = error instanceof Error ? error.message : String(error);
|
const msg = error instanceof Error ? error.message : String(error);
|
||||||
const nonRetryableMarkers = [
|
const nonRetryableMarkers = [
|
||||||
"Reference image",
|
"Reference image",
|
||||||
|
|
@ -712,7 +713,7 @@ async function prepareSingleTask(args: CliArgs, extendConfig: Partial<ExtendConf
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> {
|
export async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> {
|
||||||
const resolvedBatchFilePath = path.resolve(batchFilePath);
|
const resolvedBatchFilePath = path.resolve(batchFilePath);
|
||||||
const content = await readFile(resolvedBatchFilePath, "utf8");
|
const content = await readFile(resolvedBatchFilePath, "utf8");
|
||||||
const parsed = JSON.parse(content.replace(/^\uFEFF/, "")) as BatchFile;
|
const parsed = JSON.parse(content.replace(/^\uFEFF/, "")) as BatchFile;
|
||||||
|
|
@ -738,11 +739,11 @@ async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks>
|
||||||
throw new Error("Invalid batch file. Expected an array of tasks or an object with a tasks array.");
|
throw new Error("Invalid batch file. Expected an array of tasks or an object with a tasks array.");
|
||||||
}
|
}
|
||||||
|
|
||||||
function resolveBatchPath(batchDir: string, filePath: string): string {
|
export function resolveBatchPath(batchDir: string, filePath: string): string {
|
||||||
return path.isAbsolute(filePath) ? filePath : path.resolve(batchDir, filePath);
|
return path.isAbsolute(filePath) ? filePath : path.resolve(batchDir, filePath);
|
||||||
}
|
}
|
||||||
|
|
||||||
function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs {
|
export function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs {
|
||||||
return {
|
return {
|
||||||
...baseArgs,
|
...baseArgs,
|
||||||
prompt: task.prompt ?? null,
|
prompt: task.prompt ?? null,
|
||||||
|
|
@ -881,7 +882,7 @@ function createProviderGate(providerRateLimits: Record<Provider, ProviderRateLim
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number {
|
export function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number {
|
||||||
const requested = jobs ?? Math.min(taskCount, maxWorkers);
|
const requested = jobs ?? Math.min(taskCount, maxWorkers);
|
||||||
return Math.max(1, Math.min(requested, taskCount, maxWorkers));
|
return Math.max(1, Math.min(requested, taskCount, maxWorkers));
|
||||||
}
|
}
|
||||||
|
|
@ -1011,8 +1012,21 @@ async function main(): Promise<void> {
|
||||||
await runSingleMode(mergedArgs, extendConfig);
|
await runSingleMode(mergedArgs, extendConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
main().catch((error) => {
|
function isDirectExecution(metaUrl: string): boolean {
|
||||||
const message = error instanceof Error ? error.message : String(error);
|
const entryPath = process.argv[1];
|
||||||
console.error(message);
|
if (!entryPath) return false;
|
||||||
process.exit(1);
|
|
||||||
});
|
try {
|
||||||
|
return path.resolve(entryPath) === fileURLToPath(metaUrl);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDirectExecution(import.meta.url)) {
|
||||||
|
main().catch((error) => {
|
||||||
|
const message = error instanceof Error ? error.message : String(error);
|
||||||
|
console.error(message);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ function getBaseUrl(): string {
|
||||||
return base.replace(/\/+$/g, "");
|
return base.replace(/\/+$/g, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
export function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||||
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
||||||
if (!match) return null;
|
if (!match) return null;
|
||||||
const w = parseFloat(match[1]!);
|
const w = parseFloat(match[1]!);
|
||||||
|
|
@ -45,7 +45,7 @@ const STANDARD_SIZES_2K: [number, number][] = [
|
||||||
[2048, 2048],
|
[2048, 2048],
|
||||||
];
|
];
|
||||||
|
|
||||||
function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string {
|
export function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string {
|
||||||
const is2k = quality === "2k";
|
const is2k = quality === "2k";
|
||||||
const defaultSize = is2k ? "1536*1536" : "1024*1024";
|
const defaultSize = is2k ? "1536*1536" : "1024*1024";
|
||||||
|
|
||||||
|
|
@ -71,7 +71,7 @@ function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]):
|
||||||
return best;
|
return best;
|
||||||
}
|
}
|
||||||
|
|
||||||
function normalizeSize(size: string): string {
|
export function normalizeSize(size: string): string {
|
||||||
return size.replace("x", "*");
|
return size.replace("x", "*");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,16 @@ export function getDefaultModel(): string {
|
||||||
return process.env.GOOGLE_IMAGE_MODEL || "gemini-3-pro-image-preview";
|
return process.env.GOOGLE_IMAGE_MODEL || "gemini-3-pro-image-preview";
|
||||||
}
|
}
|
||||||
|
|
||||||
function normalizeGoogleModelId(model: string): string {
|
export function normalizeGoogleModelId(model: string): string {
|
||||||
return model.startsWith("models/") ? model.slice("models/".length) : model;
|
return model.startsWith("models/") ? model.slice("models/".length) : model;
|
||||||
}
|
}
|
||||||
|
|
||||||
function isGoogleMultimodal(model: string): boolean {
|
export function isGoogleMultimodal(model: string): boolean {
|
||||||
const normalized = normalizeGoogleModelId(model);
|
const normalized = normalizeGoogleModelId(model);
|
||||||
return GOOGLE_MULTIMODAL_MODELS.some((m) => normalized.includes(m));
|
return GOOGLE_MULTIMODAL_MODELS.some((m) => normalized.includes(m));
|
||||||
}
|
}
|
||||||
|
|
||||||
function isGoogleImagen(model: string): boolean {
|
export function isGoogleImagen(model: string): boolean {
|
||||||
const normalized = normalizeGoogleModelId(model);
|
const normalized = normalizeGoogleModelId(model);
|
||||||
return GOOGLE_IMAGEN_MODELS.some((m) => normalized.includes(m));
|
return GOOGLE_IMAGEN_MODELS.some((m) => normalized.includes(m));
|
||||||
}
|
}
|
||||||
|
|
@ -35,7 +35,7 @@ function getGoogleApiKey(): string | null {
|
||||||
return process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || null;
|
return process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" {
|
export function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" {
|
||||||
if (args.imageSize) return args.imageSize as "1K" | "2K" | "4K";
|
if (args.imageSize) return args.imageSize as "1K" | "2K" | "4K";
|
||||||
return args.quality === "2k" ? "2K" : "1K";
|
return args.quality === "2k" ? "2K" : "1K";
|
||||||
}
|
}
|
||||||
|
|
@ -46,7 +46,7 @@ function getGoogleBaseUrl(): string {
|
||||||
return base.replace(/\/+$/g, "");
|
return base.replace(/\/+$/g, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildGoogleUrl(pathname: string): string {
|
export function buildGoogleUrl(pathname: string): string {
|
||||||
const base = getGoogleBaseUrl();
|
const base = getGoogleBaseUrl();
|
||||||
const cleanedPath = pathname.replace(/^\/+/g, "");
|
const cleanedPath = pathname.replace(/^\/+/g, "");
|
||||||
if (base.endsWith("/v1beta")) return `${base}/${cleanedPath}`;
|
if (base.endsWith("/v1beta")) return `${base}/${cleanedPath}`;
|
||||||
|
|
@ -162,7 +162,7 @@ async function postGoogleJson<T>(pathname: string, body: unknown): Promise<T> {
|
||||||
return postGoogleJsonViaFetch<T>(url, apiKey, body);
|
return postGoogleJsonViaFetch<T>(url, apiKey, body);
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildPromptWithAspect(
|
export function buildPromptWithAspect(
|
||||||
prompt: string,
|
prompt: string,
|
||||||
ar: string | null,
|
ar: string | null,
|
||||||
quality: CliArgs["quality"],
|
quality: CliArgs["quality"],
|
||||||
|
|
@ -177,7 +177,7 @@ function buildPromptWithAspect(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
function addAspectRatioToPrompt(prompt: string, ar: string | null): string {
|
export function addAspectRatioToPrompt(prompt: string, ar: string | null): string {
|
||||||
if (!ar) return prompt;
|
if (!ar) return prompt;
|
||||||
return `${prompt} Aspect ratio: ${ar}.`;
|
return `${prompt} Aspect ratio: ${ar}.`;
|
||||||
}
|
}
|
||||||
|
|
@ -194,7 +194,7 @@ async function readImageAsBase64(
|
||||||
return { data: buf.toString("base64"), mimeType };
|
return { data: buf.toString("base64"), mimeType };
|
||||||
}
|
}
|
||||||
|
|
||||||
function extractInlineImageData(response: {
|
export function extractInlineImageData(response: {
|
||||||
candidates?: Array<{
|
candidates?: Array<{
|
||||||
content?: { parts?: Array<{ inlineData?: { data?: string } }> };
|
content?: { parts?: Array<{ inlineData?: { data?: string } }> };
|
||||||
}>;
|
}>;
|
||||||
|
|
@ -208,7 +208,7 @@ function extractInlineImageData(response: {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function extractPredictedImageData(response: {
|
export function extractPredictedImageData(response: {
|
||||||
predictions?: Array<any>;
|
predictions?: Array<any>;
|
||||||
generatedImages?: Array<any>;
|
generatedImages?: Array<any>;
|
||||||
}): string | null {
|
}): string | null {
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ export function getDefaultModel(): string {
|
||||||
|
|
||||||
type OpenAIImageResponse = { data: Array<{ url?: string; b64_json?: string }> };
|
type OpenAIImageResponse = { data: Array<{ url?: string; b64_json?: string }> };
|
||||||
|
|
||||||
function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
export function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||||
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
||||||
if (!match) return null;
|
if (!match) return null;
|
||||||
const w = parseFloat(match[1]!);
|
const w = parseFloat(match[1]!);
|
||||||
|
|
@ -23,7 +23,7 @@ type SizeMapping = {
|
||||||
portrait: string;
|
portrait: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
function getOpenAISize(
|
export function getOpenAISize(
|
||||||
model: string,
|
model: string,
|
||||||
ar: string | null,
|
ar: string | null,
|
||||||
quality: CliArgs["quality"]
|
quality: CliArgs["quality"]
|
||||||
|
|
@ -201,7 +201,7 @@ async function generateWithOpenAIEdits(
|
||||||
return extractImageFromResponse(result);
|
return extractImageFromResponse(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
function getMimeType(filename: string): string {
|
export function getMimeType(filename: string): string {
|
||||||
const ext = path.extname(filename).toLowerCase();
|
const ext = path.extname(filename).toLowerCase();
|
||||||
if (ext === ".jpg" || ext === ".jpeg") return "image/jpeg";
|
if (ext === ".jpg" || ext === ".jpeg") return "image/jpeg";
|
||||||
if (ext === ".webp") return "image/webp";
|
if (ext === ".webp") return "image/webp";
|
||||||
|
|
@ -209,7 +209,7 @@ function getMimeType(filename: string): string {
|
||||||
return "image/png";
|
return "image/png";
|
||||||
}
|
}
|
||||||
|
|
||||||
async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> {
|
export async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> {
|
||||||
const img = result.data[0];
|
const img = result.data[0];
|
||||||
|
|
||||||
if (img?.b64_json) {
|
if (img?.b64_json) {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ function getBaseUrl(): string {
|
||||||
return base.replace(/\/+$/g, "");
|
return base.replace(/\/+$/g, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
function parseModelId(model: string): { owner: string; name: string; version: string | null } {
|
export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
|
||||||
const [ownerName, version] = model.split(":");
|
const [ownerName, version] = model.split(":");
|
||||||
const parts = ownerName!.split("/");
|
const parts = ownerName!.split("/");
|
||||||
if (parts.length !== 2 || !parts[0] || !parts[1]) {
|
if (parts.length !== 2 || !parts[0] || !parts[1]) {
|
||||||
|
|
@ -31,7 +31,7 @@ function parseModelId(model: string): { owner: string; name: string; version: st
|
||||||
return { owner: parts[0], name: parts[1], version: version || null };
|
return { owner: parts[0], name: parts[1], version: version || null };
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
|
export function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
|
||||||
const input: Record<string, unknown> = { prompt };
|
const input: Record<string, unknown> = { prompt };
|
||||||
|
|
||||||
if (args.aspectRatio) {
|
if (args.aspectRatio) {
|
||||||
|
|
@ -144,7 +144,7 @@ async function pollPrediction(apiToken: string, getUrl: string): Promise<Predict
|
||||||
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
|
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
|
||||||
}
|
}
|
||||||
|
|
||||||
function extractOutputUrl(prediction: PredictionResponse): string {
|
export function extractOutputUrl(prediction: PredictionResponse): string {
|
||||||
const output = prediction.output;
|
const output = prediction.output;
|
||||||
|
|
||||||
if (typeof output === "string") return output;
|
if (typeof output === "string") return output;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue