refactor(baoyu-image-gen): export functions for testability and add module entry guard
This commit is contained in:
parent
a11613c11b
commit
3398509d9e
|
|
@ -1,6 +1,7 @@
|
|||
import path from "node:path";
|
||||
import process from "node:process";
|
||||
import { homedir } from "node:os";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import { access, mkdir, readFile, writeFile } from "node:fs/promises";
|
||||
import type {
|
||||
BatchFile,
|
||||
|
|
@ -136,7 +137,7 @@ Environment variables:
|
|||
Env file load order: CLI args > EXTEND.md > process.env > <cwd>/.baoyu-skills/.env > ~/.baoyu-skills/.env`);
|
||||
}
|
||||
|
||||
function parseArgs(argv: string[]): CliArgs {
|
||||
export function parseArgs(argv: string[]): CliArgs {
|
||||
const out: CliArgs = {
|
||||
prompt: null,
|
||||
promptFiles: [],
|
||||
|
|
@ -338,12 +339,12 @@ async function loadEnv(): Promise<void> {
|
|||
}
|
||||
}
|
||||
|
||||
function extractYamlFrontMatter(content: string): string | null {
|
||||
export function extractYamlFrontMatter(content: string): string | null {
|
||||
const match = content.match(/^---\s*\n([\s\S]*?)\n---\s*$/m);
|
||||
return match ? match[1] : null;
|
||||
}
|
||||
|
||||
function parseSimpleYaml(yaml: string): Partial<ExtendConfig> {
|
||||
export function parseSimpleYaml(yaml: string): Partial<ExtendConfig> {
|
||||
const config: Partial<ExtendConfig> = {};
|
||||
const lines = yaml.split("\n");
|
||||
let currentKey: string | null = null;
|
||||
|
|
@ -473,7 +474,7 @@ async function loadExtendConfig(): Promise<Partial<ExtendConfig>> {
|
|||
return {};
|
||||
}
|
||||
|
||||
function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
||||
export function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
||||
return {
|
||||
...args,
|
||||
provider: args.provider ?? extend.default_provider ?? null,
|
||||
|
|
@ -483,13 +484,13 @@ function mergeConfig(args: CliArgs, extend: Partial<ExtendConfig>): CliArgs {
|
|||
};
|
||||
}
|
||||
|
||||
function parsePositiveInt(value: string | undefined): number | null {
|
||||
export function parsePositiveInt(value: string | undefined): number | null {
|
||||
if (!value) return null;
|
||||
const parsed = parseInt(value, 10);
|
||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
|
||||
}
|
||||
|
||||
function parsePositiveBatchInt(value: unknown): number | null {
|
||||
export function parsePositiveBatchInt(value: unknown): number | null {
|
||||
if (value === null || value === undefined) return null;
|
||||
if (typeof value === "number") {
|
||||
return Number.isInteger(value) && value > 0 ? value : null;
|
||||
|
|
@ -500,13 +501,13 @@ function parsePositiveBatchInt(value: unknown): number | null {
|
|||
return null;
|
||||
}
|
||||
|
||||
function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number {
|
||||
export function getConfiguredMaxWorkers(extendConfig: Partial<ExtendConfig>): number {
|
||||
const envValue = parsePositiveInt(process.env.BAOYU_IMAGE_GEN_MAX_WORKERS);
|
||||
const configValue = extendConfig.batch?.max_workers ?? null;
|
||||
return Math.max(1, envValue ?? configValue ?? DEFAULT_MAX_WORKERS);
|
||||
}
|
||||
|
||||
function getConfiguredProviderRateLimits(
|
||||
export function getConfiguredProviderRateLimits(
|
||||
extendConfig: Partial<ExtendConfig>
|
||||
): Record<Provider, ProviderRateLimit> {
|
||||
const configured: Record<Provider, ProviderRateLimit> = {
|
||||
|
|
@ -559,14 +560,14 @@ async function readPromptFromStdin(): Promise<string | null> {
|
|||
}
|
||||
}
|
||||
|
||||
function normalizeOutputImagePath(p: string): string {
|
||||
export function normalizeOutputImagePath(p: string): string {
|
||||
const full = path.resolve(p);
|
||||
const ext = path.extname(full);
|
||||
if (ext) return full;
|
||||
return `${full}.png`;
|
||||
}
|
||||
|
||||
function detectProvider(args: CliArgs): Provider {
|
||||
export function detectProvider(args: CliArgs): Provider {
|
||||
if (
|
||||
args.referenceImages.length > 0 &&
|
||||
args.provider &&
|
||||
|
|
@ -619,7 +620,7 @@ function detectProvider(args: CliArgs): Provider {
|
|||
);
|
||||
}
|
||||
|
||||
async function validateReferenceImages(referenceImages: string[]): Promise<void> {
|
||||
export async function validateReferenceImages(referenceImages: string[]): Promise<void> {
|
||||
for (const refPath of referenceImages) {
|
||||
const fullPath = path.resolve(refPath);
|
||||
try {
|
||||
|
|
@ -630,7 +631,7 @@ async function validateReferenceImages(referenceImages: string[]): Promise<void>
|
|||
}
|
||||
}
|
||||
|
||||
function isRetryableGenerationError(error: unknown): boolean {
|
||||
export function isRetryableGenerationError(error: unknown): boolean {
|
||||
const msg = error instanceof Error ? error.message : String(error);
|
||||
const nonRetryableMarkers = [
|
||||
"Reference image",
|
||||
|
|
@ -712,7 +713,7 @@ async function prepareSingleTask(args: CliArgs, extendConfig: Partial<ExtendConf
|
|||
};
|
||||
}
|
||||
|
||||
async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> {
|
||||
export async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks> {
|
||||
const resolvedBatchFilePath = path.resolve(batchFilePath);
|
||||
const content = await readFile(resolvedBatchFilePath, "utf8");
|
||||
const parsed = JSON.parse(content.replace(/^\uFEFF/, "")) as BatchFile;
|
||||
|
|
@ -738,11 +739,11 @@ async function loadBatchTasks(batchFilePath: string): Promise<LoadedBatchTasks>
|
|||
throw new Error("Invalid batch file. Expected an array of tasks or an object with a tasks array.");
|
||||
}
|
||||
|
||||
function resolveBatchPath(batchDir: string, filePath: string): string {
|
||||
export function resolveBatchPath(batchDir: string, filePath: string): string {
|
||||
return path.isAbsolute(filePath) ? filePath : path.resolve(batchDir, filePath);
|
||||
}
|
||||
|
||||
function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs {
|
||||
export function createTaskArgs(baseArgs: CliArgs, task: BatchTaskInput, batchDir: string): CliArgs {
|
||||
return {
|
||||
...baseArgs,
|
||||
prompt: task.prompt ?? null,
|
||||
|
|
@ -881,7 +882,7 @@ function createProviderGate(providerRateLimits: Record<Provider, ProviderRateLim
|
|||
};
|
||||
}
|
||||
|
||||
function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number {
|
||||
export function getWorkerCount(taskCount: number, jobs: number | null, maxWorkers: number): number {
|
||||
const requested = jobs ?? Math.min(taskCount, maxWorkers);
|
||||
return Math.max(1, Math.min(requested, taskCount, maxWorkers));
|
||||
}
|
||||
|
|
@ -1011,8 +1012,21 @@ async function main(): Promise<void> {
|
|||
await runSingleMode(mergedArgs, extendConfig);
|
||||
}
|
||||
|
||||
function isDirectExecution(metaUrl: string): boolean {
|
||||
const entryPath = process.argv[1];
|
||||
if (!entryPath) return false;
|
||||
|
||||
try {
|
||||
return path.resolve(entryPath) === fileURLToPath(metaUrl);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (isDirectExecution(import.meta.url)) {
|
||||
main().catch((error) => {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.error(message);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ function getBaseUrl(): string {
|
|||
return base.replace(/\/+$/g, "");
|
||||
}
|
||||
|
||||
function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||
export function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
||||
if (!match) return null;
|
||||
const w = parseFloat(match[1]!);
|
||||
|
|
@ -45,7 +45,7 @@ const STANDARD_SIZES_2K: [number, number][] = [
|
|||
[2048, 2048],
|
||||
];
|
||||
|
||||
function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string {
|
||||
export function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]): string {
|
||||
const is2k = quality === "2k";
|
||||
const defaultSize = is2k ? "1536*1536" : "1024*1024";
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ function getSizeFromAspectRatio(ar: string | null, quality: CliArgs["quality"]):
|
|||
return best;
|
||||
}
|
||||
|
||||
function normalizeSize(size: string): string {
|
||||
export function normalizeSize(size: string): string {
|
||||
return size.replace("x", "*");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,16 +17,16 @@ export function getDefaultModel(): string {
|
|||
return process.env.GOOGLE_IMAGE_MODEL || "gemini-3-pro-image-preview";
|
||||
}
|
||||
|
||||
function normalizeGoogleModelId(model: string): string {
|
||||
export function normalizeGoogleModelId(model: string): string {
|
||||
return model.startsWith("models/") ? model.slice("models/".length) : model;
|
||||
}
|
||||
|
||||
function isGoogleMultimodal(model: string): boolean {
|
||||
export function isGoogleMultimodal(model: string): boolean {
|
||||
const normalized = normalizeGoogleModelId(model);
|
||||
return GOOGLE_MULTIMODAL_MODELS.some((m) => normalized.includes(m));
|
||||
}
|
||||
|
||||
function isGoogleImagen(model: string): boolean {
|
||||
export function isGoogleImagen(model: string): boolean {
|
||||
const normalized = normalizeGoogleModelId(model);
|
||||
return GOOGLE_IMAGEN_MODELS.some((m) => normalized.includes(m));
|
||||
}
|
||||
|
|
@ -35,7 +35,7 @@ function getGoogleApiKey(): string | null {
|
|||
return process.env.GOOGLE_API_KEY || process.env.GEMINI_API_KEY || null;
|
||||
}
|
||||
|
||||
function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" {
|
||||
export function getGoogleImageSize(args: CliArgs): "1K" | "2K" | "4K" {
|
||||
if (args.imageSize) return args.imageSize as "1K" | "2K" | "4K";
|
||||
return args.quality === "2k" ? "2K" : "1K";
|
||||
}
|
||||
|
|
@ -46,7 +46,7 @@ function getGoogleBaseUrl(): string {
|
|||
return base.replace(/\/+$/g, "");
|
||||
}
|
||||
|
||||
function buildGoogleUrl(pathname: string): string {
|
||||
export function buildGoogleUrl(pathname: string): string {
|
||||
const base = getGoogleBaseUrl();
|
||||
const cleanedPath = pathname.replace(/^\/+/g, "");
|
||||
if (base.endsWith("/v1beta")) return `${base}/${cleanedPath}`;
|
||||
|
|
@ -162,7 +162,7 @@ async function postGoogleJson<T>(pathname: string, body: unknown): Promise<T> {
|
|||
return postGoogleJsonViaFetch<T>(url, apiKey, body);
|
||||
}
|
||||
|
||||
function buildPromptWithAspect(
|
||||
export function buildPromptWithAspect(
|
||||
prompt: string,
|
||||
ar: string | null,
|
||||
quality: CliArgs["quality"],
|
||||
|
|
@ -177,7 +177,7 @@ function buildPromptWithAspect(
|
|||
return result;
|
||||
}
|
||||
|
||||
function addAspectRatioToPrompt(prompt: string, ar: string | null): string {
|
||||
export function addAspectRatioToPrompt(prompt: string, ar: string | null): string {
|
||||
if (!ar) return prompt;
|
||||
return `${prompt} Aspect ratio: ${ar}.`;
|
||||
}
|
||||
|
|
@ -194,7 +194,7 @@ async function readImageAsBase64(
|
|||
return { data: buf.toString("base64"), mimeType };
|
||||
}
|
||||
|
||||
function extractInlineImageData(response: {
|
||||
export function extractInlineImageData(response: {
|
||||
candidates?: Array<{
|
||||
content?: { parts?: Array<{ inlineData?: { data?: string } }> };
|
||||
}>;
|
||||
|
|
@ -208,7 +208,7 @@ function extractInlineImageData(response: {
|
|||
return null;
|
||||
}
|
||||
|
||||
function extractPredictedImageData(response: {
|
||||
export function extractPredictedImageData(response: {
|
||||
predictions?: Array<any>;
|
||||
generatedImages?: Array<any>;
|
||||
}): string | null {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ export function getDefaultModel(): string {
|
|||
|
||||
type OpenAIImageResponse = { data: Array<{ url?: string; b64_json?: string }> };
|
||||
|
||||
function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||
export function parseAspectRatio(ar: string): { width: number; height: number } | null {
|
||||
const match = ar.match(/^(\d+(?:\.\d+)?):(\d+(?:\.\d+)?)$/);
|
||||
if (!match) return null;
|
||||
const w = parseFloat(match[1]!);
|
||||
|
|
@ -23,7 +23,7 @@ type SizeMapping = {
|
|||
portrait: string;
|
||||
};
|
||||
|
||||
function getOpenAISize(
|
||||
export function getOpenAISize(
|
||||
model: string,
|
||||
ar: string | null,
|
||||
quality: CliArgs["quality"]
|
||||
|
|
@ -201,7 +201,7 @@ async function generateWithOpenAIEdits(
|
|||
return extractImageFromResponse(result);
|
||||
}
|
||||
|
||||
function getMimeType(filename: string): string {
|
||||
export function getMimeType(filename: string): string {
|
||||
const ext = path.extname(filename).toLowerCase();
|
||||
if (ext === ".jpg" || ext === ".jpeg") return "image/jpeg";
|
||||
if (ext === ".webp") return "image/webp";
|
||||
|
|
@ -209,7 +209,7 @@ function getMimeType(filename: string): string {
|
|||
return "image/png";
|
||||
}
|
||||
|
||||
async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> {
|
||||
export async function extractImageFromResponse(result: OpenAIImageResponse): Promise<Uint8Array> {
|
||||
const img = result.data[0];
|
||||
|
||||
if (img?.b64_json) {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ function getBaseUrl(): string {
|
|||
return base.replace(/\/+$/g, "");
|
||||
}
|
||||
|
||||
function parseModelId(model: string): { owner: string; name: string; version: string | null } {
|
||||
export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
|
||||
const [ownerName, version] = model.split(":");
|
||||
const parts = ownerName!.split("/");
|
||||
if (parts.length !== 2 || !parts[0] || !parts[1]) {
|
||||
|
|
@ -31,7 +31,7 @@ function parseModelId(model: string): { owner: string; name: string; version: st
|
|||
return { owner: parts[0], name: parts[1], version: version || null };
|
||||
}
|
||||
|
||||
function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
|
||||
export function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
|
||||
const input: Record<string, unknown> = { prompt };
|
||||
|
||||
if (args.aspectRatio) {
|
||||
|
|
@ -144,7 +144,7 @@ async function pollPrediction(apiToken: string, getUrl: string): Promise<Predict
|
|||
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
|
||||
}
|
||||
|
||||
function extractOutputUrl(prediction: PredictionResponse): string {
|
||||
export function extractOutputUrl(prediction: PredictionResponse): string {
|
||||
const output = prediction.output;
|
||||
|
||||
if (typeof output === "string") return output;
|
||||
|
|
|
|||
Loading…
Reference in New Issue