Skip to main content

WTF Langchain极简入门: 09. 回调 (Callback)

最近在学习Langchain框架,顺手写一个“WTF Langchain极简入门”,供小白们使用(编程大佬可以另找教程)。本教程默认以下前提:

  • 使用Python版本的Langchain
  • LLM使用OpenAI的模型
  • Langchain目前还处于快速发展阶段,版本迭代频繁,为避免示例代码失效,本教程统一使用版本 0.0.235

根据Langchain的代码约定,Python版本 ">=3.8.1,<4.0"。

推特:@verysmallwoods

所有代码和教程开源在github: github.com/sugarforever/wtf-langchain


简介

CallbackLangChain 提供的回调机制,允许我们在 LLM 应用程序的各个阶段使用 Hook(钩子)。这对于记录日志、监控、流式传输等任务非常有用。这些任务的执行逻辑由回调处理器(CallbackHandler)定义。

Python 程序中, 回调处理器通过继承 BaseCallbackHandler 来实现。BaseCallbackHandler 接口对每一个可订阅的事件定义了一个回调函数。BaseCallbackHandler 的子类可以实现这些回调函数来处理事件。当事件触发时,LangChain 的回调管理器 CallbackManager 会调用相应的回调函数。

以下是 BaseCallbackHandler 的定义。请参考源代码

class BaseCallbackHandler:
"""Base callback handler that can be used to handle callbacks from langchain."""

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""

def on_chat_model_start(
self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any
) -> Any:
"""Run when Chat Model starts running."""

def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""

def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when LLM errors."""

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
"""Run when chain starts running."""

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""

def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when chain errors."""

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""

def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when tool errors."""

def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""

LangChain 内置支持了一系列回调处理器,我们也可以按需求自定义处理器,以实现特定的业务。

内置处理器

StdOutCallbackHandlerLangChain 所支持的最基本的处理器。它将所有的回调信息打印到标准输出。这对于调试非常有用。

LangChain 链的基类 Chain 提供了一个 callbacks 参数来指定要使用的回调处理器。请参考Chain源码,其中代码片段为:

class Chain(Serializable, ABC):
"""Abstract base class for creating structured sequences of calls to components.
...
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""

用法如下:

from langchain.callbacks import StdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

handler = StdOutCallbackHandler()
llm = OpenAI()
prompt = PromptTemplate.from_template("Who is {name}?")
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
chain.run(name="Super Mario")

你应该期望如下输出:

> Entering new LLMChain chain...
Prompt after formatting:
Who is Super Mario?

> Finished chain.

\n\nSuper Mario is the protagonist of the popular video game franchise of the same name created by Nintendo. He is a fictional character who stars in video games, television shows, comic books, and films. He is a plumber who is usually portrayed as a portly Italian-American, who is often accompanied by his brother Luigi. He is well known for his catchphrase "It\'s-a me, Mario!"

自定义处理器

我们可以通过继承 BaseCallbackHandler 来实现自定义的回调处理器。下面是一个简单的例子,TimerHandler 将跟踪 ChainLLM 交互的起止时间,并统计每次交互的处理耗时。

from langchain.callbacks.base import BaseCallbackHandler
import time

class TimerHandler(BaseCallbackHandler):

def __init__(self) -> None:
super().__init__()
self.previous_ms = None
self.durations = []

def current_ms(self):
return int(time.time() * 1000 + time.perf_counter() % 1 * 1000)

def on_chain_start(self, serialized, inputs, **kwargs) -> None:
self.previous_ms = self.current_ms()

def on_chain_end(self, outputs, **kwargs) -> None:
if self.previous_ms:
duration = self.current_ms() - self.previous_ms
self.durations.append(duration)

def on_llm_start(self, serialized, prompts, **kwargs) -> None:
self.previous_ms = self.current_ms()

def on_llm_end(self, response, **kwargs) -> None:
if self.previous_ms:
duration = self.current_ms() - self.previous_ms
self.durations.append(duration)

llm = OpenAI()
timerHandler = TimerHandler()
prompt = PromptTemplate.from_template("What is the HEX code of color {color_name}?")
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[timerHandler])
response = chain.run(color_name="blue")
print(response)
response = chain.run(color_name="purple")
print(response)

timerHandler.durations

你应该期望如下输出:

The HEX code for blue is #0000FF.
The HEX code of the color purple is #800080.
[1589, 1097]

回调处理器的适用场景

通过 LLMChain 的构造函数参数设置 callbacks 仅仅是众多适用场景之一。接下来我们简明地列出其他使用场景和示例代码。

对于 ModelAgentTool,以及 Chain 都可以通过以下方式设置回调处理器:

1. 构造函数参数 callbacks 设置

关于 Chain,以 LLMChain 为例,请参考本讲上一部分内容。注意在 Chain 上的回调器监听的是 chain 相关的事件,因此回调器的如下函数会被调用:

  • on_chain_start
  • on_chain_end
  • on_chain_error

AgentTool,以及 Chain 上的回调器会分别被调用相应的回调函数。

下面分享关于 Modelcallbacks 的使用示例:

timerHandler = TimerHandler()
llm = OpenAI(callbacks=[timerHandler])
response = llm.predict("What is the HEX code of color BLACK?")
print(response)

timerHandler.durations

你应该期望看到类似如下的输出:

['What is the HEX code of color BLACK?']
generations=[[Generation(text='\n\nThe hex code of black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 21, 'completion_tokens': 11}, 'model_name': 'text-davinci-003'} run=None


The hex code of black is #000000.

[1223]

2. 通过运行时的函数调用

ModelAgentTool,以及 Chain 的请求执行函数都接受 callbacks 参数,比如 LLMChainrun 函数,OpenAIpredict 函数,等都能接受 callbacks 参数,在运行时指定回调处理器。

OpenAI 模型类为例:

timerHandler = TimerHandler()
llm = OpenAI()
response = llm.predict("What is the HEX code of color BLACK?", callbacks=[timerHandler])
print(response)

timerHandler.durations

你应该同样期望如下输出:

['What is the HEX code of color BLACK?']
generations=[[Generation(text='\n\nThe hex code of black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 21, 'completion_tokens': 11}, 'model_name': 'text-davinci-003'} run=None

The hex code of black is #000000.

[1138]

关于 AgentTool 等的使用,请参考官方文档API。

总结

本节课程中,我们学习了什么是 Callback 回调,如何使用回调处理器,以及在哪些场景下可以接入回调处理器。下一讲,我们将一起完成一个完整的应用案例,来巩固本系列课程的知识点。

本节课程的完整示例代码,请参考 09_Callbacks.ipynb

相关文档资料链接:

  1. Python Langchain官方文档