llama-factory 系列教程 (五),SFT 微调后的模型,结合langchain进行推理

avatar
作者
筋斗云
阅读量:0

背景

微调了一个 glm4-9B的大模型。微调后得到Lora权重,部署成vllm 的API,然后通过langchain接入完成相关任务的推理。

关于SFT 微调模型的部分就不做介绍了,大家可以参考前面的文章,将自己的数据集 在 Llamafactory 的 dataset_info.json 里进行注册。

llamafactory-cli webui
通过可视化界面进行微调,或者拿到预览的命令,在命令行中运行。

llamafactory API 部署模型

使用 llamafactory 训练模型,再使用llamafactory 部署API 简单又省事,就是慢了一点,但很方便。

如果你想追求极致的推理速度,建议你阅读这篇文章:llama-factory SFT 系列教程 (四),lora sft 微调后,使用vllm加速推理

运行下述代码,完成API部署:

CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \     --model_name_or_path /home/root/.cache/modelscope/hub/ZhipuAI/glm-4-9b-chat \     --adapter_name_or_path ./saves/GLM-4-9B-Chat/lora/train_2024-07-30-15-53-random-500 \     --template glm4 \     --finetuning_type lora \     --infer_backend vllm \     --vllm_enforce_eager 

adapter_name_or_path:lora 插件地址;
建议使用vllm进行部署,huggingface 容易报错。

langchain

from datasets import load_dataset from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.output_parsers import StrOutputParser  parser = StrOutputParser() 
port = 8000 model = ChatOpenAI(     api_key="0",     base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),     temperature=0 ) 

加载本地的json 文件,作为推理用的数据集:

valid_dataset = load_dataset(     "json",     data_files="../valid.json" )["train"] 
preds = [] for item in tqdm(valid_dataset):     # 修改 messages, 填入自己的数据即可     messages = [         SystemMessage(content=item['instruction']),         HumanMessage(content=item['input']),     ]     chain = model | parser     pred = chain.invoke(messages).strip()     preds.append(pred) 

如上述所示,即可轻松实现利用 langchain 结合训练后的模型,完成推理任务。

参考资料

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!