【2024】Datawhale AI夏令营 Task1

avatar
作者
猴君
阅读量:0

【2024】Datawhale AI夏令营 Task1

本文对赛事【逻辑推理赛道:复杂推理能力评估】的baseline代码进行解释。

1、安装必要的第三方库

!pip install scipy openai tiktoken retry dashscope loguru 

2、导入模块,设置日志记录配置,为后续代码的运行提供必要的工具和配置环境。

from multiprocessing import Process, Manager import json import os from pprint import pprint import re from tqdm import tqdm import random  import uuid import openai import tiktoken import json import numpy as np import requests from retry import retry from scipy import sparse #from rank_bm25 import BM25Okapi #import jieba from http import HTTPStatus import dashscope   from concurrent.futures import ThreadPoolExecutor, as_completed from loguru import logger import json import time from tqdm import tqdm  logger.remove()  # 移除默认的控制台输出 logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")  MODEL_NAME = 'qwen1.5-1.8b-chat'  # 后续推理使用的模型为qwen1.5-1.8b-chat 

3、设置API-KEY

dashscope.api_key="sk-" 

4、定义函数 api_retry,用于调用 call_qwen_api 函数,并在调用失败时进行重试。

def api_retry(MODEL_NAME, query):     max_retries = 5 ## 最大重试次数     retry_delay = 60  # in seconds ## 重试间隔时间,设为60秒     attempts = 0 ## 记录当前重试次数,初始值为0     while attempts < max_retries:         try: ## try块内尝试调用call_qwen_api(MODEL_NAME, query)并返回结果             return call_qwen_api(MODEL_NAME, query)         except Exception as e:             attempts += 1                if attempts < max_retries:                 logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")                 time.sleep(retry_delay)             else:                 logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")                 raise 

5、定义一个名为 call_qwen_api 的函数,用于通过Dashscope API调用指定模型进行推理,并返回结果.

def call_qwen_api(MODEL_NAME, query):     # 这里采用dashscope的api调用模型推理,通过http传输的json封装返回结果     messages = [         {'role': 'user', 'content': query}]     response = dashscope.Generation.call(         MODEL_NAME,         messages=messages,         result_format='message',  # set the result is message format.     )     if response.status_code == HTTPStatus.OK:         # print(response)         return response['output']['choices'][0]['message']['content']     else:         print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (             response.request_id, response.status_code,             response.code, response.message         ))         raise Exception() 

6、定义prompt推理模版

# 这里定义了prompt推理模版  def get_prompt(problem, question, options):      options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))      prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:  ### 题目: {problem}  ### 问题: {question} {options} """     # print(prompt)     return prompt  

7、定义一个名为 extract 的函数,用于从输入文本中提取答案。

# 这里使用extract抽取模获得抽取的结果  def extract(input_text):     ans_pattern = re.compile(r"答案是:(.)", re.S)      problems = ans_pattern.findall(input_text)     # print(problems)     if(problems == ''):         return 'A'     return problems[0] 

8、定义一个名为 process_datas 的函数,用于并发处理一组数据,多线程执行api_retry函数并对结果进行处理和提取。

def process_datas(datas,MODEL_NAME):     results = []     with ThreadPoolExecutor(max_workers=16) as executor:         future_data = {}         lasttask = ''         lastmark = 0         lens = 0         for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):             problem = data['problem']             for id,question in enumerate(data['questions']):                 prompt = get_prompt(problem,                                      question['question'],                                      question['options'],                                     )                  future = executor.submit(api_retry, MODEL_NAME, prompt)                                  future_data[future] = (data,id)                 time.sleep(0.6)  # 控制每0.5秒提交一个任务                 lens += 1         for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):             # print('data',data)             data = future_data[future][0]             problem_id = future_data[future][1]             try:                 res  = future.result()                 extract_response = extract(res)                 # print('res',extract_response)                 data['questions'][problem_id]['answer'] = extract_response                 results.append(data)                 # print('data',data)                              except Exception as e:                 logger.error(f"Failed to process text: {data}. Error: {e}")          return results 

9、定义一个名为 main 的函数,用于从输入文件读取数据、处理数据,并将结果返回。

def main(ifn, ofn):     if os.path.exists(ofn):         pass     data = []     # 按行读取数据     with open(ifn) as reader:         for line in reader:             sample = json.loads(line)             data.append(sample)     datas = data     # print(data)     # 均匀地分成多个数据集     return_list = process_datas(datas,MODEL_NAME)     print(len(return_list))     print("All tasks finished!")     return return_list 

10、定义一个名为 evaluate 的函数,用于评估模型的预测结果。

def evaluate(ofn):     data = []     with open(ofn) as reader:         for line in reader:             sample = json.loads(line)             data.append(sample)      pse = 0     cnt = 0     tot = 0     for task in data:         for question in task['questions']:                          if MODEL_NAME in question:                 tot += 1                 cnt += question[MODEL_NAME] == question['answer']             else:                 pse += 1      print(cnt, tot, cnt/tot, pse) 

11、运行 extract 函数并打印结果,调用 main 函数并处理数据

if __name__ == '__main__':      a = extract("""根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下:  1. 判断6和7是否相等:不相等。 2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。 3. 现在计算6和1的最大公约数。 4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。 5. 再计算5和1的最大公约数。 6. 5 > 1,用5减去1得到结果4。 7. 再计算4和1的最大公约数。 8. 4 > 1,用4减去1得到结果3。 9. 再计算3和1的最大公约数。 10. 3 > 1,用3减去1得到结果2。 11. 再计算2和1的最大公约数。 12. 2 > 1,用2减去1得到结果1。 13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。  因此,6和7的最大公约数是1。  答案是:C.""")      print(a)     return_list = main('round1_test_data.jsonl', 'upload.jsonl')  

12、定义两个函数 has_complete_answerfilter_problems,用于处理和过滤数据中的问题

def has_complete_answer(questions):     # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键     for question in questions:         if 'answer' not in question:             return False     return True  def filter_problems(data):     result = []     problem_set = set()      for item in data:         # print('处理的item' ,item)         problem = item['problem']         if problem in problem_set:             # 找到已存在的字典             for existing_item in result:                 if existing_item['problem'] == problem:                     # 如果当前字典有完整答案,替换已存在的字典                     if has_complete_answer(item['questions']):                         existing_item['questions'] = item['questions']                         existing_item['id'] = item['id']                     break         else:             # 如果当前字典有完整答案,添加到结果列表             if has_complete_answer(item['questions']):                 result.append(item)                 problem_set.add(problem)      return result 

13、

return_list return_list = filter_problems(return_list) sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:])) print(sorted_data) 

14、

sorted_data 

15、

def find_missing_ids(dict_list):     # 提取所有序号     extracted_ids = {int(d['id'][-3:]) for d in dict_list}          # 创建0-500的序号集合     all_ids = set(range(500))          # 找出缺失的序号     missing_ids = all_ids - extracted_ids          return sorted(missing_ids)  # 示例字典列表 dict_list = sorted_data  # 找出缺失的序号 missing_ids = find_missing_ids(dict_list) print("缺失的序号:", missing_ids) 

16、

len(missing_ids) 

17、

data  = [] with open('round1_test_data.jsonl') as reader:     for id,line in enumerate(reader):         if(id in missing_ids):             sample = json.loads(line)             for question in sample['questions']:                 question['answer'] = 'A'             sorted_data.append(sample) sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))          

18、将结果写入文件。

with open('upload.jsonl', 'w') as writer:     for sample in sorted_data:         writer.write(json.dumps(sample, ensure_ascii=False))         writer.write('\n') 

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!