13 KiB
カスタムモデルのインテグレーション
イントロダクション
ベンダー統合が完了した後、次にベンダーの下でモデルのインテグレーションを行います。ここでは、全体のプロセスを理解するために、例としてXinferenceを使用して、段階的にベンダーのインテグレーションを完了します。
注意が必要なのは、カスタムモデルの場合、各モデルのインテグレーションには完全なベンダークレデンシャルの記入が必要です。
事前定義モデルとは異なり、カスタムベンダーのインテグレーション時には常に以下の2つのパラメータが存在し、ベンダー yaml に定義する必要はありません。

前述したように、ベンダーはvalidate_provider_credentialを実装する必要はなく、Runtimeがユーザーが選択したモデルタイプとモデル名に基づいて、対応するモデル層のvalidate_credentialsを呼び出して検証を行います。
ベンダー yaml の作成
まず、インテグレーションを行うベンダーがどのタイプのモデルをサポートしているかを確認します。
現在サポートされているモデルタイプは以下の通りです:
llmテキスト生成モデルtext_embeddingテキスト Embedding モデルrerankRerank モデルspeech2text音声からテキスト変換ttsテキストから音声変換moderationモデレーション
XinferenceはLLM、Text Embedding、Rerankをサポートしているため、xinference.yamlを作成します。
provider: xinference # ベンダー識別子
label: # ベンダー表示名、en_US 英語、zh_Hans 中国語の両方の言語で設定可能、zh_Hans が設定されていない場合は en_US がデフォルト
en_US: Xorbits Inference
icon_small: # 小アイコン、他のベンダーのアイコンを参考にし、対応するベンダー実装ディレクトリの _assets ディレクトリに保存
en_US: icon_s_en.svg
icon_large: # 大アイコン
en_US: icon_l_en.svg
help: # ヘルプ
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # サポートされるモデルタイプ、XinferenceはLLM/Text Embedding/Rerankをサポート
- llm
- text-embedding
- rerank
configurate_methods: # Xinferenceはローカルデプロイのベンダーであり、事前定義モデルがないため、必要なモデルを自分でデプロイする必要があるので、ここではカスタムモデルのみサポート
- customizable-model
provider_credential_schema:
credential_form_schemas:
その後、Xinferenceでモデルを定義するために必要なクレデンシャルを考えます。
- 3つの異なるモデルをサポートするため、
model_typeを使用してこのモデルのタイプを指定する必要があります。3つのタイプがあるので、次のように記述します。
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 言語モデル
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
- 各モデルには独自の名称
model_nameがあるため、ここで定義する必要があります。
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: モデル名
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
- Xinferenceのローカルデプロイのアドレスを記入します。
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
- 各モデルには一意の model_uid があるため、ここで定義する必要があります。
- variable: model_uid
label:
zh_Hans: 模型 UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid
これで、ベンダーの基本定義が完了しました。
モデルコードの作成
次に、llmタイプを例にとって、xinference.llm.llm.pyを作成します。
llm.py内で、Xinference LLM クラスを作成し、XinferenceAILargeLanguageModel(任意の名前)と名付けて、__base.large_language_model.LargeLanguageModel基底クラスを継承し、以下のメソッドを実装します:
-
LLM 呼び出し
LLM 呼び出しのコアメソッドを実装し、ストリームレスポンスと同期レスポンスの両方をサポートします。
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ Invoke large language model :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages :param model_parameters: model parameters :param tools: tools for tool calling :param stop: stop words :param stream: is stream response :param user: unique user id :return: full response or stream response chunk generator result """実装時には、同期レスポンスとストリームレスポンスを処理するために2つの関数を使用してデータを返す必要があります。Pythonは
yieldキーワードを含む関数をジェネレータ関数として認識し、返されるデータ型は固定でジェネレーターになります。そのため、同期レスポンスとストリームレスポンスは別々に実装する必要があります。以下のように実装します(例では簡略化されたパラメータを使用していますが、実際の実装では上記のパラメータリストに従って実装してください):def _invoke(self, stream: bool, **kwargs) \ -> Union[LLMResult, Generator]: if stream: return self._handle_stream_response(**kwargs) return self._handle_sync_response(**kwargs) def _handle_stream_response(self, **kwargs) -> Generator: for chunk in response: yield chunk def _handle_sync_response(self, **kwargs) -> LLMResult: return LLMResult(**response) -
予測トークン数の計算
モデルが予測トークン数の計算インターフェースを提供していない場合、直接0を返すことができます。
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Get number of tokens for given prompt messages :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages :param tools: tools for tool calling :return: """時には、直接0を返す必要がない場合もあります。その場合は
self._get_num_tokens_by_gpt2(text: str)を使用して予測トークン数を取得することができます。このメソッドはAIModel基底クラスにあり、GPT2のTokenizerを使用して計算を行いますが、代替方法として使用されるものであり、完全に正確ではありません。 -
モデルクレデンシャル検証
ベンダークレデンシャル検証と同様に、ここでは個々のモデルについて検証を行います。
def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials :param model: model name :param credentials: model credentials :return: """ -
モデルパラメータスキーマ
カスタムタイプとは異なり、yamlファイルでモデルがサポートするパラメータを定義していないため、動的にモデルパラメータのスキーマを生成する必要があります。
例えば、Xinferenceは
max_tokens、temperature、top_pの3つのモデルパラメータをサポートしています。しかし、ベンダーによっては異なるモデルに対して異なるパラメータをサポートしている場合があります。例えば、ベンダー
OpenLLMはtop_kをサポートしていますが、全てのモデルがtop_kをサポートしているわけではありません。ここでは、例としてAモデルがtop_kをサポートし、Bモデルがtop_kをサポートしていない場合、以下のように動的にモデルパラメータのスキーマを生成します:def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ rules = [ ParameterRule( name='temperature', type=ParameterType.FLOAT, use_template='temperature', label=I18nObject( zh_Hans='温度', en_US='Temperature' ) ), ParameterRule( name='top_p', type=ParameterType.FLOAT, use_template='top_p', label=I18nObject( zh_Hans='Top P', en_US='Top P' ) ), ParameterRule( name='max_tokens', type=ParameterType.INT, use_template='max_tokens', min=1, default=512, label=I18nObject( zh_Hans='最大生成长度', en_US='Max Tokens' ) ) ] # if model is A, add top_k to rules if model == 'A': rules.append( ParameterRule( name='top_k', type=ParameterType.INT, use_template='top_k', min=1, default=50, label=I18nObject( zh_Hans='Top K', en_US='Top K' ) ) ) """ some NOT IMPORTANT code here """ entity = AIModelEntity( model=model, label=I18nObject( en_US=model ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=model_type, model_properties={ ModelPropertyKey.MODE: ModelType.LLM, }, parameter_rules=rules ) return entity -
呼び出しエラーマッピングテーブル
モデル呼び出し時にエラーが発生した場合、Runtimeが指定する
InvokeErrorタイプにマッピングする必要があります。これにより、Difyは異なるエラーに対して異なる後続処理を行うことができます。Runtime エラー:
InvokeConnectionError呼び出し接続エラーInvokeServerUnavailableError呼び出しサービスが利用不可InvokeRateLimitError呼び出し回数制限に達したInvokeAuthorizationError認証エラーInvokeBadRequestError不正なリクエストパラメータ
@property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to uni ```
-> dict[type[呼び出しエラー], list[type[例外]]]: """ モデル呼び出しエラーを統一エラーにマッピングする キーは呼び出し元に投げられるエラータイプ バリューはモデルが投げるエラータイプであり、 呼び出し元に対して統一エラータイプに変換する必要があります。
:return: 呼び出しエラーのマッピング
"""
インターフェース方法の詳細については:[インターフェース](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/docs/zh_Hans/interfaces.md)をご覧ください。具体的な実装例については、[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)を参照してください。