基于Langchain封装自定义语言模型类,该类继承from langchain.llms.base import LLM,可直接用于Agent的ReAct。
当我们训练好自己的语言模型后,我们希望可以把它封装成类似 from langchain.llms import OpenAI 这样的类,在langchain中灵活操作。下面是我对MiniCPM模型的封装代码,希望对大家有帮助。
·
当我们训练好自己的语言模型后,我们希望可以把它封装成类似 from langchain.llms import OpenAI 这样的类,在langchain中灵活操作。
下面是我对MiniCPM模型的封装代码,希望对大家有帮助。
from pydantic import Field
from typing import Any, List, Optional
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
CPM_DEVICE = "cuda:0"
MAX_NEW_TOKENS = 4096
TEMPERATURE = 0.7
TOP_P = 0.7
REPETITION_PENALTY = 1.02 # repetition_penalty是在使用预训练语言模型进行文本生成时,用于控制生成文本中重复词或短语的惩罚系数。这个参数在 Hugging Face Transformers 库中被引入,以帮助减少生成文本中的重复和循环模式,提高生成文本的多样性和连贯性
class LangChainMiniCPMModel(LLM):
tokenizer: Any = Field(default=None)
model: Any = Field(default=None)
def __init__(self, model_path: str):
"""
继承langchain的MiniCPM模型
参数:
model_path (str): 需要加载的MiniCPM模型路径。
返回:
self.model: 加载的MiniCPM模型。
self.tokenizer: 加载的MiniCPM模型的tokenizer。
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.float16
).to(CPM_DEVICE)
self.model = self.model.eval()
def _call(self, prompt, stop: Optional[List[str]] = None):
"""
langchain.llm的调用
参数:
prompt (str): 传入的prompt文本
返回:
responds (str): 模型在prompt下生成的文本
"""
inputs = self.tokenizer("<用户>{}".format(prompt), return_tensors="pt")
inputs = inputs.to(CPM_DEVICE)
# Generate
generate_ids = self.model.generate(
inputs.input_ids,
max_length=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
)
responds = self.tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# responds, history = self.model.chat(self.tokenizer, prompt, temperature=args.temperature, top_p=args.top_p, repetition_penalty=1.02)
return responds
@property
def _llm_type(self) -> str:
return "LangChainMiniCPMModel"
if __name__ == "__main__":
llm = LangChainMiniCPMModel("openbmb/MiniCPM3-4B")
更多推荐


所有评论(0)