Text2SQL:CHESS论文源代码笔记

2024-12-17  本文已影响0人  AlienPaul

前言

CHESS即CONTEXTUAL HARNESSING FOR EFFICIENT SQL SYNTHESIS。是斯坦福大学提出的一种自然语言转化为SQL的方法。

本篇主要围绕论文提及的主要流程的代码实现开展分析。因理解所限,文章会有错误和疏漏,请读者指正。

执行流程

一路跟随main.py的入口 -> main() -> run_manager.run_tasks(),最后跟踪到worker方法,如下所示。该方法构建了整个Text2SQL的执行流程。

def worker(self, task: Task) -> Tuple[Any, str, int]:
    """
    Worker function to process a single task.
    
    Args:
        task (Task): The task to be processed.
    
    Returns:
        tuple: The state of the task processing and task identifiers.
    """
    print(f"Initializing task: {task.db_id} {task.question_id}")
    DatabaseManager(db_mode=self.args.data_mode, db_id=task.db_id)
    logger = Logger(db_id=task.db_id, question_id=task.question_id, result_directory=self.result_directory)
    logger._set_log_level(self.args.log_level)
    logger.log(f"Processing task: {task.db_id} {task.question_id}", "info")

    # 构建执行流程
    # 从run/configs中读取config
    team = build_team(self.args.config)
    thread_id = f"{self.args.run_start_time}_{task.db_id}_{task.question_id}"
    thread_config = {"configurable": {"thread_id": thread_id}}
    state_values =  SystemState(task=task, 
                                tentative_schema=DatabaseManager().get_db_schema(), 
                                execution_history=[])
    thread_config["recursion_limit"] = 50
    # 调用StateGraph,执行流程
    for state_dict in team.stream(state_values, thread_config, stream_mode="values"):
        logger.log("________________________________________________________________________________________")
        continue
    system_state = SystemState(**state_dict)
    return system_state, task.db_id, task.question_id

接下来的分析从build_team方法展开。

CHESSTeamBuilder

build_team方法调用了CHESSTeamBuilderbuild方法。如下所示:

def build(self):
    # 从run/configs中读取config配置中的team_agents项
    agents = {agent_name: agent_config for agent_name, agent_config in self.config["team_agents"].items() 
              if agent_name in AGENT_CLASSES}
    # 添加agent
    self._add_agents(agents)
    # 增加evaluation节点
    self.team.add_node("evaluation", ExecutionAccuracy())
    agents_with_evaluation = list(agents.keys()) + ["evaluation"]
    # 配置StateGraph中第一个节点为配置文件中的第一个agent
    self.team.set_entry_point(agents_with_evaluation[0])
    # 按照配置文件中的agent顺序,串联各个agents
    connections = [(agents_with_evaluation[i], agents_with_evaluation[i+1]) 
                   for i in range(len(agents_with_evaluation)-1)]
    # 增加结束节点
    connections += [(agents_with_evaluation[-1], END)]
    self._add_connections(connections)

Langchain中的agent可以理解为处理节点,connection为处理节点之间的连接关系。上面代码将agent按照顺序通过connections串联起来。

根据代码,CHESS中提供了如下4中agent,即CHESS的4个关键步骤。

AGENT_CLASSES = {
    "information_retriever": InformationRetriever,
    "schema_selector": SchemaSelector,
    "candidate_generator": CandidateGenerator,
    "unit_tester": UnitTester
}

这4个关键步骤为:

InformationRetriever

根据代码,InformationRetriever包含如下三个工具:

def __init__(self, config: dict):
    """Initialize the tools needed for information retrieval"""
    super().__init__(
        name="Information Retriever",
        task=("retrieve the most important entities and context relevant to the keywords of the question, through ",
                     "extracting keywords, retrieving entities, and retrieving context"),
        config=config
    )
    
    self.tools = {
        "extract_keywords": ExtractKeywords(**config["tools"]["extract_keywords"]),
        "retrieve_entity": RetrieveEntity(**config["tools"]["retrieve_entity"]),
        "retrieve_context": RetrieveContext(**config["tools"]["retrieve_context"])
    }

接下来对这三个工具展开分析。

ExtractKeywords

ExtractKeyword试图从SQL中解析出关键字。

论文原话:To search for the similar values in the database and schema description, the agent needs to extract the main keywords from the natural language question.
This tool uses a few-shot LLM call to extract the primary keywords and key phrases from the input.

Prompt位于template_extract_keywords.txt中。

def _run(self, state: SystemState):
    request_kwargs = {
        "QUESTION": state.task.question,
        "HINT": state.task.evidence,
    }
    
    # 调用LLM,返回SQL中的关键字
    response = async_llm_chain_call(
        prompt=get_prompt(template_name=self.template_name),
        engine=get_llm_chain(**self.engine_config),
        parser=get_parser(self.parser_name),
        request_list=[request_kwargs],
        step=self.tool_name,
        sampling_count=1
    )[0]
    
    state.keywords = response[0]

RetrieveEntity

RetrieveEntity根据proprecess操作(database schema和数据信息需要预先保存在向量库和LSH索引中,后面有提及)保存的LSH索引,找出和关键字相似的字段和字段值。同时使用文本相似算法和语义相似算法。

def _run(self, state: SystemState):
    """
    Executes the entity retrieval process.
    
    Args:
        state (SystemState): The current system state.
    """
    
    # 找到相似的列
    state.similar_columns = self._get_similar_columns(keywords=state.keywords, question=state.task.question, hint=state.task.evidence)
    
    # 返回相似的schema和样例数据
    state.schema_with_examples = self._get_similar_entities(keywords=state.keywords)

### Column name similarity ###

_get_similar_columns方法找到相似的数据表列,分析如下:

def _get_similar_columns(self, keywords: List[str], question: str, hint: str) -> Dict[str, List[str]]:
    """
    Finds columns similar to given keywords based on question and hint.

    Args:
        keywords (List[str]): The list of keywords.
        question (str): The question string.
        hint (str): The hint string.

    Returns:
        Dict[str, List[str]]: A dictionary mapping table names to lists of similar column names.
    """
    selected_columns = {}
    # 下面分析
    similar_columns = self._get_similar_column_names(keywords=keywords, question=question, hint=hint)
    # 组装返回值,保持表名和列名的二级关系
    for table_name, column_name in similar_columns:
        if table_name not in selected_columns:
            selected_columns[table_name] = []
        if column_name not in selected_columns[table_name]:
            selected_columns[table_name].append(column_name)
    return selected_columns

_get_similar_column_names方法计算keyword和column文本和语义上的相似度。按照相似度从高到低排名,输出表名和列名。代码如下:

def _get_similar_column_names(self, keywords: str, question: str, hint: str) -> List[Tuple[str, str]]:
    """
    Finds column names similar to given keywords based on question and hint.

    Args:
        keywords (str): The list of keywords.
        question (str): The question string.
        hint (str): The hint string.

    Returns:
        List[Tuple[str, str]]: A list of tuples containing table and column names.
    """
    potential_column_names = []
    for keyword in keywords:
        keyword = keyword.strip()
        # 加入keyword到潜在可能的列名
        potential_column_names.append(keyword)
        
        # 如果keyword包含引号,将keyword以引号为分割拆开
        column, value = self._column_value(keyword)
        # 等号左侧的值为column,如果keyword中包含有等号,则column不为空
        if column:
            potential_column_names.append(column)

        # 如果keyword有小括号包裹,去掉小括号返回
        potential_column_names.extend(self._extract_paranthesis(keyword))

        # 如果keyword有空格,将其按照空格分割
        if " " in keyword:
            potential_column_names.extend(part.strip() for part in keyword.split())
    # 读取数据库schema:所有数据库表和字段的关系
    schema = DatabaseManager().get_db_schema()
    
    to_embed_strings = []

    # Prepare the list of strings to embed
    
    # 拼装所有column,格式为table.column
    column_strings = [f"`{table}`.`{column}`" for table, columns in schema.items() for column in columns]
    question_hint_string = f"{question} {hint}"

    to_embed_strings.extend(column_strings)
    to_embed_strings.append(question_hint_string)

    # Get embeddings
    # 将table column信息和question hint转换为embedding
    embeddings = self.embedding_function.embed_documents(to_embed_strings)

    # Separate embeddings
    # 获取表名列名的embedding
    column_embeddings = embeddings[:-1]  # All except the last one
    # 获取question和hint的embedding
    question_hint_embedding = embeddings[-1]  # The last one

    # Compute similarities
    similar_column_names = []
    for i, column_embedding in enumerate(column_embeddings):
        table, column = column_strings[i].split('.')[0].strip('`'), column_strings[i].split('.')[1].strip('`')
        for potential_column_name in potential_column_names:
            # 对比潜在可能的列名和真实的列名
            # 使用difflib.SequenceMatcher计算文本的相似度,相似度大于0.9的返回True。threshold默认值为0.9
            if self._does_keyword_match_column(potential_column_name, column):
                # 计算column名字和question hint的语义相似度。余弦相似度
                similarity_score = np.dot(column_embedding, question_hint_embedding)
                # 将语义相似度值和table,column名字的关系保存起来
                similar_column_names.append((table, column, similarity_score))

    # 按照相似度排序,从大到小
    similar_column_names.sort(key=lambda x: x[2], reverse=True)
    # 返回相似的表名和列名,舍弃掉相似度数值(前面已经排序)
    table_column_pairs = list(set([(table, column) for table, column, _ in similar_column_names]))
    return table_column_pairs

### Entity similarity ###

_get_similar_entities方法使用文本相似度和语义相似度方法,找到相似的字段值:

def _get_similar_entities(self, keywords: List[str]) -> Dict[str, Dict[str, List[str]]]:
    """
    Retrieves similar entities from the database based on keywords.

    Args:
        keywords (List[str]): The list of keywords.

    Returns:
        Dict[str, Dict[str, List[str]]]: A dictionary mapping table and column names to similar entities.
    """
    # 将keywords按照空格分割,加入到search_value中。如果keyword包含等号,则等号右侧的值也算是search_value
    to_seartch_values = self._get_to_search_values(keywords)
    # 通过LSH算法查询出相似的table和column和column值(数据)
    similar_entities_via_LSH = self._get_similar_entities_via_LSH(to_seartch_values)
    # 根据上面的结果,通过edit_distance再筛选出相似的table,column和column值。使用文本相似度方式
    similar_entities_via_edit_distance = self._get_similar_entities_via_edit_distance(similar_entities_via_LSH)
    # 根据上面的结果,通过语义相似度再筛选出相似的table,column和column值,保留余弦相似度大于0.6的结果
    similar_entities_via_embedding = self._get_similar_entities_via_embedding(similar_entities_via_edit_distance)
    
    selected_values = {}
    # 按照table column层级组装结果
    for entity in similar_entities_via_embedding:
        table_name = entity["table_name"]
        column_name = entity["column_name"]
        if table_name not in selected_values:
            selected_values[table_name] = {}
        if column_name not in selected_values[table_name]:
            selected_values[table_name][column_name] = []
        selected_values[table_name][column_name].append(entity)
    for table_name, column_values in selected_values.items():
        for column_name, values in column_values.items():
            # 找出最大的edit_distance相似度
            max_edit_distance_similarity = max(entity["edit_distance_similarity"] for entity in values)
            # 找出所有edit_distance相似度大于0.9 * max_edit_distance_similarity的字段值
            values = [entity for entity in values if entity["edit_distance_similarity"] >= 0.9*max_edit_distance_similarity]
            # 找到最大的embedding_similarity值
            max_embedding_similarity = max(entity["embedding_similarity"] for entity in values)
            # 同上,在前面基础上找出所有embedding_similarity相似度大于0.9 * max_embedding_similarity的字段值
            selected_values[table_name][column_name] = [entity['similar_value'] for entity in values if entity["embedding_similarity"] >= 0.9*max_embedding_similarity]
                
    return selected_values

_get_to_search_values方法分割keyword:

def _get_to_search_values(self, keywords: List[str]) -> List[str]:
    """
    Extracts values to search from the keywords.

    Args:
        keywords (List[str]): The list of keywords.

    Returns:
        List[str]: A list of values to search.
    """
    def get_substring_packet(keyword: str, substring: str) -> Dict[str, str]:
        return {"keyword": keyword, "substring": substring}
    
    to_search_values = []
    for keyword in keywords:
        keyword = keyword.strip()
        to_search_values.append(get_substring_packet(keyword, keyword))
        # 将keyword按照空格切分
        if " " in keyword:
            for i in range(len(keyword)):
                if keyword[i] == " ":
                    first_part = keyword[:i]
                    second_part = keyword[i+1:]
                    # 空格左边和右边分别作为substring保存起来
                    to_search_values.append(get_substring_packet(keyword, first_part))
                    to_search_values.append(get_substring_packet(keyword, second_part))
        # 将keyword按照等号切分,等号右侧的值保存起来
        hint_column, hint_value = self._column_value(keyword)
        if hint_value:
            to_search_values.append(get_substring_packet(keyword, hint_value))
    to_search_values.sort(key=lambda x: (x["keyword"], len(x["substring"]), x["substring"]), reverse=True)
    return to_search_values

_get_similar_entities_via_LSH方法通过LSH索引找到相似的字段值:

def _get_similar_entities_via_LSH(self, substring_packets: List[Dict[str, str]]) -> List[Dict[str, Any]]:
    similar_entities_via_LSH = []
    for packet in substring_packets:
        keyword = packet["keyword"]
        substring = packet["substring"]
        # 通过LSH查询相似的表名和列名
        unique_similar_values = DatabaseManager().query_lsh(keyword=substring, signature_size=100, top_n=10)
        for table_name, column_values in unique_similar_values.items():
            for column_name, values in column_values.items():
                for value in values:
                    # 组装结果,包含相似的列值(数据)
                    similar_entities_via_LSH.append({"keyword": keyword, 
                                            "substring": substring,
                                            "table_name": table_name,
                                            "column_name": column_name,
                                            "similar_value": value})
    return similar_entities_via_LSH

_get_similar_entities_via_edit_distance方法,通过语义相似算法找出相似的字段值:

def _get_similar_entities_via_edit_distance(self, similar_entities_via_LSH: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    similar_entities_via_edit_distance_similarity = []
    for entity_packet in similar_entities_via_LSH:
        # 比较字段值的文本相似度
        edit_distance_similarity = difflib.SequenceMatcher(None, entity_packet["substring"].lower(), entity_packet["similar_value"].lower()).ratio()
        # 比较相似阈值是否大于等于0.3
        if edit_distance_similarity >= self.edit_distance_threshold:
            entity_packet["edit_distance_similarity"] = edit_distance_similarity
            similar_entities_via_edit_distance_similarity.append(entity_packet)
    return similar_entities_via_edit_distance_similarity

RetrieveContext

RetrieveContext根据表schema描述找出相似的table和column。Schema信息来源于向量库中保存的数据库schema描述信息。

论文原话:In addition to retrieving values, the IR agent can access the database catalog, which often includes schema metadata, such as column descriptions, extended column names (to resolve abbreviations), and value descriptions

Retrieval is based on semantic (embedding) similarity, ensuring that the most relevant context is provided to the model.

def _run(self, state: SystemState):
    """
    Executes the context retrieval process.
    
    Args:
        state (SystemState): The current system state.
    """
    
    retrieved_columns = self._find_most_similar_columns(
        question=state.task.question,
        evidence=state.task.evidence,
        keywords=state.keywords,
        top_k=self.top_k
    )
    
    # 移除description中的score,返回
    state.schema_with_descriptions = self._format_retrieved_descriptions(retrieved_columns)

    # try:
    #     path = os.path.join(os.getenv("DB_ROOT_DIRECTORY"), state.task.db_id)
    #     state.schema_with_descriptions = load_tables_description(path, use_value_description=True)
    # except Exception as e:
    #     logging.error(f"Error loading tables description: {e}")
    #     state.schema_with_descriptions = {}

### Context similarity ###

_find_most_similar_columns方法从向量数据库中根据question和evidence找出desciption最相似的column。代码如下:

def _find_most_similar_columns(self, question: str, evidence: str, keywords: List[str], top_k: int) -> Dict[str, Dict[str, Dict[str, str]]]:
    """
    Finds the most similar columns based on the question and evidence.

    Args:
        question (str): The question string.
        evidence (str): The evidence string.
        keywords (List[str]): The list of keywords.
        top_k (int): The number of top similar columns to retrieve.

    Returns:
        Dict[str, Dict[str, Dict[str, str]]]: A dictionary containing the most similar columns with descriptions.
    """
    logging.info("Finding the most similar columns")
    tables_with_descriptions = {}
    
    for keyword in keywords:
        # 组合question,evidence和keyword
        question_based_query = f"{question} {keyword}"
        evidence_based_query = f"{evidence} {keyword}"
        
        # 查向量数据库
        retrieved_question_based_query = DatabaseManager().query_vector_db(question_based_query, top_k=top_k)
        retrieved_evidence_based_query = DatabaseManager().query_vector_db(evidence_based_query, top_k=top_k)
        
        # 按照表名,列名,列description三层结构返回结果
        tables_with_descriptions = self._add_description(tables_with_descriptions, retrieved_question_based_query)
        tables_with_descriptions = self._add_description(tables_with_descriptions, retrieved_evidence_based_query)
    
    return tables_with_descriptions

SchemaSelector

SchemaSelector挑选出和keyword相关的表和字段,类似于列裁剪优化,减少了schema的大小。Schema信息来源于真实的数据库。

论文原话:The goal of the Schema Selector(SS) agent is to reduce the schema size by selecting only the necessary tables and columns required for generating the SQL query. To achieve this, the SS agent is equipped with three tools, filter column, select tables, and select columns.

def __init__(self, config: dict):
    """Initialize the tools needed for schema selection"""
    super().__init__(
        name="schema_selector",
        task="narrow down the schema into the most relevant ones through filtering columns, selecting tables and selecting columns",
        config=config,
    )
    
    self.tools = {
        "filter_column": FilterColumn(**config["tools"]["filter_column"]),              
        "select_tables": SelectTables(**config["tools"]["select_tables"]),
        "select_columns": SelectColumns(**config["tools"]["select_columns"])
    }

FilterColumn

过滤出符合条件的column。判断出字段是否和question和hint相关。

def _run(self, state: SystemState):
    """
    Executes the column filtering process.
    
    Args:
        state (SystemState): The current system state.
    """
    # 获取column信息
    column_profiles = DatabaseManager().get_column_profiles(
        schema_with_examples=state.schema_with_examples, 
        use_value_description=True, 
        with_keys=True, 
        with_references=True,
        tentative_schema=state.tentative_schema
    )

    list_of_kwargs = []
    # 遍历组装question hint和column_profile到list_of_kwargs
    for table_name, columns in column_profiles.items():
        for column_name, column_profile in columns.items():
            kwargs = {
                "QUESTION": state.task.question,
                "HINT": state.task.evidence,
                "COLUMN_PROFILE": column_profile,
            }
            list_of_kwargs.append(kwargs)

    # 交给大模型返回这些column是否和question hint相关
    # prompt在template_filter_column.txt
    response = async_llm_chain_call(
        prompt=get_prompt(template_name=self.template_name),
        engine=get_llm_chain(**self.engine_config),
        parser=get_parser(self.parser_name),
        request_list=list_of_kwargs,
        step=self.tool_name, 
        sampling_count=1
    )
    
    index = 0
    tentative_schema = state.tentative_schema
    for table_name, columns in column_profiles.items():
        tentative_schema[table_name] = []
        for column_name, column_profile in columns.items():
            try:
                # 如果大模型认为这个column相关,加入到tentative_schema中
                chosen = (response[index][0]["is_column_information_relevant"].lower() == "yes")
                if chosen:
                    tentative_schema[table_name].append(column_name)
            except Exception as e:
                Logger().log(f"({state.task.db_id}, {state.task.question_id}) Error in column filtering: {e}", "error")
                logging.error(f"Error in column filtering for table '{table_name}', column '{column_name}': {e}")
            index += 1        
    
    # RetrieveEntity步骤找到的similar_columns加入到tentitive_schema中
    state.add_columns_to_tentative_schema(state.similar_columns)
    # 加入所有的主键外键字段
    state.add_connections_to_tentative_schema()

SelectTables

根据question和hint,挑选出所需的table。

def _run(self, state: SystemState):
    """
    Executes the table selection process.
    
    Args:
        state (SystemState): The current system state.

    Returns:
        Dict[str, Any]: A dictionary containing the updated tentative schema and selected tables.
    """

    if self.mode == "ask_model":
        
        request_kwargs = {
            "DATABASE_SCHEMA": state.get_schema_string(schema_type="tentative"),
            "QUESTION": state.task.question,
            "HINT": state.task.evidence,
        }
        # prompt使用template_select_columns.txt
        response = async_llm_chain_call(
            prompt=get_prompt(template_name=self.template_name),
            engine=get_llm_chain(**self.engine_config),
            parser=get_parser(self.parser_name),
            request_list=[request_kwargs],
            step=self.tool_name,
            sampling_count=self.sampling_count,
        )[0]
        
        # 从前面LLM的返回结果,聚合所有的chain_of_thought_reasoning和table_names
        aggregated_result = self.aggregate_tables(response)
        self.selected_tables = aggregated_result["table_names"]
        self.chain_of_thought_reasoning = aggregated_result["chain_of_thought_reasoning"]
        
    elif self.mode == "corrects":
        # 从SQL中获取表名
        self.chain_of_thought_reasoning = "Tables that are appeared in the gold SQL query."
        self.selected_tables = DatabaseManager().get_sql_tables(state.task.SQL)
    else:
        logging.error(f"Unknown mode for table selection: {self.mode}")
        raise ValueError(f"Unknown mode for table selection: {self.mode}")
    
    # 配置tentative_schema为selected_tables
    state.tentative_schema = {
        table_name: state.tentative_schema.get(table_name, [])
        for table_name in self.selected_tables
    }
    # RetrieveEntity步骤找到的similar_columns加入到tentitive_schema中
    state.add_columns_to_tentative_schema(state.similar_columns)
    # 加入所有的主键外键字段
    state.add_connections_to_tentative_schema()

SelectColumns

根据question,hint和前面过滤出来的表和字段,挑选出最终需要的字段。

def _run(self, state: SystemState):
    """
    Executes the column selection process.
    
    Args:
        state (SystemState): The current system state.
    """
    if self.mode == "ask_model":                        
        request_kwargs = {
            "DATABASE_SCHEMA": state.get_schema_string(schema_type="tentative"),
            "QUESTION": state.task.question,
            "HINT": state.task.evidence,
        }
        
        # prompt位于template_select_columns.txt中
        response = async_llm_chain_call(
            prompt=get_prompt(template_name=self.template_name),
            engine=get_llm_chain(**self.engine_config),
            parser=get_parser(self.parser_name),
            request_list=[request_kwargs],
            step=self.tool_name,
            sampling_count=self.sampling_count,
        )[0]
        # 后面的内容和select_tables对应的方法类似
        aggregated_result = self.aggregate_columns(response, list(state.tentative_schema.keys()))
        self.chain_of_thought_reasoning = aggregated_result.pop("chain_of_thought_reasoning")
        # self.selected_columns = self.union_schemas(response)
        self.selected_columns = aggregated_result
        
    elif self.mode == "corrects":
        self.chain_of_thought_reasoning = "Columns that are appeared in the gold SQL query."
        self.selected_columns = DatabaseManager().get_sql_columns_dict(state.task.SQL)

    else:
        logging.error(f"Unknown mode for column selection: {self.mode}")
        raise ValueError(f"Unknown mode for column selection: {self.mode}")
    
    state.tentative_schema = self.selected_columns.copy()

CandidateGenerator

论文原文:The Candidate Generator (CG) is responsible for synthesizing SQL query that answers the question asked from the database.

该步骤用于生成符合question的SQL查询,然后订正这些SQL。

def __init__(self, config: dict):
    super().__init__(
        name="Candidate Generator",
        task=("generate candidate sql queries, and revise the predicted SQL query based on task evidence and schema information",
              "revise the predicted SQL query based on task evidence and schema information"),
        config=config
    )

    self.tools = {
        "generate_candidate": GenerateCandidate(**config["tools"]["generate_candidate"]),
        "revise": Revise(**config["tools"]["revise"])
    }

GenerateCandidate

论文原话:This tool generates a single candidate query that answers the question. It takes the question, the schema, and the context (entities and descriptions) and prompts an LLM to follow a multi-step reasoning guideline to write a candidate SQL query.

def _run(self, state: SystemState):
    """
    Executes the candidate generation process.
    
    Args:
        state (SystemState): The current system state.
    """
    state.SQL_meta_infos[self.tool_name] = []
    for generator_config in self.generator_configs:
        # template_name 位于 template_generate_candidate_xxx.txt
        self.generators_queries[generator_config.template_name] = []
    for generator_config in self.generator_configs:
        # 遍历generator配置
        if self.next_generator_to_use != "ALL" and generator_config.template_name != self.next_generator_to_use:
            continue
        request_list = []
        # 重复sampling_count次调用
        for i in range(generator_config.sampling_count):
            try:
                request_kwargs = {
                    "DATABASE_SCHEMA": state.get_schema_string(schema_type="complete"),
                    "QUESTION": state.task.question,
                    "HINT": state.task.evidence,
                }
                request_list.append(request_kwargs)
            except Exception as e:
                print(f"Error in creating request_kwargs for generator {generator_config.template_name}: {e}")
                continue
        
        try:
            # 使用LLM执行输出SQL,取出执行结果
            # 模板为template_generate_candidate_xxx.txt
            response = async_llm_chain_call(
                prompt=get_prompt(template_name=generator_config.template_name),
                engine=get_llm_chain(**generator_config.engine_config),
                parser=get_parser(generator_config.parser_name),
                request_list=request_list,
                step=f"{self.tool_name}_{generator_config.engine_config['engine_name']}",
            )
            response = [res for sublist in response for res in sublist]
        except Exception as e:
            print(f"Error in generating SQL queries for generator {generator_config.template_name}: {e}")
            continue
        # 生成SQLMetaInfo加入到generators_queries中
        for res in response:
            if not res:
                continue
            try:
                sql_meta_info = SQLMetaInfo(**res)
                # state.SQL_meta_infos[self.tool_name].append(sql_meta_info)
                self.generators_queries[generator_config.template_name].append(sql_meta_info)
            except Exception as e:
                print(f"Error in creating SQLMetaInfo for generator {generator_config.template_name}: {e}")
                continue
        request_list = []
    # 追加SQLMetaInfo到SQL_meta_infos中,和tool_name关联起来
    for generator_config in self.generator_configs:
        if len(self.generators_queries[generator_config.template_name]) > 0:
            state.SQL_meta_infos[self.tool_name] += self.generators_queries[generator_config.template_name]

Revise

此工具修正生成SQL中存在的错误。分析如下:

def _run(self, state: SystemState):
    """
    Executes the SQL revision process.
    
    Args:
        state (SystemState): The current system state.
    """
    try:
        # 找到最后一个tool_name
        key_to_refine = list(state.SQL_meta_infos.keys())[-1]
        # 获取到这个tool关联的SQLMetaInfo
        target_SQL_meta_infos = state.SQL_meta_infos[key_to_refine]
    except Exception as e:
        print(f"Error in Checker: {e}")
        return
    # 构建SQL_id
    if key_to_refine.startswith(self.tool_name):
        id = int(key_to_refine[len(self.tool_name)+1:])
        SQL_id = self.tool_name + "_" + str(id+1)
    else:
        SQL_id = self.tool_name + "_1"  
    state.SQL_meta_infos[SQL_id] = []
    request_list = []
    for SQL_meta_info in target_SQL_meta_infos:
        try:
            # 获取执行状态
            execution_status = SQL_meta_info.execution_status
            # 如果状态不为语法正确,标记为需要修复
            if execution_status != ExecutionStatus.SYNTACTICALLY_CORRECT:
                SQL_meta_info.need_fixing = True
        except Exception:
            SQL_meta_info.need_fixing = True
    # 过滤出需要修复的SQLMetaInfo
    need_fixing_SQL_meta_infos = [(index, target_SQL_meta_info) for index, target_SQL_meta_info in enumerate(target_SQL_meta_infos) if target_SQL_meta_info.need_fixing]
    for index, target_SQL_meta_info in need_fixing_SQL_meta_infos:   
        # 拼接调用大模型的参数
        try:            
            request_kwargs = {
                "DATABASE_SCHEMA": state.get_schema_string(schema_type="complete"),
                "QUESTION": state.task.question,
                "HINT": state.task.evidence,
                "QUERY": target_SQL_meta_info.SQL  ,
                "RESULT": self.get_formatted_execution_result(target_SQL_meta_info)
            }
            request_list.append(request_kwargs)
        except Exception as e:
            print(f"Error in Checker while creating request list: {e}")
            continue
            
    try:
        # prompt模板位于template_revise_xxx.txt
        response = async_llm_chain_call(
            prompt=get_prompt(template_name=self.template_name),
            engine=get_llm_chain(**self.engine_config),
            parser=get_parser(self.parser_name),
            request_list=request_list,
            step=self.tool_name
        )
        response = [r[0] for r in response]
    except Exception as e:
        print(f"Error in Checker while getting response: {e}")
        response = []
    index = 0
    for target_SQL_meta_info in target_SQL_meta_infos:
        try:
            if target_SQL_meta_info.need_fixing:
                refinement_response = response[index]
                index += 1
                # 如果没有select,使用原SQL
                if "SELECT" not in refinement_response["refined_sql_query"]:
                    refinement_response = {
                        "refined_sql_query": target_SQL_meta_info.SQL
                    }
            else:
                # 如果不需要修复,使用原SQL
                refinement_response = {
                    "refined_sql_query": target_SQL_meta_info.SQL
                }
        except Exception as e:
            print(f"Error in Checker while updating SQL meta info: {e}")
            refinement_response = {
                "refined_sql_query": target_SQL_meta_info.SQL
            }
        # 替换SQL为refined之后的SQL
        if "refined_sql_query" in refinement_response:
            if refinement_response["refined_sql_query"]:
                state.SQL_meta_infos[SQL_id].append(SQLMetaInfo(**{
                    "SQL": refinement_response["refined_sql_query"]
                })) 

UnitTester

论文原话:

UT identifies the best candidate by:

  1. generating multiple unit tests that highlight differences between the candidate queries and
  2. evaluating the candidates against these unit tests.
def __init__(self, config: dict):
    """Initialize the tools needed for unit testing"""
    super().__init__(
        name="unit_tester",
        task="generate unit tests then evaluate them",
        config=config,
    )
    
    self.tools = {
        "generate_unit_test": GenerateUnitTest(**config["tools"]["generate_unit_test"]),
        "evaluate": Evaluate(**config["tools"]["evaluate"])
    }

GenerateUnitTest

论文原话:This tool prompts an LLM to generate k unit tests, where k is an input parameter, designed such that only the correct SQL query can pass each of them.

def _run(self, state: SystemState):
    try:
        # 找到最后一个tools_name
        key_to_evaluate = list(state.SQL_meta_infos.keys())[-1]
        # 返回对应的SQLMetaInfo
        target_SQL_meta_infos = state.SQL_meta_infos[key_to_evaluate]
    except Exception as e:
        print(f"Error in UnitTestEvaluator: {e}")
        return
    if len(target_SQL_meta_infos) <= 1:
        state.unit_tests["unit_test_generation"] = []
        return
    # 根据SQL获取数据库schema
    database_schema = state.get_database_schema_for_queries(
            [sql_meta_info.SQL for sql_meta_info in target_SQL_meta_infos]
        )
    formatted_candidates = ""
    # 将SQL按照执行结果分类
    clusters = self.execution_based_clustering(target_SQL_meta_infos)
    self.candidates = target_SQL_meta_infos
    # 如果结果就一类,说明没有备选项,直接返回
    if len(clusters) == 1:
        state.unit_tests["unit_test_generation"] = []
        return
    index = 0
    # 构建候选Query提示语
    for key, candidate_queries in clusters.items():
        formatted_candidates += f"Cluster #{index+1}: \n"
        for candidate_query in candidate_queries:
            formatted_candidates += f"Query: {candidate_query.SQL}\n"
            formatted_candidates += "########\n"
        formatted_candidates += f"Execution result: {self._format_sql_query_result(candidate_queries[-1])}\n"
        formatted_candidates += "=====================\n"
        index += 1
        
    request_kwargs = {
        "HINT": state.task.evidence,
        "QUESTION": state.task.question,
        "DATABASE_SCHEMA": database_schema,
        "CANDIDATE_QUERIES": formatted_candidates,
        "UNIT_TEST_CAP": self.unit_test_count
    }
    # 模板位于template_generate_unit_tests.txt
    responses = async_llm_chain_call(
        prompt=get_prompt(template_name=self.template_name),
        engine=get_llm_chain(**self.engine_config),
        parser=get_parser(self.parser_name),
        request_list=[request_kwargs],
        step=self.tool_name,
        sampling_count=self.sampling_count
    )[0]

    # 存入生成的unit test到state
    state.unit_tests["unit_test_generation"] = []
    for response in responses:
        state.unit_tests["unit_test_generation"].extend(response['unit_tests'])
    state.unit_tests["unit_test_generation"].extend(HARD_CODES_TEST_CASES)

Evaluate

评估生成的candidate SQL是否能满足各个unit test,并给出评分。每满足一个unit test,评分加一。

论文原话:This tool takes multiple candidate queries and a single unit test as input, prompting an LLM to reason through each candidate and determine whether it passes the unit test.

def _run(self, state: SystemState):
    """
    Executes the unit test evaluation process.
    
    Args:
        state (SystemState): The current system state.
    """
    try:
        # 取出tool_name
        key_to_evaluate = list(state.SQL_meta_infos.keys())[-1]
        # 获取SQLMetaInfo
        target_SQL_meta_infos = state.SQL_meta_infos[key_to_evaluate]
    except Exception as e:
        print(f"Error in UnitTestEvaluator: {e}")
        return
    # 配置SQL id
    if key_to_evaluate.startswith(self.tool_name):
        id = int(key_to_evaluate[len(self.tool_name)+1:])
        self.SQL_id = self.tool_name + "_" + str(id+1)
    else:
        self.SQL_id = self.tool_name + "_1"  
    state.SQL_meta_infos[self.SQL_id] = []
    request_list = []
    # 没有备选SQL,给0分
    if len(target_SQL_meta_infos) == 0:
        state.SQL_meta_infos[self.SQL_id].append("SELECT * FROM table_name")
        self.scores = [0]
        self.comparison_matrix = [[0]]
        return
    # 只有一个备选SQL给1分
    if len(target_SQL_meta_infos) == 1:
        state.SQL_meta_infos[self.SQL_id].append(target_SQL_meta_infos[0])
        self.scores = [1]
        self.comparison_matrix = [[1]]
        return
    # 没有单元测试,给1分
    if len(state.unit_tests["unit_test_generation"]) == 0:
        state.SQL_meta_infos[self.SQL_id].append(target_SQL_meta_infos[0])
        self.scores = [1]
        self.comparison_matrix = [[1]]
        return
    # 根据执行结果将SQL分类
    candidates_clusters = self.execution_based_clustering(target_SQL_meta_infos)
    formatted_candidates = ""
    # 组装formatted_candidates提示语
    # 从SQL中获取数据库schema
    for index, candidate_query in enumerate(target_SQL_meta_infos):
        formatted_candidates += f"Candidate Response #{index+1}: Query: {candidate_query.SQL}\n, Execution Result: {self._format_sql_query_result(candidate_query)}\n"
    database_schema = state.get_database_schema_for_queries(
            [sql_meta_info.SQL for sql_meta_info in target_SQL_meta_infos]
        )
    for index, unit_test in enumerate(state.unit_tests["unit_test_generation"]): 
        try:            
            request_kwargs = {
                "DATABASE_SCHEMA": database_schema,
                "QUESTION": state.task.question,
                "HINT": state.task.evidence,
                "CANDIDATE_RESPONSES": formatted_candidates,
                "UNIT_TEST": unit_test
            }
            request_list.append(request_kwargs)
        except Exception as e:
            print(f"Error in UnitTestEvaluator while creating request list: {e}")
            continue
            
    try:
        # 调用大模型,prompt模板位于template_evaluate.txt
        response = async_llm_chain_call(
            prompt=get_prompt(template_name=self.template_name),
            engine=get_llm_chain(**self.engine_config),
            # parser为UnitTestEvaluationOutput
            parser=get_parser(self.parser_name),
            request_list=request_list,
            step=self.tool_name
        )
        response = [r[0] for r in response]
    except Exception as e:
        print(f"Error in Checker while getting response: {e}")
        response = []
    comparison_matrix = []
    for item in response:
        # if self.test_case_filtering_based_on_inter_cluster_variance(candidates_clusters, item["scores"], target_SQL_meta_infos):
        comparison_matrix.append(item["scores"])
    # sum scores across all unit tests
    self.comparison_matrix = comparison_matrix  
    # 按照SQL将该SQL所有单元测试得分求和
    # 根据UnitTestEvaluationOutput,passed的test会给1分,failed给0分
    scores = [sum([score[index] for score in comparison_matrix]) for index in range(len(comparison_matrix[0]))]
    self.scores = scores
    # find the best candidate
    # 找到最高分数对应的candidate
    best_candidate = self.pick_the_best_candidate(scores, target_SQL_meta_infos, candidates_clusters)
    state.SQL_meta_infos[self.SQL_id].append(best_candidate)

到此为止,CHESS整个自然语言转化为SQL的步骤分析完毕。

预处理

对于一个database,需要将其schema信息等经过proprocess步骤处理之后,才能为CHESS所用。

预处理代码逻辑位于preprocess.py

入口方法解析命令行参数后调用了worker_initializer方法:

if __name__ == '__main__':
    # Setup argument parser
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--db_root_directory', type=str, required=True, help="Root directory of the databases")
    args_parser.add_argument('--signature_size', type=int, default=20, help="Size of the MinHash signature")
    args_parser.add_argument('--n_gram', type=int, default=3, help="N-gram size for the MinHash")
    args_parser.add_argument('--threshold', type=float, default=0.01, help="Threshold for the MinHash LSH")
    args_parser.add_argument('--db_id', type=str, default='all', help="Database ID or 'all' to process all databases")
    args_parser.add_argument('--verbose', type=bool, default=True, help="Enable verbose logging")
    args_parser.add_argument('--use_value_description', type=bool, default=True, help="Include value descriptions")

    args = args_parser.parse_args()

    if args.db_id == 'all':
        with multiprocessing.Pool(NUM_WORKERS) as pool:
            for db_id in os.listdir(args.db_root_directory):
                # check if the db_id is a directory
                if os.path.isdir(f"{args.db_root_directory}/{db_id}"):
                    pool.apply_async(worker_initializer, args=(db_id, args))
            pool.close()
            pool.join()
    else:
        worker_initializer(args.db_id, args)

    logging.info("Preprocessing is complete.")

worker_initializer方法使用LSH算法索引整个数据库(make_db_lsh),然后将整个数据库的描述信息存放入Chroma向量数据库(make_db_context_vec_db)。

LSH是一种用于海量高维数据的近似最近邻快速查找的方法。

def worker_initializer(db_id: str, args: argparse.Namespace):
    """
    Initializes the worker to create LSH and context vectors for a given database ID.
    
    Args:
        db_id (str): The database ID.
        args (argparse.Namespace): The command line arguments.
    """
    db_directory_path = f"{args.db_root_directory}/{db_id}"
    logging.info(f"Creating LSH for {db_id}")
    # LSH索引
    make_db_lsh(db_directory_path, 
                signature_size=args.signature_size, 
                n_gram=args.n_gram, 
                threshold=args.threshold,
                verbose=args.verbose)
    logging.info(f"LSH for {db_id} created.")
    logging.info(f"Creating context vectors for {db_id}")
    # 表结构原始信息位于csv文件中
    make_db_context_vec_db(db_directory_path,
                           use_value_description=args.use_value_description)
    logging.info(f"Context vectors for {db_id} created.")

文字描述参见论文3.2节。

参考文献

论文地址: https://arxiv.org/pdf/2405.16755

项目地址: https://github.com/ShayanTalaei/CHESS.git

上一篇 下一篇

猜你喜欢

热点阅读