DataWhale学习打卡Tiny-universe-taks3-TinyAgent
- Agent简介
- Tiny-Agent实现
- 定义模型
- 构造工具Tools
- 构建Agent
手搓一个最小Agent系统,本次打卡全流程均参考DataWhale的Tiny-Agent完成DataWhale原文链接
Tiny-Agent是基于React方式实现的的Agent开发
Agent简介
在计算机科学和人工智能中,Agent是一种自主软件实体,能够在某个环境中感知、推理、决策和执行任务。它们可以根据环境状态和目标做出独立决策,通常具有学习和适应能力。
Agent的特点:
自主性:Agent能够自主执行任务,无需用户的持续干预。
感知能力:Agent可以通过传感器(如用户输入、API数据等)获取环境信息,比如通过Agent调用一些天气相关的API让模型可以得到实时的天气信息。
智能决策:基于获取的信息和预设的规则或学习到的知识,Agent可以进行推理和决策。
学习能力:通过与环境的互动,Agent可以更新其知识库,提高未来决策的准确性,比较著名的我觉得斯坦福小镇应该算一个经典的案例。
Tiny-Agent实现
定义模型
首先构建一个Agent系统肯定需要围绕着模型来进行
from typing import Dict, List, Optional, Tuple, Unionimport torch
from transformers import AutoTokenizer, AutoModelForCausalLM# 先创建一个BaseModel类,我们可以在这个类中定义一些基本的方法,比如chat方法和load_model方法,方便以后扩展使用其他模型。
class BaseModel:def __init__(self, path: str = '') -> None:self.path = pathdef chat(self, prompt: str, history: List[dict]):passdef load_model(self):pass# 创建一个InternLM2类,这个类继承自BaseModel类,我们在这个类中实现chat方法和load_model方法。和正常加载InternLM2模型一样,来做一个简单的加载和返回即可。
class InternLM2Chat(BaseModel):def __init__(self, path: str = '') -> None:super().__init__(path)self.load_model()def load_model(self):print('================ Loading model ================')self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda().eval()print('================ Model loaded ================')def chat(self, prompt: str, history: List[dict], meta_instruction:str ='') -> str:response, history = self.model.chat(self.tokenizer, prompt, history, temperature=0.1, meta_instruction=meta_instruction)return response, history
构造工具Tools
定义工具函数为谷歌搜索,后面我也会自己尝试一下其他工具,比如天气或者一些数学计算的工具,现在先按照DataWhale提供的base代码尝试。
class Tools:def __init__(self) -> None:self.toolConfig = self._tools()def _tools(self):tools = [{'name_for_human': '谷歌搜索','name_for_model': 'google_search','description_for_model': '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。','parameters': [{'name': 'search_query','description': '搜索关键词或短语','required': True,'schema': {'type': 'string'},}],}]return toolsdef google_search(self, search_query: str):url = "https://google.serper.dev/search"payload = json.dumps({"q": search_query})headers = {'X-API-KEY': '修改为你自己的key','Content-Type': 'application/json'}response = requests.request("POST", url, headers=headers, data=payload).json()return response['organic'][0]['snippet']
构建Agent
使用React范式构建Agent系统
class Agent:def __init__(self, path: str = '') -> None:self.path = pathself.tool = Tools()self.system_prompt = self.build_system_input()self.model = InternLM2Chat(path)def build_system_input(self):tool_descs, tool_names = [], []for tool in self.tool.toolConfig:tool_descs.append(TOOL_DESC.format(**tool))tool_names.append(tool['name_for_model'])tool_descs = '\n\n'.join(tool_descs)tool_names = ','.join(tool_names)sys_prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names)return sys_promptdef parse_latest_plugin_call(self, text):plugin_name, plugin_args = '', ''i = text.rfind('\nAction:')j = text.rfind('\nAction Input:')k = text.rfind('\nObservation:')if 0 <= i < j: # If the text has `Action` and `Action input`,if k < j: # but does not contain `Observation`,text = text.rstrip() + '\nObservation:' # Add it back.k = text.rfind('\nObservation:')plugin_name = text[i + len('\nAction:') : j].strip()plugin_args = text[j + len('\nAction Input:') : k].strip()text = text[:k]return plugin_name, plugin_args, textdef call_plugin(self, plugin_name, plugin_args):plugin_args = json5.loads(plugin_args)if plugin_name == 'google_search':return '\nObservation:' + self.tool.google_search(**plugin_args)def text_completion(self, text, history=[]):text = "\nQuestion:" + textresponse, his = self.model.chat(text, history, self.system_prompt)print(response)plugin_name, plugin_args, response = self.parse_latest_plugin_call(response)if plugin_name:response += self.call_plugin(plugin_name, plugin_args)response, his = self.model.chat(response, history, self.system_prompt)return response, his