AI 基础 | | 约 49 分钟 | 19,467 字

模型 API 调用实战:统一接口封装

封装统一的模型调用层,支持 Claude、OpenAI、本地模型的切换

为什么需要统一接口

当我们的项目同时使用多个 AI 提供商时,代码很快就会变成这样:

// ❌ 到处都是 if-else
if (provider === "anthropic") {
  const response = await anthropic.messages.create({
    model: "claude-sonnet-4-20250514",
    max_tokens: 1024,
    messages: [{ role: "user", content: prompt }],
  });
  return response.content[0].text;
} else if (provider === "openai") {
  const response = await openai.chat.completions.create({
    model: "gpt-4o",
    max_tokens: 1024,
    messages: [{ role: "user", content: prompt }],
  });
  return response.choices[0].message.content;
} else if (provider === "ollama") {
  // 又一套不同的 API...
}

每个提供商的 API 格式不同、错误处理不同、流式响应不同。如果不做封装,切换模型就是一场噩梦。

我们需要一个统一的接口层,让上层业务代码不关心底层用的是哪个模型。


设计统一接口

核心类型定义

// types.ts - 统一的类型定义

export interface Message {
  role: "system" | "user" | "assistant";
  content: string | ContentBlock[];
}

export interface ContentBlock {
  type: "text" | "image";
  text?: string;
  imageUrl?: string;
  imageBase64?: string;
  mediaType?: string;
}

export interface ChatOptions {
  model: string;
  messages: Message[];
  maxTokens?: number;
  temperature?: number;
  topP?: number;
  stream?: boolean;
  tools?: ToolDefinition[];
  responseFormat?: "text" | "json";
}

export interface ChatResponse {
  content: string;
  model: string;
  usage: {
    inputTokens: number;
    outputTokens: number;
    totalTokens: number;
  };
  finishReason: "stop" | "max_tokens" | "tool_use";
  toolCalls?: ToolCall[];
}

export interface StreamChunk {
  content: string;
  done: boolean;
}

export interface ToolDefinition {
  name: string;
  description: string;
  parameters: Record<string, unknown>;
}

export interface ToolCall {
  id: string;
  name: string;
  arguments: Record<string, unknown>;
}

Provider 接口

// provider.ts - Provider 抽象接口

export interface AIProvider {
  readonly name: string;

  chat(options: ChatOptions): Promise<ChatResponse>;

  chatStream(
    options: ChatOptions
  ): AsyncGenerator<StreamChunk, void, unknown>;
}

实现各 Provider

Anthropic Provider

// providers/anthropic.ts
import Anthropic from "@anthropic-ai/sdk";
import type { AIProvider, ChatOptions, ChatResponse, StreamChunk } from "../types";

export class AnthropicProvider implements AIProvider {
  readonly name = "anthropic";
  private client: Anthropic;

  constructor(apiKey?: string) {
    this.client = new Anthropic({
      apiKey: apiKey || process.env.ANTHROPIC_API_KEY,
    });
  }

  async chat(options: ChatOptions): Promise<ChatResponse> {
    // 分离 system message
    const systemMessage = options.messages.find(m => m.role === "system");
    const nonSystemMessages = options.messages.filter(m => m.role !== "system");

    const response = await this.client.messages.create({
      model: options.model,
      max_tokens: options.maxTokens || 4096,
      temperature: options.temperature,
      top_p: options.topP,
      system: systemMessage ? String(systemMessage.content) : undefined,
      messages: nonSystemMessages.map(m => ({
        role: m.role as "user" | "assistant",
        content: this.formatContent(m.content),
      })),
      tools: options.tools?.map(t => ({
        name: t.name,
        description: t.description,
        input_schema: t.parameters as Anthropic.Tool.InputSchema,
      })),
    });

    // 提取文本内容
    const textContent = response.content
      .filter(block => block.type === "text")
      .map(block => block.text)
      .join("");

    // 提取工具调用
    const toolCalls = response.content
      .filter(block => block.type === "tool_use")
      .map(block => ({
        id: block.id,
        name: block.name,
        arguments: block.input as Record<string, unknown>,
      }));

    return {
      content: textContent,
      model: response.model,
      usage: {
        inputTokens: response.usage.input_tokens,
        outputTokens: response.usage.output_tokens,
        totalTokens: response.usage.input_tokens + response.usage.output_tokens,
      },
      finishReason: response.stop_reason === "tool_use" ? "tool_use" : "stop",
      toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
    };
  }

  async *chatStream(options: ChatOptions): AsyncGenerator<StreamChunk> {
    const systemMessage = options.messages.find(m => m.role === "system");
    const nonSystemMessages = options.messages.filter(m => m.role !== "system");

    const stream = this.client.messages.stream({
      model: options.model,
      max_tokens: options.maxTokens || 4096,
      temperature: options.temperature,
      system: systemMessage ? String(systemMessage.content) : undefined,
      messages: nonSystemMessages.map(m => ({
        role: m.role as "user" | "assistant",
        content: String(m.content),
      })),
    });

    for await (const event of stream) {
      if (
        event.type === "content_block_delta" &&
        event.delta.type === "text_delta"
      ) {
        yield { content: event.delta.text, done: false };
      }
    }
    yield { content: "", done: true };
  }

  private formatContent(content: string | ContentBlock[]): string | Anthropic.ContentBlockParam[] {
    if (typeof content === "string") return content;

    return content.map(block => {
      if (block.type === "text") {
        return { type: "text" as const, text: block.text! };
      }
      if (block.type === "image" && block.imageBase64) {
        return {
          type: "image" as const,
          source: {
            type: "base64" as const,
            media_type: (block.mediaType || "image/png") as "image/png",
            data: block.imageBase64,
          },
        };
      }
      return { type: "text" as const, text: "" };
    });
  }
}

OpenAI Provider

// providers/openai.ts
import OpenAI from "openai";
import type { AIProvider, ChatOptions, ChatResponse, StreamChunk } from "../types";

export class OpenAIProvider implements AIProvider {
  readonly name = "openai";
  private client: OpenAI;

  constructor(config?: { apiKey?: string; baseURL?: string }) {
    this.client = new OpenAI({
      apiKey: config?.apiKey || process.env.OPENAI_API_KEY,
      baseURL: config?.baseURL,
    });
  }

  async chat(options: ChatOptions): Promise<ChatResponse> {
    const response = await this.client.chat.completions.create({
      model: options.model,
      max_tokens: options.maxTokens || 4096,
      temperature: options.temperature,
      top_p: options.topP,
      messages: options.messages.map(m => ({
        role: m.role,
        content: String(m.content),
      })),
      tools: options.tools?.map(t => ({
        type: "function" as const,
        function: {
          name: t.name,
          description: t.description,
          parameters: t.parameters,
        },
      })),
      response_format: options.responseFormat === "json"
        ? { type: "json_object" }
        : undefined,
    });

    const choice = response.choices[0];
    const toolCalls = choice.message.tool_calls?.map(tc => ({
      id: tc.id,
      name: tc.function.name,
      arguments: JSON.parse(tc.function.arguments),
    }));

    return {
      content: choice.message.content || "",
      model: response.model,
      usage: {
        inputTokens: response.usage?.prompt_tokens || 0,
        outputTokens: response.usage?.completion_tokens || 0,
        totalTokens: response.usage?.total_tokens || 0,
      },
      finishReason: choice.finish_reason === "tool_calls" ? "tool_use" : "stop",
      toolCalls: toolCalls?.length ? toolCalls : undefined,
    };
  }

  async *chatStream(options: ChatOptions): AsyncGenerator<StreamChunk> {
    const stream = await this.client.chat.completions.create({
      model: options.model,
      max_tokens: options.maxTokens || 4096,
      temperature: options.temperature,
      messages: options.messages.map(m => ({
        role: m.role,
        content: String(m.content),
      })),
      stream: true,
    });

    for await (const chunk of stream) {
      const content = chunk.choices[0]?.delta?.content;
      if (content) {
        yield { content, done: false };
      }
    }
    yield { content: "", done: true };
  }
}

Ollama Provider(本地模型)

// providers/ollama.ts
import type { AIProvider, ChatOptions, ChatResponse, StreamChunk } from "../types";

export class OllamaProvider implements AIProvider {
  readonly name = "ollama";
  private baseURL: string;

  constructor(baseURL = "http://localhost:11434") {
    this.baseURL = baseURL;
  }

  async chat(options: ChatOptions): Promise<ChatResponse> {
    const response = await fetch(`${this.baseURL}/api/chat`, {
      method: "POST",
      headers: { "Content-Type": "application/json" },
      body: JSON.stringify({
        model: options.model,
        messages: options.messages.map(m => ({
          role: m.role,
          content: String(m.content),
        })),
        stream: false,
        options: {
          temperature: options.temperature,
          top_p: options.topP,
          num_predict: options.maxTokens,
        },
        format: options.responseFormat === "json" ? "json" : undefined,
      }),
    });

    const data = await response.json();

    return {
      content: data.message.content,
      model: data.model,
      usage: {
        inputTokens: data.prompt_eval_count || 0,
        outputTokens: data.eval_count || 0,
        totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0),
      },
      finishReason: "stop",
    };
  }

  async *chatStream(options: ChatOptions): AsyncGenerator<StreamChunk> {
    const response = await fetch(`${this.baseURL}/api/chat`, {
      method: "POST",
      headers: { "Content-Type": "application/json" },
      body: JSON.stringify({
        model: options.model,
        messages: options.messages.map(m => ({
          role: m.role,
          content: String(m.content),
        })),
        stream: true,
        options: {
          temperature: options.temperature,
          num_predict: options.maxTokens,
        },
      }),
    });

    const reader = response.body!.getReader();
    const decoder = new TextDecoder();

    while (true) {
      const { done, value } = await reader.read();
      if (done) break;

      const lines = decoder.decode(value).split("\n").filter(Boolean);
      for (const line of lines) {
        const data = JSON.parse(line);
        if (data.message?.content) {
          yield { content: data.message.content, done: data.done };
        }
      }
    }
  }
}

统一客户端

// client.ts - 统一的 AI 客户端

import type { AIProvider, ChatOptions, ChatResponse, StreamChunk, Message } from "./types";
import { AnthropicProvider } from "./providers/anthropic";
import { OpenAIProvider } from "./providers/openai";
import { OllamaProvider } from "./providers/ollama";

interface ClientConfig {
  defaultProvider: string;
  defaultModel: string;
  providers: Record<string, AIProvider>;
}

export class AIClient {
  private providers: Record<string, AIProvider>;
  private defaultProvider: string;
  private defaultModel: string;

  constructor(config: ClientConfig) {
    this.providers = config.providers;
    this.defaultProvider = config.defaultProvider;
    this.defaultModel = config.defaultModel;
  }

  /**
   * 快速创建一个预配置的客户端
   */
  static create(options?: {
    defaultProvider?: string;
    defaultModel?: string;
    ollamaURL?: string;
  }): AIClient {
    return new AIClient({
      defaultProvider: options?.defaultProvider || "anthropic",
      defaultModel: options?.defaultModel || "claude-sonnet-4-20250514",
      providers: {
        anthropic: new AnthropicProvider(),
        openai: new OpenAIProvider(),
        ollama: new OllamaProvider(options?.ollamaURL),
      },
    });
  }

  async chat(
    messages: Message[],
    options?: Partial<ChatOptions> & { provider?: string }
  ): Promise<ChatResponse> {
    const provider = this.getProvider(options?.provider);
    return provider.chat({
      model: options?.model || this.defaultModel,
      messages,
      maxTokens: options?.maxTokens,
      temperature: options?.temperature,
      topP: options?.topP,
      tools: options?.tools,
      responseFormat: options?.responseFormat,
    });
  }

  async *stream(
    messages: Message[],
    options?: Partial<ChatOptions> & { provider?: string }
  ): AsyncGenerator<StreamChunk> {
    const provider = this.getProvider(options?.provider);
    yield* provider.chatStream({
      model: options?.model || this.defaultModel,
      messages,
      maxTokens: options?.maxTokens,
      temperature: options?.temperature,
    });
  }

  private getProvider(name?: string): AIProvider {
    const providerName = name || this.defaultProvider;
    const provider = this.providers[providerName];
    if (!provider) {
      throw new Error(
        `Provider "${providerName}" 未注册。可用: ${Object.keys(this.providers).join(", ")}`
      );
    }
    return provider;
  }
}

添加错误处理和重试

// middleware/retry.ts

interface RetryConfig {
  maxRetries: number;
  baseDelay: number;     // 毫秒
  maxDelay: number;
  retryableErrors: string[];
}

const DEFAULT_RETRY_CONFIG: RetryConfig = {
  maxRetries: 3,
  baseDelay: 1000,
  maxDelay: 30000,
  retryableErrors: ["rate_limit", "timeout", "server_error"],
};

export function withRetry(provider: AIProvider, config = DEFAULT_RETRY_CONFIG): AIProvider {
  return {
    name: provider.name,

    async chat(options) {
      let lastError: Error | null = null;

      for (let attempt = 0; attempt <= config.maxRetries; attempt++) {
        try {
          return await provider.chat(options);
        } catch (error) {
          lastError = error as Error;
          const errorType = classifyError(error);

          if (!config.retryableErrors.includes(errorType)) {
            throw error; // 不可重试的错误直接抛出
          }

          if (attempt < config.maxRetries) {
            const delay = Math.min(
              config.baseDelay * Math.pow(2, attempt),
              config.maxDelay
            );
            console.warn(
              `[${provider.name}] 第 ${attempt + 1} 次重试,等待 ${delay}ms...`
            );
            await sleep(delay);
          }
        }
      }

      throw lastError;
    },

    async *chatStream(options) {
      // 流式响应的重试更复杂,这里简化处理
      yield* provider.chatStream(options);
    },
  };
}

function classifyError(error: unknown): string {
  const message = String(error);
  if (message.includes("rate") || message.includes("429")) return "rate_limit";
  if (message.includes("timeout") || message.includes("ETIMEDOUT")) return "timeout";
  if (message.includes("500") || message.includes("502") || message.includes("503")) return "server_error";
  return "unknown";
}

function sleep(ms: number): Promise<void> {
  return new Promise(resolve => setTimeout(resolve, ms));
}

添加 Fallback 支持

// middleware/fallback.ts

export function withFallback(
  primary: AIProvider,
  fallbacks: AIProvider[]
): AIProvider {
  return {
    name: `${primary.name}+fallback`,

    async chat(options) {
      const providers = [primary, ...fallbacks];

      for (let i = 0; i < providers.length; i++) {
        try {
          const response = await providers[i].chat(options);
          if (i > 0) {
            console.warn(
              `使用了 Fallback Provider: ${providers[i].name}`
            );
          }
          return response;
        } catch (error) {
          console.error(
            `Provider ${providers[i].name} 失败:`,
            error
          );
          if (i === providers.length - 1) throw error;
        }
      }

      throw new Error("所有 Provider 都失败了");
    },

    async *chatStream(options) {
      // Fallback 对流式响应的处理
      try {
        yield* primary.chatStream(options);
      } catch {
        for (const fallback of fallbacks) {
          try {
            yield* fallback.chatStream(options);
            return;
          } catch {
            continue;
          }
        }
        throw new Error("所有 Provider 的流式响应都失败了");
      }
    },
  };
}

完整使用示例

基本用法

import { AIClient } from "./client";

// 创建客户端
const ai = AIClient.create({
  defaultProvider: "anthropic",
  defaultModel: "claude-sonnet-4-20250514",
});

// 简单对话
const response = await ai.chat([
  { role: "system", content: "你是一个有帮助的助手。" },
  { role: "user", content: "什么是 TypeScript?" },
]);

console.log(response.content);
console.log(`Token 使用: ${response.usage.totalTokens}`);

切换 Provider

// 使用 OpenAI
const openaiResponse = await ai.chat(
  [{ role: "user", content: "Hello!" }],
  { provider: "openai", model: "gpt-4o" }
);

// 使用本地 Ollama
const localResponse = await ai.chat(
  [{ role: "user", content: "你好!" }],
  { provider: "ollama", model: "qwen2.5:7b" }
);

流式响应

// 流式输出
const messages = [{ role: "user" as const, content: "写一首关于编程的诗" }];

for await (const chunk of ai.stream(messages)) {
  if (!chunk.done) {
    process.stdout.write(chunk.content);
  }
}
console.log(); // 换行

工具调用

const tools = [{
  name: "get_weather",
  description: "获取天气信息",
  parameters: {
    type: "object",
    properties: {
      city: { type: "string", description: "城市名" },
    },
    required: ["city"],
  },
}];

const response = await ai.chat(
  [{ role: "user", content: "北京天气怎么样?" }],
  { tools }
);

if (response.toolCalls) {
  console.log("需要调用工具:", response.toolCalls);
  // 执行工具调用,然后把结果返回给模型
}

带重试和 Fallback 的生产配置

import { AnthropicProvider } from "./providers/anthropic";
import { OpenAIProvider } from "./providers/openai";
import { OllamaProvider } from "./providers/ollama";
import { withRetry } from "./middleware/retry";
import { withFallback } from "./middleware/fallback";

// 生产环境配置
const anthropic = withRetry(new AnthropicProvider());
const openai = withRetry(new OpenAIProvider());
const ollama = new OllamaProvider();

// 主用 Anthropic,Fallback 到 OpenAI,最后兜底到本地
const provider = withFallback(anthropic, [openai, ollama]);

const ai = new AIClient({
  defaultProvider: "primary",
  defaultModel: "claude-sonnet-4-20250514",
  providers: { primary: provider },
});

进阶:添加可观测性

// middleware/logging.ts

export function withLogging(provider: AIProvider): AIProvider {
  return {
    name: provider.name,

    async chat(options) {
      const startTime = Date.now();
      const requestId = crypto.randomUUID();

      console.log(`[${requestId}] 请求 ${provider.name}/${options.model}`);

      try {
        const response = await provider.chat(options);
        const duration = Date.now() - startTime;

        console.log(
          `[${requestId}] 完成 | ` +
          `${duration}ms | ` +
          `${response.usage.totalTokens} tokens | ` +
          `${response.finishReason}`
        );

        return response;
      } catch (error) {
        const duration = Date.now() - startTime;
        console.error(
          `[${requestId}] 失败 | ${duration}ms | ${error}`
        );
        throw error;
      }
    },

    async *chatStream(options) {
      yield* provider.chatStream(options);
    },
  };
}

总结

封装统一的模型调用层,让我们可以:

  • 用一套代码支持 Claude、OpenAI、本地模型的无缝切换
  • 通过中间件模式添加重试、Fallback、日志等横切关注点
  • 业务代码完全不感知底层 Provider 的差异
  • 快速响应模型市场的变化,随时切换到更好或更便宜的模型

核心设计原则很简单:定义统一的接口,让每个 Provider 实现这个接口,然后通过组合来增强功能。

好的抽象不是隐藏复杂性,而是把复杂性放到正确的位置。统一接口让你在需要简单时简单,在需要控制时有控制。

评论

加载中...

相关文章

分享:

评论

加载中...