聊聊langchain4j的AiServices
序
本文主要研究一下langchain4j的AiServices
示例
原生版本
public interface Assistant {
String chat(String userMessage);
}
构建
Assistant assistant = AiServices.create(Assistant.class, chatLanguageModel);
String resp = assistant.chat(userMessage);
spring-boot版本
@AiService
public interface AssistantV2 {
@SystemMessage("You are a polite assistant")
String chat(String userMessage);
}
之后直接像使用托管的bean一样注入就可以使用
@Autowired
AssistantV2 assistantV2;
@GetMapping("/ai-service")
public String aiService(@RequestParam("prompt") String prompt) {
return assistantV2.chat(prompt);
}
源码
AiServices
dev/langchain4j/service/AiServices.java
public abstract class AiServices {
protected static final String DEFAULT = "default";
protected final AiServiceContext context;
private boolean retrieverSet = false;
private boolean contentRetrieverSet = false;
private boolean retrievalAugmentorSet = false;
protected AiServices(AiServiceContext context) {
this.context = context;
}
/**
* Creates an AI Service (an implementation of the provided interface), that is backed by the provided chat model.
* This convenience method can be used to create simple AI Services.
* For more complex cases, please use {@link #builder}.
*
* @param aiService The class of the interface to be implemented.
* @param chatLanguageModel The chat model to be used under the hood.
* @return An instance of the provided interface, implementing all its defined methods.
*/
public static T create(Class aiService, ChatLanguageModel chatLanguageModel) {
return builder(aiService).chatLanguageModel(chatLanguageModel).build();
}
/**
* Creates an AI Service (an implementation of the provided interface), that is backed by the provided streaming chat model.
* This convenience method can be used to create simple AI Services.
* For more complex cases, please use {@link #builder}.
*
* @param aiService The class of the interface to be implemented.
* @param streamingChatLanguageModel The streaming chat model to be used under the hood.
* The return type of all methods should be {@link TokenStream}.
* @return An instance of the provided interface, implementing all its defined methods.
*/
public static T create(Class aiService, StreamingChatLanguageModel streamingChatLanguageModel) {
return builder(aiService)
.streamingChatLanguageModel(streamingChatLanguageModel)
.build();
}
/**
* Begins the construction of an AI Service.
*
* @param aiService The class of the interface to be implemented.
* @return builder
*/
public static AiServices builder(Class aiService) {
AiServiceContext context = new AiServiceContext(aiService);
for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
return factory.create(context);
}
return new DefaultAiServices<>(context);
}
/**
* Configures chat model that will be used under the hood of the AI Service.
*
* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured,
* but not both at the same time.
*
* @param chatLanguageModel Chat model that will be used under the hood of the AI Service.
* @return builder
*/
public AiServices chatLanguageModel(ChatLanguageModel chatLanguageModel) {
context.chatModel = chatLanguageModel;
return this;
}
/**
* Configures streaming chat model that will be used under the hood of the AI Service.
* The methods of the AI Service must return a {@link TokenStream} type.
*
* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured,
* but not both at the same time.
*
* @param streamingChatLanguageModel Streaming chat model that will be used under the hood of the AI Service.
* @return builder
*/
public AiServices streamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
context.streamingChatModel = streamingChatLanguageModel;
return this;
}
/**
* Configures the system message provider, which provides a system message to be used each time an AI service is invoked.
*
* When both {@code @SystemMessage} and the system message provider are configured,
* {@code @SystemMessage} takes precedence.
*
* @param systemMessageProvider A {@link Function} that accepts a chat memory ID
* (a value of a method parameter annotated with @{@link MemoryId})
* and returns a system message to be used.
* If there is no parameter annotated with {@code @MemoryId},
* the value of memory ID is "default".
* The returned {@link String} can be either a complete system message
* or a system message template containing unresolved template variables (e.g. "{{name}}"),
* which will be resolved using the values of method parameters annotated with @{@link V}.
* @return builder
*/
public AiServices systemMessageProvider(Function
AiServices是个抽象类,它提供了AiServices的builder方法,默认创建DefaultAiServices,它提供了设置chatLanguageModel、
streamingChatLanguageModel、systemMessageProvider、chatMemory、chatMemoryProvider、moderationModel、tools、toolProvider、contentRetriever、retrievalAugmentor方法。它定义了build抽象方法供子类去实现。
DefaultAiServices
dev/langchain4j/service/DefaultAiServices.java
class DefaultAiServices extends AiServices {
private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
private final Collection tokenStreamAdapters = loadFactories(TokenStreamAdapter.class);
DefaultAiServices(AiServiceContext context) {
super(context);
}
//......
public T build() {
performBasicValidation();
for (Method method : context.aiServiceClass.getMethods()) {
if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
throw illegalConfiguration(
"The @Moderate annotation is present, but the moderationModel is not set up. "
+ "Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
if (method.getReturnType() == Result.class
|| method.getReturnType() == List.class
|| method.getReturnType() == Set.class) {
TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
}
if (context.chatMemoryProvider == null) {
for (Parameter parameter : method.getParameters()) {
if (parameter.isAnnotationPresent(MemoryId.class)) {
throw illegalConfiguration(
"In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
context.aiServiceClass.getName());
}
}
}
}
Object proxyInstance = Proxy.newProxyInstance(
context.aiServiceClass.getClassLoader(),
new Class>[] {context.aiServiceClass},
new InvocationHandler() {
private final ExecutorService executor = Executors.newCachedThreadPool();
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
if (method.getDeclaringClass() == Object.class) {
// methods like equals(), hashCode() and toString() should not be handled by this proxy
return method.invoke(this, args);
}
validateParameters(method);
Object memoryId = findMemoryId(method, args).orElse(DEFAULT);
Optional systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
List chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
userMessage = (UserMessage) augmentationResult.chatMessage();
}
// TODO give user ability to provide custom OutputParser
Type returnType = method.getGenericReturnType();
boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType);
boolean supportsJsonSchema =
supportsJsonSchema(); // TODO should it be called for returnType==String?
Optional jsonSchema = Optional.empty();
if (supportsJsonSchema && !streaming) {
jsonSchema = jsonSchemaFrom(returnType);
}
if ((!supportsJsonSchema || jsonSchema.isEmpty()) && !streaming) {
// TODO append after storing in the memory?
userMessage = appendOutputFormatInstructions(returnType, userMessage);
}
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
systemMessage.ifPresent(chatMemory::add);
chatMemory.add(userMessage);
}
List messages;
if (context.hasChatMemory()) {
messages = context.chatMemory(memoryId).messages();
} else {
messages = new ArrayList<>();
systemMessage.ifPresent(messages::add);
messages.add(userMessage);
}
Future moderationFuture = triggerModerationIfNeeded(method, messages);
ToolExecutionContext toolExecutionContext =
context.toolService.executionContext(memoryId, userMessage);
if (streaming) {
TokenStream tokenStream = new AiServiceTokenStream(
messages,
toolExecutionContext.toolSpecifications(),
toolExecutionContext.toolExecutors(),
augmentationResult != null ? augmentationResult.contents() : null,
context,
memoryId);
// TODO moderation
if (returnType == TokenStream.class) {
return tokenStream;
} else {
return adapt(tokenStream, returnType);
}
}
ResponseFormat responseFormat = null;
if (supportsJsonSchema && jsonSchema.isPresent()) {
responseFormat = ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build();
}
ChatRequestParameters parameters = ChatRequestParameters.builder()
.toolSpecifications(toolExecutionContext.toolSpecifications())
.responseFormat(responseFormat)
.build();
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
verifyModerationIfNeeded(moderationFuture);
ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
chatResponse,
parameters,
messages,
context.chatModel,
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
memoryId,
toolExecutionContext.toolExecutors());
chatResponse = toolExecutionResult.chatResponse();
FinishReason finishReason = chatResponse.metadata().finishReason();
Response response = Response.from(
chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);
Object parsedResponse = serviceOutputParser.parse(response, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
.tokenUsage(toolExecutionResult.tokenUsageAccumulator())
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(finishReason)
.toolExecutions(toolExecutionResult.toolExecutions())
.build();
} else {
return parsedResponse;
}
}
private boolean canAdaptTokenStreamTo(Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return true;
}
}
return false;
}
private Object adapt(TokenStream tokenStream, Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return tokenStreamAdapter.adapt(tokenStream);
}
}
throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
}
private boolean supportsJsonSchema() {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}
private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
String text = userMessage.singleText() + outputFormatInstructions;
if (isNotNullOrBlank(userMessage.name())) {
userMessage = UserMessage.from(userMessage.name(), text);
} else {
userMessage = UserMessage.from(text);
}
return userMessage;
}
private Future triggerModerationIfNeeded(Method method, List messages) {
if (method.isAnnotationPresent(Moderate.class)) {
return executor.submit(() -> {
List messagesToModerate = removeToolMessages(messages);
return context.moderationModel
.moderate(messagesToModerate)
.content();
});
}
return null;
}
});
return (T) proxyInstance;
}
//......
}
DefaultAiServices集成了AiServices,它的build方法主要通过Proxy.newProxyInstance来创建实现类,InvocationHandler的实现主要是处理systemMessage、userMessage、构建chatMemory、toolExecutionContext,最后构建ChatRequest,通过context.chatModel.chat(chatRequest)执行请求,然后解析和适配输出。
小结
langchain4j提供了诸如ChatLanguageModel, ChatMessage, ChatMemory的low level的组件,也提供了诸如Chains和AI Services这样的high level的组件,用于协同多个组件(提示模版、ChatMemory、LLM、输出解析、RAG组件:嵌入模型和评分)一起。其中Chains是从Python的LangChain移植过来的,不过不方便自定义,于是后续不再继续添加新增功能了。langchain4j提供了AI Services来取代Chains,它有点类似于JPA或者Retrofit,通过简单声明接口就可以自动生成代理实现类,它可以处理LLM输入的格式化,LLM输出的解析,ChatMemory、Tools、RAG。
langchain4j提供了AiServices来创建DefaultAiServices,它默认是通过JDK的Proxy.newProxyInstance创建了实现类。
doc
- ai-services