Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions astrbot/dashboard/api/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astrbot.dashboard.async_utils import run_maybe_async
from astrbot.dashboard.responses import error, ok
from astrbot.dashboard.schemas import (
KnowledgeBaseCreateRequest,
KnowledgeBaseImportRequest,
KnowledgeBaseRequest,
KnowledgeBaseRetrieveRequest,
Expand Down Expand Up @@ -53,14 +54,6 @@ def _to_int(value: Any, default: int) -> int:
return default


def _model_dict(payload) -> dict[str, Any]:
if payload is None:
return {}
if hasattr(payload, "model_dump"):
return payload.model_dump(exclude_none=True)
return payload if isinstance(payload, dict) else {}


async def _run(operation, *, prefix: str):
Comment thread
lxfight marked this conversation as resolved.
try:
result = await run_maybe_async(operation)
Expand Down Expand Up @@ -102,12 +95,12 @@ async def list_knowledge_bases(

@router.post("/knowledge-bases")
async def create_knowledge_base(
payload: KnowledgeBaseRequest,
payload: KnowledgeBaseCreateRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.create_kb(_model_dict(payload)),
lambda: service.create_kb(payload.canonical_payload()),
prefix="创建知识库失败",
)

Expand Down Expand Up @@ -140,9 +133,8 @@ async def update_knowledge_base(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.update_kb({"kb_id": kb_id, **body}),
lambda: service.update_kb({**payload.canonical_payload(), "kb_id": kb_id}),
prefix="更新知识库失败",
)

Expand Down Expand Up @@ -213,7 +205,7 @@ async def import_knowledge_base_documents(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.import_documents({"kb_id": kb_id, **body}),
prefix="导入文档失败",
Expand All @@ -227,7 +219,7 @@ async def import_knowledge_base_document_url(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.upload_document_from_url({"kb_id": kb_id, **body}),
prefix="从URL上传文档失败",
Expand Down Expand Up @@ -307,7 +299,7 @@ async def retrieve_knowledge_base(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.retrieve({"kb_id": kb_id, **body}),
prefix="检索失败",
Expand Down
40 changes: 38 additions & 2 deletions astrbot/dashboard/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,49 @@ class ImMessageRequest(OpenModel):


class KnowledgeBaseRequest(OpenModel):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
kb_id: str | None = None
name: str | None = None
kb_name: str | None = Field(None, alias="name")
description: str | None = None
emoji: str | None = None
embedding_provider_id: str | None = None
rerank_provider_id: str | None = None
chunk_size: int | None = None
chunk_overlap: int | None = None
top_k_dense: int | None = None
top_k_sparse: int | None = None
top_m_final: int | None = None

model_config = ConfigDict(populate_by_name=True, extra="allow")

def canonical_payload(self) -> dict[str, Any]:
"""Return the service-facing knowledge base payload.

Returns:
Dictionary accepted by KnowledgeBaseService.
"""
return self.model_dump(
exclude_unset=True,
include={
"kb_name",
"description",
"emoji",
"embedding_provider_id",
"rerank_provider_id",
"chunk_size",
"chunk_overlap",
"top_k_dense",
"top_k_sparse",
"top_m_final",
},
by_alias=False,
)


class KnowledgeBaseCreateRequest(KnowledgeBaseRequest):
model_config = ConfigDict(
populate_by_name=True,
extra="allow",
json_schema_extra={"required": ["name", "embedding_provider_id"]},
)


class KnowledgeBaseImportRequest(OpenModel):
Comment thread
lxfight marked this conversation as resolved.
Expand Down
44 changes: 25 additions & 19 deletions astrbot/dashboard/services/knowledge_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.dashboard.schemas import KnowledgeBaseRequest
from astrbot.dashboard.utils import generate_tsne_visualization


Expand All @@ -29,6 +30,19 @@ def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}

@staticmethod
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
def _canonical_kb_payload(data: object) -> dict[str, Any]:
"""Normalize knowledge base create/update payloads.

Uses KnowledgeBaseRequest to handle the legacy ``name`` →
``kb_name`` migration while preserving operational fields
like ``kb_id``.
"""
raw = KnowledgeBaseService._payload(data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying the input data dictionary in-place via raw.update(canonical) can lead to unexpected side-effects in callers that reuse the dictionary. It is safer to create a copy of the dictionary first.

Suggested change
raw = KnowledgeBaseService._payload(data)
raw = dict(KnowledgeBaseService._payload(data))

canonical = KnowledgeBaseRequest(**raw).canonical_payload()
raw.update(canonical)
return raw

def get_kb_manager(self):
return self.core_lifecycle.kb_manager

Expand Down Expand Up @@ -293,7 +307,7 @@ async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, A

async def create_kb(self, data: object) -> tuple[dict[str, Any], str]:
kb_manager = self.get_kb_manager()
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_name = payload.get("kb_name")
if not kb_name:
raise KnowledgeBaseServiceError("知识库名称不能为空")
Expand Down Expand Up @@ -363,7 +377,7 @@ async def get_kb_from_dashboard_query(self, kb_id: str | None) -> dict[str, Any]
return await self.get_kb(kb_id)

async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
Expand All @@ -380,28 +394,20 @@ async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
"top_k_sparse",
"top_m_final",
]
if all(payload.get(key) is None for key in update_keys):
provided_updates = {key: payload[key] for key in update_keys if key in payload}
if not provided_updates:
raise KnowledgeBaseServiceError("至少需要提供一个更新字段")

current_kb = await self.get_kb_manager().get_kb(kb_id)
kb_name = payload.get("kb_name")
if kb_name is None:
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
kb_name = current_kb.kb.kb_name
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
current = current_kb.kb
update_data = {key: getattr(current, key, None) for key in update_keys}
update_data.update(provided_updates)

kb_helper = await self.get_kb_manager().update_kb(
kb_id=kb_id,
kb_name=kb_name,
description=payload.get("description"),
emoji=payload.get("emoji"),
embedding_provider_id=payload.get("embedding_provider_id"),
rerank_provider_id=payload.get("rerank_provider_id"),
chunk_size=payload.get("chunk_size"),
chunk_overlap=payload.get("chunk_overlap"),
top_k_dense=payload.get("top_k_dense"),
top_k_sparse=payload.get("top_k_sparse"),
top_m_final=payload.get("top_m_final"),
**update_data,
)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
Expand Down Expand Up @@ -762,11 +768,11 @@ async def retrieve(self, data: object) -> dict[str, Any]:

if not query:
raise KnowledgeBaseServiceError("缺少参数 query")
kb_manager = self.get_kb_manager()
if not kb_names or not isinstance(kb_names, list):
raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误")

top_k = payload.get("top_k", 5)
kb_manager = self.get_kb_manager()
results = await kb_manager.retrieve(
query=query,
kb_names=kb_names,
Expand Down
22 changes: 15 additions & 7 deletions dashboard/src/api/generated/openapi-v1/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,22 @@ export type JsonSchema = {
[key: string]: unknown;
};

export type KnowledgeBaseCreateRequest = KnowledgeBaseRequest & {
kb_name: string;
embedding_provider_id: string;
};

export type KnowledgeBaseRequest = {
name: string;
kb_name?: string;
description?: string;
embedding_provider_id?: string;
rerank_provider_id?: string;
chunking?: DynamicConfig;
metadata?: DynamicConfig;
emoji?: string;
embedding_provider_id?: (string) | null;
rerank_provider_id?: (string) | null;
chunk_size?: number;
chunk_overlap?: number;
top_k_dense?: number;
top_k_sparse?: number;
top_m_final?: number;
};

export type KnowledgeDocumentImportRequest = {
Expand All @@ -271,7 +280,6 @@ export type KnowledgeDocumentImportRequest = {

export type KnowledgeDocumentUploadRequest = {
file: (Blob | File);
parser?: string;
};

export type KnowledgeDocumentUrlImportRequest = {
Expand Down Expand Up @@ -2606,7 +2614,7 @@ export type ListKnowledgeBasesResponse = (SuccessEnvelope);
export type ListKnowledgeBasesError = unknown;

export type CreateKnowledgeBaseData = {
body: KnowledgeBaseRequest;
body: KnowledgeBaseCreateRequest;
};

export type CreateKnowledgeBaseResponse = (SuccessEnvelope);
Expand Down
10 changes: 6 additions & 4 deletions dashboard/src/api/v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import {
type DynamicConfig,
type EnabledPatch,
type GhproxyTestRequest,
type KnowledgeBaseCreateRequest,
type KnowledgeBaseRequest,
type LoginRequest,
type ListConversationsData,
type McpServerConfig,
Expand Down Expand Up @@ -1366,16 +1368,16 @@ export const knowledgeApi = {
openApiV1.getKnowledgeBase({ path: { kb_id: kbId } }),
);
},
create(config: OpenConfig) {
create(config: KnowledgeBaseCreateRequest) {
return typed<OpenConfig>(
openApiV1.createKnowledgeBase({ body: config as any }),
openApiV1.createKnowledgeBase({ body: config }),
);
},
update(kbId: string, config: OpenConfig) {
update(kbId: string, config: KnowledgeBaseRequest) {
return typed<OpenConfig>(
openApiV1.updateKnowledgeBase({
path: { kb_id: kbId },
body: config as any,
body: config,
}),
);
},
Expand Down
9 changes: 7 additions & 2 deletions dashboard/src/views/knowledge-base/KBList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@

<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null" hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null"
:rules="[v => editingKB !== null || !!v || t('create.embeddingModelRequired')]" required
hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
Expand Down Expand Up @@ -455,7 +457,10 @@ const submitForm = async () => {
if (editingKB.value) {
response = await knowledgeApi.update(editingKB.value.kb_id, payload)
} else {
response = await knowledgeApi.create(payload)
response = await knowledgeApi.create({
...payload,
embedding_provider_id: formData.value.embedding_provider_id!
})
}

if (response.data.status === 'ok') {
Expand Down
38 changes: 27 additions & 11 deletions openspec/openapi-v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3381,7 +3381,7 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/KnowledgeBaseRequest"
$ref: "#/components/schemas/KnowledgeBaseCreateRequest"
responses:
"200":
$ref: "#/components/responses/Ok"
Expand Down Expand Up @@ -5700,31 +5700,47 @@ components:

KnowledgeBaseRequest:
type: object
required: [name]
properties:
name:
kb_name:
type: string
description:
type: string
embedding_provider_id:
emoji:
type: string
embedding_provider_id:
type: [string, "null"]
rerank_provider_id:
type: string
chunking:
$ref: "#/components/schemas/DynamicConfig"
metadata:
$ref: "#/components/schemas/DynamicConfig"
type: [string, "null"]
chunk_size:
type: integer
chunk_overlap:
type: integer
top_k_dense:
type: integer
top_k_sparse:
type: integer
top_m_final:
type: integer
additionalProperties: false

KnowledgeBaseCreateRequest:
allOf:
- $ref: "#/components/schemas/KnowledgeBaseRequest"
- type: object
required: [kb_name, embedding_provider_id]
properties:
kb_name:
type: string
embedding_provider_id:
type: string

KnowledgeDocumentUploadRequest:
type: object
required: [file]
properties:
file:
type: string
format: binary
parser:
type: string

KnowledgeDocumentImportRequest:
type: object
Expand Down
Loading
Loading