阅读量:0
【2024】Datawhale AI夏令营 Task1
本文对赛事【逻辑推理赛道:复杂推理能力评估】的baseline代码进行解释。
1、安装必要的第三方库
!pip install scipy openai tiktoken retry dashscope loguru
2、导入模块,设置日志记录配置,为后续代码的运行提供必要的工具和配置环境。
from multiprocessing import Process, Manager import json import os from pprint import pprint import re from tqdm import tqdm import random import uuid import openai import tiktoken import json import numpy as np import requests from retry import retry from scipy import sparse #from rank_bm25 import BM25Okapi #import jieba from http import HTTPStatus import dashscope from concurrent.futures import ThreadPoolExecutor, as_completed from loguru import logger import json import time from tqdm import tqdm logger.remove() # 移除默认的控制台输出 logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip") MODEL_NAME = 'qwen1.5-1.8b-chat' # 后续推理使用的模型为qwen1.5-1.8b-chat
3、设置API-KEY
dashscope.api_key="sk-"
4、定义函数 api_retry
,用于调用 call_qwen_api
函数,并在调用失败时进行重试。
def api_retry(MODEL_NAME, query): max_retries = 5 ## 最大重试次数 retry_delay = 60 # in seconds ## 重试间隔时间,设为60秒 attempts = 0 ## 记录当前重试次数,初始值为0 while attempts < max_retries: try: ## try块内尝试调用call_qwen_api(MODEL_NAME, query)并返回结果 return call_qwen_api(MODEL_NAME, query) except Exception as e: attempts += 1 if attempts < max_retries: logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}") raise
5、定义一个名为 call_qwen_api
的函数,用于通过Dashscope API调用指定模型进行推理,并返回结果.
def call_qwen_api(MODEL_NAME, query): # 这里采用dashscope的api调用模型推理,通过http传输的json封装返回结果 messages = [ {'role': 'user', 'content': query}] response = dashscope.Generation.call( MODEL_NAME, messages=messages, result_format='message', # set the result is message format. ) if response.status_code == HTTPStatus.OK: # print(response) return response['output']['choices'][0]['message']['content'] else: print('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( response.request_id, response.status_code, response.code, response.message )) raise Exception()
6、定义prompt推理模版
# 这里定义了prompt推理模版 def get_prompt(problem, question, options): options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options)) prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下: ### 题目: {problem} ### 问题: {question} {options} """ # print(prompt) return prompt
7、定义一个名为 extract
的函数,用于从输入文本中提取答案。
# 这里使用extract抽取模获得抽取的结果 def extract(input_text): ans_pattern = re.compile(r"答案是:(.)", re.S) problems = ans_pattern.findall(input_text) # print(problems) if(problems == ''): return 'A' return problems[0]
8、定义一个名为 process_datas
的函数,用于并发处理一组数据,多线程执行api_retry
函数并对结果进行处理和提取。
def process_datas(datas,MODEL_NAME): results = [] with ThreadPoolExecutor(max_workers=16) as executor: future_data = {} lasttask = '' lastmark = 0 lens = 0 for data in tqdm(datas, desc="Submitting tasks", total=len(datas)): problem = data['problem'] for id,question in enumerate(data['questions']): prompt = get_prompt(problem, question['question'], question['options'], ) future = executor.submit(api_retry, MODEL_NAME, prompt) future_data[future] = (data,id) time.sleep(0.6) # 控制每0.5秒提交一个任务 lens += 1 for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"): # print('data',data) data = future_data[future][0] problem_id = future_data[future][1] try: res = future.result() extract_response = extract(res) # print('res',extract_response) data['questions'][problem_id]['answer'] = extract_response results.append(data) # print('data',data) except Exception as e: logger.error(f"Failed to process text: {data}. Error: {e}") return results
9、定义一个名为 main
的函数,用于从输入文件读取数据、处理数据,并将结果返回。
def main(ifn, ofn): if os.path.exists(ofn): pass data = [] # 按行读取数据 with open(ifn) as reader: for line in reader: sample = json.loads(line) data.append(sample) datas = data # print(data) # 均匀地分成多个数据集 return_list = process_datas(datas,MODEL_NAME) print(len(return_list)) print("All tasks finished!") return return_list
10、定义一个名为 evaluate
的函数,用于评估模型的预测结果。
def evaluate(ofn): data = [] with open(ofn) as reader: for line in reader: sample = json.loads(line) data.append(sample) pse = 0 cnt = 0 tot = 0 for task in data: for question in task['questions']: if MODEL_NAME in question: tot += 1 cnt += question[MODEL_NAME] == question['answer'] else: pse += 1 print(cnt, tot, cnt/tot, pse)
11、运行 extract
函数并打印结果,调用 main
函数并处理数据
if __name__ == '__main__': a = extract("""根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下: 1. 判断6和7是否相等:不相等。 2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。 3. 现在计算6和1的最大公约数。 4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。 5. 再计算5和1的最大公约数。 6. 5 > 1,用5减去1得到结果4。 7. 再计算4和1的最大公约数。 8. 4 > 1,用4减去1得到结果3。 9. 再计算3和1的最大公约数。 10. 3 > 1,用3减去1得到结果2。 11. 再计算2和1的最大公约数。 12. 2 > 1,用2减去1得到结果1。 13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。 因此,6和7的最大公约数是1。 答案是:C.""") print(a) return_list = main('round1_test_data.jsonl', 'upload.jsonl')
12、定义两个函数 has_complete_answer
和 filter_problems
,用于处理和过滤数据中的问题
def has_complete_answer(questions): # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键 for question in questions: if 'answer' not in question: return False return True def filter_problems(data): result = [] problem_set = set() for item in data: # print('处理的item' ,item) problem = item['problem'] if problem in problem_set: # 找到已存在的字典 for existing_item in result: if existing_item['problem'] == problem: # 如果当前字典有完整答案,替换已存在的字典 if has_complete_answer(item['questions']): existing_item['questions'] = item['questions'] existing_item['id'] = item['id'] break else: # 如果当前字典有完整答案,添加到结果列表 if has_complete_answer(item['questions']): result.append(item) problem_set.add(problem) return result
13、
return_list return_list = filter_problems(return_list) sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:])) print(sorted_data)
14、
sorted_data
15、
def find_missing_ids(dict_list): # 提取所有序号 extracted_ids = {int(d['id'][-3:]) for d in dict_list} # 创建0-500的序号集合 all_ids = set(range(500)) # 找出缺失的序号 missing_ids = all_ids - extracted_ids return sorted(missing_ids) # 示例字典列表 dict_list = sorted_data # 找出缺失的序号 missing_ids = find_missing_ids(dict_list) print("缺失的序号:", missing_ids)
16、
len(missing_ids)
17、
data = [] with open('round1_test_data.jsonl') as reader: for id,line in enumerate(reader): if(id in missing_ids): sample = json.loads(line) for question in sample['questions']: question['answer'] = 'A' sorted_data.append(sample) sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))
18、将结果写入文件。
with open('upload.jsonl', 'w') as writer: for sample in sorted_data: writer.write(json.dumps(sample, ensure_ascii=False)) writer.write('\n')