本文主要介绍一下通过Spring AI 来快速实现一个简单的类似OpenAI Agents SDK的功能
具体效果为在application.yml 中配置对应的agent 信息,如
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| spring: ai: agents: - name: historyTutor instructions: You provide assistance with historical queries. Explain important events and context clearly. handoffDescription: Specialist agent for historical questions
- name: mathTutor instructions: You provide help with math problems. Explain your reasoning at each step and include examples handoffDescription: Specialist agent for math questions
- name: triageAgent instructions: You determine which agent/tools to use based on the user's homework question handoffs: - historyTutor - mathTutor
|
使用时,按需注入即可:
1 2 3 4 5 6 7 8 9 10 11 12 13
| @Slf4j @RestController public class AgentController {
@Resource private Agent triageAgent;
@GetMapping("/triage") public Flux<String> triage(String input) { return triageAgent.asyncExecute(input); } }
|
下面看一下具体的实现
- 首先需要定义一个Agent类,其中包含与大模型交互等需要用到的信息
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
| import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.schema.JsonSchemaGenerator; import org.springframework.util.Assert; import reactor.core.publisher.Flux;
import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List;
@Slf4j public class Agent { @Setter private String name; @Setter private String instructions; @Setter private String handoffDescription; @Setter private ChatModel chatModel; @Setter private List<Object> tools = new ArrayList<>(); @Setter private List<Agent> handoffs = new ArrayList<>();
private ChatClient chatClient;
public void init() { Assert.notNull(chatModel, "ChatModel must not be null"); log.info("Initializing agent: {}", name);
final List<ToolCallback> list = handoffs.stream() .map(Agent::getToolCallback) .toList();
final ChatClient.Builder builder = ChatClient.builder(chatModel); chatClient = builder.defaultSystem(instructions) .defaultTools(tools.toArray()) .defaultAdvisors(new SimpleLoggerAdvisor()) .defaultToolCallbacks(list) .build(); }
public String execute(String input) { log.info("Executing agent: {}", name); return chatClient.prompt() .user( input) .call() .content(); }
public Flux<String> asyncExecute(String input) { log.info("async Executing agent: {}", name); return chatClient.prompt() .user(input) .stream() .content(); }
public ToolCallback getToolCallback() { return new AgentToolCallback(this); }
public static class AgentToolCallback implements ToolCallback {
private final Agent agent;
public AgentToolCallback(Agent agent) { this.agent = agent; }
@Override public ToolDefinition getToolDefinition() { final Method callMethod; try { callMethod = AgentToolCallback.class.getMethod("call", String.class); } catch (NoSuchMethodException e) { log.error("Error generating JSON schema for method: {}", e.getMessage(), e); throw new RuntimeException(e); } final String methodInput = JsonSchemaGenerator.generateForMethodInput(callMethod); return ToolDefinition.builder() .name(agent.name) .description(agent.handoffDescription) .inputSchema(methodInput) .build(); }
@Override public String call(String toolInput) { return agent.execute(toolInput); } } }
|
- 之后需要定义一个配置类,用于和application.yml 中配置信息进行一一映射
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import lombok.Data;
import java.util.ArrayList; import java.util.List;
@Data public class AgentConfig {
private String name; private String instructions; private String handoffDescription; private String chatModel;
private List<String> tools = new ArrayList<>();
private List<String> handoffs = new ArrayList<>(); }
|
- 需要实现
ImportBeanDefinitionRegistrar
来完成根据配置信息动态创建Agent bean 的过程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
| import com.github.zavier.spring.agents.agent.Agent; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.ManagedList; import org.springframework.boot.context.properties.bind.Bindable; import org.springframework.boot.context.properties.bind.Binder; import org.springframework.context.EnvironmentAware; import org.springframework.context.ResourceLoaderAware; import org.springframework.context.annotation.ImportBeanDefinitionRegistrar; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.env.Environment; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.core.type.AnnotationMetadata; import org.springframework.util.StreamUtils;
import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List;
public class AgentBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar, EnvironmentAware, BeanFactoryAware, PriorityOrdered, ResourceLoaderAware { private BeanFactory beanFactory; private Environment environment; private ResourceLoader resourceLoader;
@Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.beanFactory = beanFactory; }
@Override public void setEnvironment(Environment environment) { this.environment = environment; }
@Override public void setResourceLoader(ResourceLoader resourceLoader) { this.resourceLoader = resourceLoader; }
@Override public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { List<AgentConfig> agents = Binder.get(environment) .bind("spring.ai.agents", Bindable.listOf(AgentConfig.class)) .orElse(List.of());
for (AgentConfig agentConfig : agents) { String beanName = agentConfig.getName();
if (registry.containsBeanDefinition(beanName)) { throw new IllegalStateException("Bean already exists: " + beanName); }
BeanDefinition beanDefinition = createAgentBeanDefinition(agentConfig); registry.registerBeanDefinition(beanName, beanDefinition); } }
private BeanDefinition createAgentBeanDefinition(AgentConfig agentConfig) { BeanDefinitionBuilder builder = BeanDefinitionBuilder .genericBeanDefinition(Agent.class);
builder.addPropertyValue("name", agentConfig.getName()); builder.addPropertyValue("instructions", agentConfig.getInstructions());
if (agentConfig.getInstructions() != null && agentConfig.getInstructions().startsWith("classpath:")) { try { final Resource resource = resourceLoader.getResource(agentConfig.getInstructions()); String instructions = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); builder.addPropertyValue("instructions", instructions); } catch (IOException e) { throw new IllegalStateException("Failed to read instructions file: " + agentConfig.getInstructions(), e); } }
builder.addPropertyValue("handoffDescription", agentConfig.getHandoffDescription());
if (StringUtils.isNotBlank(agentConfig.getChatModel())) { builder.addPropertyReference("chatModel", agentConfig.getChatModel()); } else { builder.addPropertyReference("chatModel", "openAiChatModel"); }
if (agentConfig.getTools() != null && !agentConfig.getTools().isEmpty()) { ManagedList<RuntimeBeanReference> toolsList = new ManagedList<>(); for (String toolName : agentConfig.getTools()) { if (beanFactory.containsBean(toolName)) { toolsList.add(new RuntimeBeanReference(toolName)); } else { throw new IllegalStateException("Tool bean not found: " + toolName); } } builder.addPropertyValue("tools", toolsList); }
if (agentConfig.getHandoffs() != null && !agentConfig.getHandoffs().isEmpty()) { ManagedList<RuntimeBeanReference> handoffsList = new ManagedList<>(); for (String handoffName : agentConfig.getHandoffs()) { if (beanFactory.containsBean(handoffName)) { handoffsList.add(new RuntimeBeanReference(handoffName)); } else { throw new IllegalStateException("Handoff bean not found: " + handoffName); } } builder.addPropertyValue("handoffs", handoffsList); }
builder.setInitMethodName("init");
return builder.getBeanDefinition(); }
@Override public int getOrder() { return Ordered.LOWEST_PRECEDENCE; } }
|
- 最后,我们添加一个配置类来使AgentBeanDefinitionRegistrar 生效
1 2 3 4 5 6 7
| import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import;
@Configuration @Import(AgentBeanDefinitionRegistrar.class) public class AgentAutoConfiguration { }
|
这样在项目启动的过程中,会自动读取配置文件,创建相应的Agent Bean, 在需要的地方通过名字直接注入即可使用~
具体的全部代码可参考:https://github.com/zavier/spring-ai-agents