当我们训练好自己的语言模型后,我们希望可以把它封装成类似 from langchain.llms import OpenAI 这样的类,在langchain中灵活操作。

下面是我对MiniCPM模型的封装代码,希望对大家有帮助。

from pydantic import Field
from typing import Any, List, Optional
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

CPM_DEVICE = "cuda:0"
MAX_NEW_TOKENS = 4096
TEMPERATURE = 0.7
TOP_P = 0.7
REPETITION_PENALTY = 1.02 # repetition_penalty是在使用预训练语言模型进行文本生成时,用于控制生成文本中重复词或短语的惩罚系数。这个参数在 Hugging Face Transformers 库中被引入,以帮助减少生成文本中的重复和循环模式,提高生成文本的多样性和连贯性


class LangChainMiniCPMModel(LLM):
    tokenizer: Any = Field(default=None)
    model: Any = Field(default=None)

    def __init__(self, model_path: str):
        """
        继承langchain的MiniCPM模型

        参数:
        model_path (str): 需要加载的MiniCPM模型路径。

        返回:
        self.model: 加载的MiniCPM模型。
        self.tokenizer: 加载的MiniCPM模型的tokenizer。
        """
        super().__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path, trust_remote_code=True, torch_dtype=torch.float16
        ).to(CPM_DEVICE)
        self.model = self.model.eval()

    def _call(self, prompt, stop: Optional[List[str]] = None):
        """
        langchain.llm的调用

        参数:
        prompt (str): 传入的prompt文本

        返回:
        responds (str): 模型在prompt下生成的文本
        """

        inputs = self.tokenizer("<用户>{}".format(prompt), return_tensors="pt")
        inputs = inputs.to(CPM_DEVICE)
        # Generate
        generate_ids = self.model.generate(
            inputs.input_ids,
            max_length=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            repetition_penalty=REPETITION_PENALTY,
        )
        responds = self.tokenizer.batch_decode(
            generate_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]
        # responds, history = self.model.chat(self.tokenizer, prompt, temperature=args.temperature, top_p=args.top_p, repetition_penalty=1.02)

        return responds

    @property
    def _llm_type(self) -> str:
        return "LangChainMiniCPMModel"

if __name__ == "__main__":
    llm = LangChainMiniCPMModel("openbmb/MiniCPM3-4B")

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐