Python
VexDB 为使用 Python 语言进行数据库开发的用户提供了以下两个软件包。
- vexdb-psycopg2
基于 psycopg2 开发的数据库驱动包,它允许应用程序与数据库进行交互和通信,通过 Python 程序执行原始的 SQL 语句并返回结果。
该包必须依赖 vexdb 的开发库进行编译(libpq),不能与 pyvector-vexdb 同时安装。 - pyvector-vexdb
pyvector-vexdb 是一个向量扩展包。不能与 vexdb-psycopg2 同时安装。- 通过 psycopg2 驱动的 register_adapter 接口为 psycopg2 驱动扩展了向量类型 FloatVector。
- 通过 SQLAlchemy ORM 的 TypeDecorator 接口为SQLAlchemy ORM扩展了向量类型 FloatVector。
vexdb-psycopg2
本节介绍通过 Python 驱动连接并操作 VexDB 的方式,并提供示例代码。
软件安装
- 使用驱动软件前,请确保已经参考 安装 VexDB 的内容完成了数据库的安装。
- 参考以下步骤配置环境变量:
# 应按实际的环境信息进行配置 export GAUSSHOME=/usr/local/vexdb export PATH=$GAUSSHOME/bin:$PATH export LD_LIBRARY_PATH=$GAUSSHOME/lib:$LD_LIBRARY_PATH - 安装 Python3 和 GCC。
yum install -y python3 python3-devel gcc - 使用root用户卸载已安装的 psycopg2(如有,否则跳过此步骤)
pip3 uninstall psycopg2 -y - 上传并解压 vexdb-psycopg2 源码包。
tar -xvf vexdb-psycopg2-{version}.tar.gz - 切换到解压后的目录,执行安装脚本。
cd vexdb-psycopg2-{version} python3 setup.py build # 此步骤可能需要 root 用户执行 python3 setup.py install - 检查安装结果。
进入python3命令行,执行以下命令:import psycopg2
导入成功即表示安装完成。
连接数据库
获取数据库连接
使用 psycopg2.connect() 获取 connection 对象,支持通过 dsn 与 key-value 两种格式连接数据库。
- DSN 格式:
conn = psycopg2.connect("host=127.0.0.1 port=5432 dbname=postgres user=vbadmin password=vbase@123")
对于DSN 格式,不支持用户密码中有 @ 等特殊字符。 - key-value 格式:
conn = psycopg2.connect(host="127.0.0.1", port=5432, dbname='postgres', user='vbadmin', password='vbase@123')
关闭数据库连接
关闭连接时调用 connect 的 close() 方法。
conn = psycopg2.connect(database="postgres", user="vbadmin", password="Vbase@123", host="127.0.0.1", port="5432")
conn.close()
配置数据库集群连接
psycopg2 支持在 DSN 中配置多个 Host:Port 对(Key-Value 格式则是配置多个 host),以逗号分隔。target_session_attrs 用于控制连接的数据库属性。
如下示例中,将 target_session_attrs 配置为 primary 表示仅允许连接主库,在发生主备切换后,psycopg2 会连接到切换后的新主库。
conn = psycopg2.connect(
host="xxx.xxx.xxx.xxa,xxx.xxx.xxx.xxb",
port=5432,
database='postgres',
user='vbadmin',
password='Vbase@123',
target_session_attrs='primary'
)
target_session_attrs 的可选取值包括:
- read-write
只接受可以读写的数据库。 - any
表示可以允许连接到任意数据库。 - read-only
表示仅允许连接到只读的数据库。 - primary
表示仅允许连接主库。 - standby
表示仅允许连接备库。 - prefer-standby
表示连接到任意一个只读的数据库节点,如果没有可用的只读节点,则连接到可读可写节点。
示例
import string
import ctypes
import decimal
import platform
import unittest
import random
import testutils
import ast
import csv
import datetime
import io
from io import StringIO
import json
import struct
import numpy
from typing import Any
# 注意修改ConnectingTestCase中的数据库连接信息
from testutils import ConnectingTestCase, restore_types
import psycopg2
from psycopg2.extensions import FloatVector
class FloatVectorTest(ConnectingTestCase):
def test_floatvector_dql(self):
curs = self.conn.cursor()
try:
curs.execute("select '[1,2,3]'::floatvector(3);")
result = curs.fetchone()[0]
print(result.tolist())
print(FloatVector([1,2,3]).tolist())
print(str(FloatVector([1,2,3])))
self.assertEqual(result.tolist(),FloatVector([1,2,3]).tolist())
self.assertEqual(str(result),str(FloatVector([1,2,3])))
curs.execute("select '[0]'::floatvector(1);")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([0]).tolist())
self.assertEqual(str(result),str(FloatVector([0])))
curs.execute("select '[1.00001,0.001234]'::floatvector(2);")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.00001,0.001234]).tolist())
self.assertEqual(str(result),str(FloatVector([1.00001,0.001234])))
curs.execute("SELECT ARRAY[1,1,1]::floatvector;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1,1,1]).tolist())
self.assertEqual(str(result),str(FloatVector([1,1,1])))
curs.execute("SELECT '[1,1,1]'::floatvector;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1,1,1]).tolist())
self.assertEqual(str(result),str(FloatVector([1,1,1])))
curs.execute("SELECT '[1.0]'::text::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::char(5)::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::nchar(5)::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::varchar(5)::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::nvarchar(5)::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::clob::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
curs.execute("SELECT '[1.0]'::bpchar(5)::floatvector AS vector_example;")
result = curs.fetchone()[0]
print(str(result))
print("------------------")
self.assertEqual(result.tolist(),FloatVector([1.0]).tolist())
self.assertEqual(str(result),str(FloatVector([1.0])))
print("query result: ",result)
self.conn.commit()
except Exception:
self.conn.rollback()
raise
finally:
curs.close()
接口参考
vexdb-psycopg2 为使用 Python 语言的应用程序定义了一组访问和操作 VexDB 的标准接口。
psycopg2.connect()
此方法用于创建新的数据库会话并返回新的connection对象。
返回值:connection对象(连接数据库实例的对象)。
参数说明如下:
| 关键字 | 参数说明 |
|---|---|
| dbname | 数据库名称。 |
| user | 用户名。 |
| password | 密码。 |
| host | 数据库所在主机 IP 地址。 |
| port | 连接端口号,默认为5432。 |
| sslmode | ssl 模式,ssl 连接时用。 |
| sslcert | 客户端证书路径,ssl 连接时用。 |
| sslkey | 客户端秘钥路径,ssl 连接时用。 |
| sslrootcert | 根证书路径,ssl 连接时用。 |
| target_session_attrs | 配置多个 host 时,数据库连接的优先级。该参数会使用目标库的 libpq 中的target_session_attrs。 在连接的时候,只接受可以读写的数据库。建立连接后,会发送 SHOW transaction_read_only,如果是 o n,代表是只读库,psycopg2 会把连接关闭;然后测试第二个数据库,以此类推,直至连接到支持读写的数据库为止。 表示可以允许连接到任意数据库,它会从所有配置的连接中随机选择一个,如果连接的数据库出现故障导致连接断开,会尝试连接其他数据库,从而实现故障转移。 表示仅允许连接到只读的数据库。 表示仅允许连接主库。 表示仅允许连接备库。 表示连接到任意一个只读的数据库节点,如果没有可用的只读节点,则连接到可读可写节点。 |
connection.cursor()
此方法用于返回新的 cursor 对象。参数说明如下:
| 关键字 | 参数说明 |
|---|---|
| name | cursor 名称,默认为 None |
| cursor_factory | 用于创造非标准 cursor,默认为 None |
| scrollable | 设置 SCROLL 选项,默认为 None |
| withhold | 设置 HOLD 选项,默认为 False |
cursor.execute()
此方法执行被参数化的SQL语句(即占位符,而不是SQL文字)。psycopg2模块支持用%s标志的占位符。参数说明如下:
| 关键字 | 参数说明 |
|---|---|
| query | 待执行的 sql 语句 |
| vars_list | 变量列表,匹配 query 中%s 为占位符 |
connection.commit()
此方法将当前挂起的事务提交到数据库。默认情况下,psycopg2在执行第一个命令之前打开一个事务:如果不调用commit(),任何数据操作的效果都将丢失。
参数:无。
返回值:无。
connection.rollback()
此方法回滚当前挂起事务。
参数:无。
返回值:无
cursor.fetchone()
此方法提取查询结果集的下一行,并返回一个元组。
参数:无。
返回值:单个元组,为结果集的第一条结果,当没有更多数据可用时,返回为“None”。
cursor.fetchall()
此方法获取查询结果的所有(剩余)行,并将它们作为元组列表返回。
参数:无。
返回值:元组列表,为结果集的所有结果。空行时则返回空列表。
cursor.close()
此方法用于关闭当前连接的游标。
参数:无。
返回值:无。
connection.close()
此方法用于关闭数据库连接。
参数:无。
返回值:无。
pyvector-vexdb
开源库 psycopg2 允许 Python 程序与 PostgreSQL 数据库进行交互,VexDB 提供了基于 psycopg2 开发的 SDK: pyvector-vexdb。pyvector-vexdb 在 psycopg2 的基础上进行修改,兼容 SQLAlchemy,Django ORM,Peewee 等 ORM 框架。
软件安装
说明
操作前请确保已经参考 安装VexDB 完成了数据库安装,并部署了 python3 环境。 依赖 psycopg2,pyvector-vexdb 对操作系统和 CPU 没有要求,python 版本需要 ≥3.8。
将 pyvector-vexdb 安装到本地环境中,可参考如下命令:
pip install psycopg2-binary #必须
pip install sqlalchemy #如果要使用SQLAlchemy集成
pip install numpy #可选,支持直接映射numpy到向量类型
pip install pyvector-vexdb.xxx.whl
它提供了一个功能齐全的 API 来执行 SQL 语句,可以处理事务、获取查询结果、执行批量数据插入等操作。
psycopg2 模式
原生 psycopg2 并不支持向量类型,pyvector-vexdb 通过 psycopg2 的 register_adapter 机制,可以为 psycopg2 新增向量类型 floatvector 的支持。
- 和数据库建立连接。
from vexdb.psycopg2 import register_vector conn = psycopg2.connect("dbname=postgres,user=vexdb,password=123456,host=172.0.0.1,port=5432") - 将向量类型注册到您的连接或游标。
#将向量类型注册到当前连接 register_vector(conn) #或将向量类型注册到游标 #cur = conn.cursor() #register_vector(cur) - 创建一个向量表。
cur.execute('CREATE TABLE items (id bigserial PRIMARY KEY, embedding floatvector(3))') - 插入一个向量(向量嵌入-embedding)
embedding = np.array([1, 2, 3]) cur.execute('INSERT INTO items (embedding) VALUES (%s)', (embedding,)) - 对向量进行函数/操作符操作:
- 获取一个向量最近邻
cur.execute('SELECT * FROM items ORDER BY embedding <-> %s LIMIT 5', (embedding,)) cur.fetchall() - 计算两个向量的余弦距离
余弦相似度 = 1-余弦距离,比较两个向量的方向而不是它们的大小。余弦相似度的范围在-1到1之间,1表示向量相同,0表示无关,-1表示向量指向相反方向。cur.execute('SELECT 1-(%s <->%s) as value', (embedding1,embedding2)) cur.fetchall() - 计算两个向量的负内积。内积 = -1*负内积,内积表示两个向量在相同维度上的对应分量相乘后再求和的结果,内积越大表示方向越一致。
cur.execute(' -1 * (%s <#> %s) AS value',(embedding1,embedding2)) cur.fetchall() - 两个向量逐元素相加。
cur.execute('SELECT %s + %s AS value',(embedding1,embedding2)) cur.fetchall() - floatvector_combine(double precision, double precision)
将两个double precision类型的数组逐元素相加成一个新的向量,返回这个新向量。cur.execute('SELECT floatvector_combine(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_accum(double precision, floatvector)
将一个向量累加到数组中,返回一个新的数组。其中,被累加的数组及结果数组的第一个元素为累加次数,之后的元素为各维度的累积值。cur.execute('SELECT floatvector_accum(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_cmp(floatvector, floatvector)
函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量小于第二个,返回 -1;如果相等,返回 0;如果第一个向量大于第二个,返回 1。cur.execute('SELECT floatvector_cmp(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_gt(floatvector, floatvector)
函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量大于第二个向量,则返回true,否则返回false。cur.execute('SELECT floatvector_gt(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_ge(floatvector, floatvector)
函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量大于等于第二个向量,则返回true,否则返回false。cur.execute('SELECT floatvector_ge(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_ne(floatvector, floatvector)
函数用于比较两个向量,如果两个向量的任意元素不相等,则返回true,否则返回false。cur.execute('SELECT floatvector_ne(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_eq(floatvector, floatvector)
这个函数用于比较两个向量,如果两个向量的所有元素都相等,则返回true,否则返回false。cur.execute('SELECT floatvector_eq(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_le(floatvector, floatvector)
这个函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量小于等于第二个向量,则返回true,否则返回false。cur.execute('SELECT floatvector_le(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_lt(floatvector, floatvector)
这个函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量小于于第二个向量,则返回true,否则返回false。cur.execute('SELECT floatvector_lt(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_spherical_distance(floatvector, floatvector)
这个函数用于计算两个向量之间的球面距离(spherical distance),即两个向量之间的夹角的余弦值。cur.execute('SELECT floatvector_spherical_distance(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_negative_inner_product(floatvector, floatvector)
这个函数用于计算两个向量的负内积(negative inner product),即两个向量对应位置上的元素相乘后求和并取负值。cur.execute('SELECT floatvector_negative_inner_product(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_l2_squared_distance(floatvector, floatvector)
这个函数用于计算两个向量之间的L2范数的平方距离。L2范数距离是向量元素差的平方和。cur.execute('SELECT floatvector_l2_squared_distance(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_avg(double precision)
这个函数用于计算一个double precision类型数组中所有元素的值进行N等分。N是数组的第一个元素,并返回一个N等分后的向量。cur.execute('SELECT floatvector_avg(%s)',(embedding)) cur.fetchall() - floatvector_sub(floatvector, floatvector)
这个函数用于计算两个向量的元素级相减,返回一个新的向量,其中每个元素是对应位置上两个向量元素的差。cur.execute('SELECT floatvector_sub(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_add(floatvector, floatvector)
这个函数用于计算两个向量的元素级相加,返回一个新的向量,其中每个元素是对应位置上两个向量元素的和。cur.execute('SELECT floatvector_add(%s,%s)',(embedding1,embedding2)) cur.fetchall() - floatvector_norm(floatvector)
这个函数用于计算给定向量的范数,即向量元素的平方和的平方根。cur.execute('SELECT floatvector_norm(%s)',(embedding)) cur.fetchall() - floatvector_dims(floatvector)
这个函数用于返回给定向量的维度(即向量中元素的数量)。cur.execute('SELECT floatvector_dims(%s)',(embedding)) cur.fetchall() - l2_distance(floatvector, floatvector)
这个函数用于计算两个向量之间的L2范数距离。L2范数距离也称为欧氏距离,表示两个向量之间的直线距离。cur.execute('SELECT l2_distance(%s,%s)',(embedding1,embedding2)) cur.fetchall() - inner_product(floatvector, floatvector)
这个函数用于计算两个向量的内积。内积是两个向量对应元素乘积的和。cur.execute('SELECT inner_product(%s,%s)',(embedding1,embedding2)) cur.fetchall() - cosine_distance(floatvector, floatvector)
这个函数用于计算两个向量之间的余弦距离。余弦距离是通过计算两个向量之间的夹角余弦值来衡量它们之间的相似度。cur.execute('SELECT cosine_distance(%s,%s)',(embedding1,embedding2)) cur.fetchall()
- 获取一个向量最近邻
- 添加一个近似索引。
cur.execute('CREATE INDEX ON items USING hnsw (embedding floatvector_l2_ops)') # or cur.execute('CREATE INDEX ON items USING ivfflat (embedding floatvector_l2_ops) WITH (lists = 100)')
说明
更多索引类型、参数和操作符使用详见《向量检索指南》。
示例
import numpy as np
from vexdb.psycopg2 import register_vector
import psycopg2
from psycopg2.extras import DictCursor, RealDictCursor, NamedTupleCursor
from psycopg2.pool import ThreadedConnectionPool
# 使用dsn
conn = psycopg2.connect("dbname=postgres user=test password=Test@1234 host=localhost port=5432")
# 使用关键字
# conn = psycopg2.connect(
# dbname="postgres",
# user="test",
# password="Test@1234",
# host="localhost",
# port=5432)
conn.autocommit = True
cur = conn.cursor()
cur.execute('DROP TABLE IF EXISTS psycopg2_items')
cur.execute('CREATE TABLE psycopg2_items (id bigserial PRIMARY KEY, embedding floatvector(3))')
# 注册到连接
register_vector(conn)
# 注册到游标
# register_vector(cur)
class TestPsycopg2:
def setup_method(self, test_method):
cur.execute('DELETE FROM psycopg2_items')
def test_vector(self):
embedding = np.array([1.5, 2, 3])
cur.execute('INSERT INTO psycopg2_items (embedding) VALUES (%s), (NULL)', (embedding,))
cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id')
res = cur.fetchall()
assert np.array_equal(res[0][0], embedding)
assert res[0][0].dtype == np.float32
assert res[1][0] is None
def test_query(self):
pass
def test_cursor_factory(self):
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(host='localhost', port=5432, database='postgres', user='test', password='Test@1234')
cur = conn.cursor(cursor_factory=cursor_factory)
register_vector(cur, globally=False)
conn.close()
def test_cursor_factory_connection(self):
for cursor_factory in [DictCursor, RealDictCursor, NamedTupleCursor]:
conn = psycopg2.connect(host='localhost', port=5432, database='postgres', user='test', password='Test@1234', cursor_factory=cursor_factory)
register_vector(conn, globally=False)
conn.close()
def test_pool(self):
pool = ThreadedConnectionPool(1, 1, host='localhost', port=5432, database='postgres', user='test', password='Test@1234')
conn = pool.getconn()
try:
# use globally=True for apps to ensure registered with all connections
register_vector(conn, globally=False)
finally:
pool.putconn(conn)
conn = pool.getconn()
try:
cur = conn.cursor()
cur.execute("SELECT '[1,2,3]'::floatvector")
res = cur.fetchone()
assert np.array_equal(res[0], np.array([1, 2, 3]))
finally:
pool.putconn(conn)
pool.closeall()
SQLAlchemy 兼容模式
pyvector-vexdb 提供了 SQLAlchemy 集成,可以直接映射到 VexDB 的向量类型和向量函数。这允许您直接使用向量类型 FloatVector 定义 SQLAlchemy 模型,并使用熟悉的 SQLAlchemy 语法执行向量相似性搜索。
使用方式
- 映射向量列。
from vexdb.sqlalchemy import FloatVector class Item(Base): embedding = mapped_column(FloatVector(3)) - 插入向量。
item = Item(embedding=[1, 2, 3]) session.add(item) session.commit() - 对向量进行函数/操作符操作:
- 查询近似向量:使用欧几里得距离操作符l2_distance(<->)
session.scalar(select(Item).order_by(Item.embedding.l2_distance([3, 1, 2])).limit(5))
除此以外,还支持如下操作符,具体使用详见函数和操作符。- negative_inner_product(<#>)
- consine_distance(<=>)
- add(+)
- sub(-)
- l2_distance函数,查询向量的欧几里得距离:
session.scalar(select(Item.embedding.l2_distance([3, 1, 2])))
或者查询近似向量(使用欧几里得距离)session.scalar(select(Item).filter(Item.embedding.l2_distance([3, 1, 2]) < 5)) - floatvector_combine函数,将两个double precision类型的数组逐元素相加成一个新的向量,返回这个新向量。
session.scalar(session.query(floatvector_combine([1.0,2.0,3.0], [4,5,6]))) - floatvector_accum函数,用于将一个向量累加到数组中,返回一个新的数组。其中,被累加的数组及结果数组的第一个元素为累加次数,之后的元素为各维度的累积值。
session.scalar(session.query(floatvector_accum([1.0, 2.0, 3.0], [4, 5]))) - floatvector_cmp函数,这个函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量小于第二个,返回 -1;如果相等,返回 0;如果第一个向量大于第二个,返回 1。
session.scalar(session.query(floatvector_cmp([1,1,1,1],[2,2,2,2]))) - floatvector_gt,这个函数逐个比较两个向量的元素,找到第一个不同的元素并根据它决定大小关系。如果第一个向量大于第二个向量,则返回true,否则返回false。
session.scalar(session.query(floatvector_gt([1,1,1,1],[2,2,2,2]))) - floatvector_spherical_distance函数,函数用于计算两个向量之间的球面距离,即两个向量之间的夹角的余弦值。
session.scalar(session.query(floatvector_spherical_distance([1,1,1,1],[2,2,2,2])))
除此以外,还支持如下函数,具体使用详见函数和操作符。- floatvector_ge
- floatvector_ne
- floatvector_eq
- floatvector_le
- floatvector_lt
- floatvector_negative_inner_product
- floatvector_l2_squared_distance
- floatvector_avg
- floatvector_sub
- floatvector_add
- floatvector_norm
- floatvector_dims
- l2_distance
- inner_product
- cosine_distance
- 查询近似向量:使用欧几里得距离操作符l2_distance(<->)
- 近似最近邻索引。
index = Index( 'my_index', Item.embedding, postgresql_using='hnsw', postgresql_with={'m': 16, 'ef_construction': 64}, postgresql_ops={'embedding': 'floatvector_l2_ops'} ) # or index = Index( 'my_index', Item.embedding, postgresql_using='ivfflat', postgresql_with={'ivf_nlist': 100}, postgresql_ops={'embedding': 'floatvector_l2_ops'} )
说明
更多索引类型、参数和操作符使用详见《向量检索指南》。
示例1:典型向标联合查询
以下程序展示了使用 SQLAlchemy 语法,构建在手机商城中执行向量相似性搜索的示例。
from sqlalchemy import create_engine, Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import text
from vexdb.sqlalchemy import FloatVector
import numpy as np
from typing import List, Optional, Tuple
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 数据库配置
#DATABASE_URL = "postgresql+psycopg2://test:Test%401234@172.16.100.97:5432/postgres"
DATABASE_URL = "postgresql+psycopg2://postgres:12345Aa*@172.16.97.114:5432/postgres"
# 创建基类
Base = declarative_base()
class Product(Base):
"""手机产品模型"""
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False)
color = Column(Text, nullable=False) # 颜色枚举值:'black', 'white', 'red', 'blue', 'gold'
description = Column(Text, nullable=False)
features = Column(FloatVector(512), nullable=False) # 512维特征向量
def __repr__(self):
return f"<Product(id={self.id}, name='{self.name}', color='{self.color}')>"
class ProductVectorSearch:
"""手机产品向量检索服务"""
def __init__(self, database_url: str):
self.engine = create_engine(database_url)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
# 创建表
self._create_tables_and_extensions()
def _create_tables_and_extensions(self):
"""创建表"""
# 清空表,便于重新创建
Base.metadata.drop_all(bind=self.engine)
# 创建所有表
Base.metadata.create_all(bind=self.engine)
# 创建向量索引以提升检索性能
with self.engine.connect() as conn:
try:
conn.execute(text("""
CREATE INDEX IF NOT EXISTS products_features_idx
ON products USING ivfflat (features floatvector_cosine_ops)
WITH (ivf_nlist = 100)
"""))
conn.commit()
logger.info("向量索引创建成功")
except Exception as e:
logger.warning(f"创建向量索引失败: {e}")
def add_product(self, id: int, name: str, color: str, description: str, features: np.ndarray) -> Product:
"""添加新产品"""
with self.SessionLocal() as session:
product = Product(
id=id,
name=name,
color=color,
description=description,
features=features.tolist() # 转换为列表格式
)
session.add(product)
session.commit()
session.refresh(product)
return product
def vector_search_with_color_filter(
self,
query_vector: np.ndarray,
color_filter: Optional[str] = None,
colors_filter: Optional[List[str]] = None,
limit: int = 10,
similarity_threshold: float = 0.5
) -> List[Tuple[Product, float]]:
"""
使用向量检索结合颜色过滤的联合查询
Args:
query_vector: 查询向量 (512维)
color_filter: 单个颜色过滤条件
colors_filter: 多个颜色过滤条件列表
limit: 返回结果数量限制
similarity_threshold: 相似度阈值 (0-1)
Returns:
List[Tuple[Product, float]]: 产品和相似度分数的元组列表
"""
with self.SessionLocal() as session:
# 构建基础查询,计算余弦相似度
query = session.query(
Product,
(1 - Product.features.cosine_distance(query_vector)).label('similarity')
)
# 添加颜色过滤条件
if color_filter:
query = query.filter(Product.color == color_filter)
elif colors_filter:
query = query.filter(Product.color.in_(colors_filter))
# 添加相似度阈值过滤
query = query.filter(
(1 - Product.features.cosine_distance(query_vector)) >= similarity_threshold
)
# 按相似度降序排列并限制结果数量
results = query.order_by(
(1 - Product.features.cosine_distance(query_vector)).desc()
).limit(limit).all()
return [(product, float(similarity)) for product, similarity in results]
def hybrid_search(
self,
query_vector: np.ndarray,
color_filter: Optional[str] = None,
price_range: Optional[Tuple[float, float]] = None,
limit: int = 10
) -> List[Tuple[Product, float]]:
"""
混合搜索:向量相似度 + 多重标量字段过滤
"""
with self.SessionLocal() as session:
query = session.query(
Product,
(1 - Product.features.cosine_distance(query_vector)).label('similarity')
)
# 颜色过滤
if color_filter:
query = query.filter(Product.color == color_filter)
# 可以扩展更多过滤条件,比如价格范围
# if price_range:
# query = query.filter(Product.price.between(price_range[0], price_range[1]))
results = query.order_by(
(1 - Product.features.cosine_distance(query_vector)).desc()
).limit(limit).all()
return [(product, float(similarity)) for product, similarity in results]
def get_similar_products_by_description(
self,
description: str,
embedding_function, # 文本转向量的函数
color_filter: Optional[str] = None,
limit: int = 5
) -> List[Tuple[Product, float]]:
"""
根据描述文本查找相似产品
Args:
description: 产品描述文本
embedding_function: 将文本转换为向量的函数
color_filter: 颜色过滤
limit: 结果数量限制
"""
# 将描述转换为向量
query_vector = embedding_function(description)
return self.vector_search_with_color_filter(
query_vector=query_vector,
color_filter=color_filter,
limit=limit
)
def advanced_search(
self,
query_vector: np.ndarray,
filters: dict,
limit: int = 10
) -> List[Tuple[Product, float]]:
"""
高级搜索:支持动态过滤条件
Args:
query_vector: 查询向量
filters: 过滤条件字典,如 {'color': 'black', 'colors': ['black', 'white']}
limit: 结果数量限制
"""
with self.SessionLocal() as session:
query = session.query(
Product,
(1 - Product.features.cosine_distance(query_vector)).label('similarity')
)
# 动态添加过滤条件
if 'color' in filters:
query = query.filter(Product.color == filters['color'])
if 'colors' in filters:
query = query.filter(Product.color.in_(filters['colors']))
if 'name_contains' in filters:
query = query.filter(Product.name.ilike(f"%{filters['name_contains']}%"))
results = query.order_by(
(1 - Product.features.cosine_distance(query_vector)).desc()
).limit(limit).all()
return [(product, float(similarity)) for product, similarity in results]
# 使用示例
def example_usage():
"""使用示例"""
# 初始化搜索服务
search_service = ProductVectorSearch(DATABASE_URL)
# 模拟添加一些手机产品数据
sample_products = [
{
"id": 1,
"name": "iPhone 15 Pro",
"color": "black",
"description": "高端智能手机,配备A17 Pro芯片,钛金属机身,三摄系统",
"features": np.random.rand(512).astype(np.float32) # 实际应用中这里是真实的嵌入向量
},
{
"id": 2,
"name": "Samsung Galaxy S24",
"color": "white",
"description": "Android旗舰手机,骁龙8 Gen3处理器,AI拍照功能",
"features": np.random.rand(512).astype(np.float32)
},
{
"id": 3,
"name": "小米14 Pro",
"color": "blue",
"description": "徕卡影像系统,骁龙8 Gen3,120W快充",
"features": np.random.rand(512).astype(np.float32)
}
]
# 添加产品到数据库
for product_data in sample_products:
search_service.add_product(**product_data)
# 示例1: 基于向量查询,过滤黑色手机
query_vector = np.random.rand(512).astype(np.float32)
results = search_service.vector_search_with_color_filter(
query_vector=query_vector,
color_filter="black",
limit=5
)
print("基于向量查询,过滤黑色手机:")
for product, similarity in results:
print(f" {product.name} (相似度: {similarity:.3f})")
# 示例2: 多颜色过滤
results = search_service.vector_search_with_color_filter(
query_vector=query_vector,
colors_filter=["black", "white"],
limit=5
)
print("\n黑色或白色手机搜索结果:")
for product, similarity in results:
print(f" {product.name} - {product.color} (相似度: {similarity:.3f})")
# 示例3: 高级搜索
results = search_service.advanced_search(
query_vector=query_vector,
filters={
"colors": ["blue", "white"],
"name_contains": "Pro"
},
limit=3
)
print("\n高级搜索结果 (蓝色或白色 + 名称包含Pro):")
for product, similarity in results:
print(f" {product.name} - {product.color} (相似度: {similarity:.3f})")
if __name__ == "__main__":
# 运行示例
example_usage()
示例2:BM25多路召回和融合排序
融合排序通常用于信息检索和推荐系统中,旨在将多个排序列表(如来自不同算法或模型的搜索结果)合并,以提高排序质量。最终得到更优的排序列表。
本示例展示了一个结合了BM25全文检索和向量语义搜索的程序,并提供了两种融合方法:线性加权融合和RRF(Reciprocal Rank Fusion)融合。
- 加权求和:一种基于分数的融合方法,为每个列表分配权重,计算每个项目的加权分数总和,然后根据最终分数排序。
示例中,分别设置了向量语义检索,和BM25检索的权重,并对向量检索和BM25查询的结果进行归一化处理。 - RRF(Reciprocal Rank Fusion)融合算法:一种基于排名的融合方法,通过计算每个项目在每个列表中的排名的倒数之和来得到最终分数,常用于合并来自不同源的排序结果。
RRF算法的优势在于它不依赖于具体的评分机制,只利用排名信息,因此可以融合来自不同来源、不同评分标准的排序列表。同时,由于使用了倒数函数,排名越靠前(数值小)的文档贡献的分数越大,且随着排名靠后,贡献的分数会迅速减小。
在本例中,RRF算法将BM25和向量搜索的排名进行融合,从而结合了关键词匹配和语义相似度的优势,得到更全面的搜索结果。
另外,本例提供了分别执行BM25和向量搜索的方法,以及一个测试函数,便于对比两种融合方法的结果和性能。
"""
BM25 + 向量融合检索Demo
基于VexDB实现BM25全文检索与向量相似度融合搜索
数据表结构:
products (
id INT PRIMARY KEY, #产品ID
name TEXT NOT NULL, #产品名称
color TEXT NOT NULL, #颜色
description TEXT NOT NULL, #产品描述
title_tokens TEXT NOT NULL,#标题关键字
embedding_vec floatvector(384) #向量嵌入
)
主要特性:
1. BM25和向量搜索分别在本地Python中执行,避免SQL中的复杂计算
2. 支持两种融合方法:
- 线性融合: BM25分数经过min-max归一化后,加权融合 (默认权重: 文本0.6, 向量0.4)
- RRF融合: Reciprocal Rank Fusion,平滑参数k=60
3. 向量搜索增加了相似度阈值判断 (similarity > 0.3),过滤低相关性结果
使用示例:
demo = HybridSearchDemo()
results = demo.hybrid_search("iPhone 15 Pro", fusion_method="linear")
results = demo.hybrid_search("旗舰手机", fusion_method="rrf")
"""
import numpy as np
import psycopg2
from psycopg2 import sql
from psycopg2.extras import DictCursor
import logging
import time
from typing import List, Dict, Optional
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 尝试导入transformers库
try:
from sentence_transformers import SentenceTransformer
TRANSFORMER_AVAILABLE = True
logger.info("sentence-transformers 可用")
except ImportError:
TRANSFORMER_AVAILABLE = False
logger.warning("sentence-transformers 不可用,将使用随机向量")
class HybridSearchDemo:
"""融合检索demo"""
def __init__(self, host: str = "localhost", port: int = 5438,
dbname: str = "postgres", user: str = "test", password: str = "Test@1234",
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
self.host = host
self.port = port
self.dbname = dbname
self.user = user
self.password = password
self.table_name = "products"
self.model_name = model_name
# 初始化transformer模型
self.model = None
self.vector_size = 384 # 固定向量维度
self.transformer_available = TRANSFORMER_AVAILABLE
if self.transformer_available:
try:
logger.info(f"正在加载模型: {model_name}")
self.model = SentenceTransformer(model_name)
logger.info(f"模型加载成功,向量维度: {self.vector_size}")
except Exception as e:
logger.error(f"模型加载失败: {e}")
self.model = None
self.transformer_available = False
# 保持单一连接
self.conn = None
def get_connection(self):
"""获取单一数据库连接"""
if not self.conn:
try:
self.conn = psycopg2.connect(
host=self.host, port=self.port,
dbname=self.dbname, user=self.user, password=self.password,
connect_timeout=10
)
logger.info(f"单一连接建立成功: {self.host}:{self.port}")
except Exception as e:
logger.error(f"连接失败: {e}")
self.conn = None
return self.conn
def init_data(self):
"""初始化数据"""
conn = self.get_connection()
if not conn:
return False
try:
with conn.cursor() as cur:
# 创建表
cur.execute(f"""
DROP TABLE IF EXISTS {self.table_name};
CREATE TABLE {self.table_name} (
id INT PRIMARY KEY,
name TEXT NOT NULL,
color TEXT NOT NULL,
description TEXT NOT NULL,
title_tokens TEXT NOT NULL,
embedding_vec floatvector({self.vector_size})
);
""")
# 插入测试数据
products = [
(1, "iPhone 15 Pro Max", "蓝色",
"Apple iPhone 15 Pro Max 苹果旗舰手机,配备A17 Pro芯片,钛金属设计,支持5G网络,拥有三摄像头系统",
["苹果", "iPhone", "15", "Pro", "Max", "旗舰", "A17", "Pro", "芯片", "钛金属", "5G", "三摄像头"]),
(2, "Galaxy S24 Ultra", "黑色",
"Samsung Galaxy S24 Ultra 三星旗舰手机,搭载骁龙8 Gen3处理器,配备S Pen手写笔,支持卫星通信",
["三星", "Galaxy", "S24", "Ultra", "旗舰", "骁龙", "8", "Gen3", "S", "Pen", "手写笔", "卫星通信"]),
(3, "小米14 Pro", "白色",
"小米14 Pro 徕卡光学镜头,骁龙8 Gen3处理器,徕卡影像系统,120W快充,专业摄影手机",
["小米", "14", "Pro", "旗舰", "徕卡", "光学", "镜头", "骁龙", "8", "Gen3", "120W", "快充", "摄影"]),
(4, "Mate 60 Pro", "紫色",
"华为Mate 60 Pro,搭载麒麟9000S芯片,支持卫星通话,运行鸿蒙操作系统,超可靠玄武架构",
["华为", "Mate", "60", "Pro", "旗舰", "麒麟", "9000S", "卫星", "通话", "鸿蒙", "系统", "玄武架构"])
]
# 生成文本数据
all_texts = []
for id, name, color, desc, tokens in products:
search_text = f"{name} {color} {desc} {' '.join(tokens)}"
all_texts.append(search_text)
# 生成嵌入向量
if self.transformer_available and self.model:
logger.info(f"正在生成 {len(all_texts)} 个文本的语义嵌入向量...")
embeddings = self.model.encode(all_texts, normalize_embeddings=True)
logger.info("嵌入向量生成完成")
else:
logger.warning("使用随机向量作为回退方案")
embeddings = []
for i, text in enumerate(all_texts):
seed = hash(text) % (2**32)
np.random.seed(seed)
vec = np.random.normal(0, 1, self.vector_size)
vec = vec / np.linalg.norm(vec)
embeddings.append(vec)
# 插入数据
for i, (id, name, color, desc, tokens) in enumerate(products):
vector = embeddings[i].tolist() if hasattr(embeddings[i], 'tolist') else embeddings[i].astype(float).tolist()
# 将 tokens 数组转换为空格分隔的字符串
tokens_str = " ".join(tokens)
cur.execute(f"""
INSERT INTO {self.table_name}
(id, name, color, description, title_tokens, embedding_vec)
VALUES (%s, %s, %s, %s, %s, %s);
""", (id, name, color, desc, tokens_str, vector))
# 创建索引
try:
# BM25索引覆盖 description 和 title_tokens 字段
cur.execute(f"CREATE INDEX text_idx_{self.table_name} ON {self.table_name} USING fulltext(description, title_tokens);")
cur.execute(f"CREATE INDEX vector_idx_{self.table_name} ON {self.table_name} USING hnsw (embedding_vec floatvector_cosine_ops);")
logger.info("索引创建成功")
except Exception as e:
logger.warning(f"索引创建失败: {e}")
# 验证数据插入
cur.execute(f"SELECT COUNT(*) FROM {self.table_name};")
count = cur.fetchone()[0]
logger.info(f"数据初始化成功,插入了 {count} 条记录")
conn.commit()
return True
except Exception as e:
logger.error(f"数据初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def hybrid_search(self, query: str, text_weight: float = 0.6, vector_weight: float = 0.4,
top_n: int = 5, fusion_method: str = "linear"):
"""BM25 + 向量融合搜索
Args:
query: 搜索查询
text_weight: 文本搜索权重 (线性融合时使用)
vector_weight: 向量搜索权重 (线性融合时使用)
top_n: 返回结果数量
fusion_method: 融合方法,可选 'linear' 或 'rrf'
"""
conn = self.get_connection()
if not conn:
return []
try:
# 确保没有未完成的交易
conn.rollback()
# 生成查询向量
if self.transformer_available and self.model:
query_embedding = self.model.encode(query, normalize_embeddings=True)
query_vector = query_embedding.tolist() if hasattr(query_embedding, 'tolist') else query_embedding.astype(float).tolist()
else:
seed = hash(query) % (2**32)
np.random.seed(seed)
query_vector = np.random.normal(0, 1, self.vector_size)
query_vector = (query_vector / np.linalg.norm(query_vector)).tolist()
# 1. 执行BM25文本搜索
bm25_results = self._bm25_search(conn, query, top_n * 2)
# 2. 执行向量搜索
vector_results = self._vector_search(conn, query_vector, top_n * 2)
# 3. 在本地进行融合
if fusion_method == "linear":
return self._linear_fusion(bm25_results, vector_results, top_n, text_weight, vector_weight)
elif fusion_method == "rrf":
return self._rrf_fusion(bm25_results, vector_results, top_n, k=60)
else:
logger.warning(f"未知的融合方法: {fusion_method},使用线性融合")
return self._linear_fusion(bm25_results, vector_results, top_n, text_weight, vector_weight)
except Exception as e:
logger.error(f"搜索失败: {e}")
import traceback
traceback.print_exc()
return []
def _bm25_search(self, conn, query: str, limit: int):
"""执行BM25文本搜索"""
try:
with conn.cursor(cursor_factory=DictCursor) as cur:
desc_query = query + " @<PARAMS:BOOST=1.0>@"
title_tokens_query = query + " @<PARAMS:BOOST=2.0>@"
sql = f"""
SELECT id, name, color, description, bm25_score() as "SCORE"
FROM {self.table_name}
WHERE description @~@ '{desc_query}' AND title_tokens @~@ '{title_tokens_query}'
ORDER BY bm25_score DESC
LIMIT {limit}
"""
cur.execute(sql)
results = cur.fetchall()
return [dict(row) for row in results]
except Exception as e:
logger.warning(f"BM25搜索失败: {e}")
return []
def _vector_search(self, conn, query_vector: List[float], limit: int):
"""执行向量搜索"""
try:
with conn.cursor(cursor_factory=DictCursor) as cur:
sql = f"""
SELECT id, name, color, description,
(1 - (embedding_vec <=> ARRAY{query_vector})) AS "SIMILARITY"
FROM {self.table_name}
WHERE (1 - (embedding_vec <=> ARRAY{query_vector})) > 0.3
ORDER BY embedding_vec <=> ARRAY{query_vector}
LIMIT {limit}
"""
cur.execute(sql)
results = cur.fetchall()
return [dict(row) for row in results]
except Exception as e:
logger.warning(f"向量搜索失败: {e}")
return []
def _linear_fusion(self, bm25_results: List[Dict], vector_results: List[Dict],
top_n: int, text_weight: float, vector_weight: float):
"""线性融合:对BM25分数进行min-max归一化后加权融合"""
if not bm25_results and not vector_results:
return []
# 将结果按id索引
all_results = {}
# 处理BM25结果 - 进行min-max归一化
if bm25_results:
scores = [r.get("SCORE", 0) for r in bm25_results]
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score if max_score > min_score else 1.0
for result in bm25_results:
doc_id = result["id"]
raw_score = result.get("SCORE", 0)
# min-max归一化
normalized_score = (raw_score - min_score) / score_range if score_range > 0 else 0
if doc_id not in all_results:
all_results[doc_id] = {
"id": doc_id,
"name": result["name"],
"color": result["color"],
"description": result["description"],
"text_score": normalized_score,
"vector_score": 0,
"hybrid_score": 0
}
else:
all_results[doc_id]["text_score"] = normalized_score
# 处理向量结果
if vector_results:
for result in vector_results:
doc_id = result["id"]
similarity = result.get("SIMILARITY", 0)
if doc_id not in all_results:
all_results[doc_id] = {
"id": doc_id,
"name": result["name"],
"color": result["color"],
"description": result["description"],
"text_score": 0,
"vector_score": similarity,
"hybrid_score": 0
}
else:
all_results[doc_id]["vector_score"] = similarity
# 计算融合分数
for doc_id, result in all_results.items():
result["hybrid_score"] = (result["text_score"] * text_weight +
result["vector_score"] * vector_weight)
# 按融合分数排序并返回top_n
sorted_results = sorted(all_results.values(), key=lambda x: x["hybrid_score"], reverse=True)
return sorted_results[:top_n]
def _rrf_fusion(self, bm25_results: List[Dict], vector_results: List[Dict], top_n: int, k: int = 60):
"""RRF融合:Reciprocal Rank Fusion"""
rrf_scores = {}
# 处理BM25结果
for rank, result in enumerate(bm25_results, start=1):
doc_id = result["id"]
score = 1.0 / (k + rank)
if doc_id not in rrf_scores:
rrf_scores[doc_id] = {
"id": doc_id,
"name": result["name"],
"color": result["color"],
"description": result["description"],
"rrf_score": score,
"text_rank": rank,
"vector_rank": None
}
else:
rrf_scores[doc_id]["rrf_score"] += score
rrf_scores[doc_id]["text_rank"] = rank
# 处理向量结果
for rank, result in enumerate(vector_results, start=1):
doc_id = result["id"]
score = 1.0 / (k + rank)
if doc_id not in rrf_scores:
rrf_scores[doc_id] = {
"id": doc_id,
"name": result["name"],
"color": result["color"],
"description": result["description"],
"rrf_score": score,
"text_rank": None,
"vector_rank": rank
}
else:
rrf_scores[doc_id]["rrf_score"] += score
rrf_scores[doc_id]["vector_rank"] = rank
# 转换为列表并排序
results = list(rrf_scores.values())
results.sort(key=lambda x: x["rrf_score"], reverse=True)
# 只返回top_n,并添加文本分数和向量分数(如果有的话)
final_results = results[:top_n]
for result in final_results:
# 从原始结果中查找对应的文本分数和向量分数
bm25_match = next((r for r in bm25_results if r["id"] == result["id"]), None)
vector_match = next((r for r in vector_results if r["id"] == result["id"]), None)
result["text_score"] = bm25_match["SCORE"] if bm25_match else 0
result["vector_score"] = vector_match["SIMILARITY"] if vector_match else 0
result["hybrid_score"] = result["rrf_score"]
return final_results
def search_separate(self, query: str, top_n: int = 5):
"""分别执行BM25和向量搜索,用于对比"""
conn = self.get_connection()
if not conn:
return {"bm25": [], "vector": []}
try:
conn.rollback()
# 生成查询向量
if self.transformer_available and self.model:
query_embedding = self.model.encode(query, normalize_embeddings=True)
query_vector = query_embedding.tolist() if hasattr(query_embedding, 'tolist') else query_embedding.astype(float).tolist()
else:
seed = hash(query) % (2**32)
np.random.seed(seed)
query_vector = np.random.normal(0, 1, self.vector_size)
query_vector = (query_vector / np.linalg.norm(query_vector)).tolist()
# 使用新的辅助方法进行搜索
bm25_results = self._bm25_search(conn, query, top_n)
vector_results = self._vector_search(conn, query_vector, top_n)
return {
"bm25": bm25_results,
"vector": vector_results
}
except Exception as e:
logger.error(f"搜索失败: {e}")
return {"bm25": [], "vector": []}
def test_search(self, query: str):
"""测试搜索,比较两种融合方法"""
print(f"\n搜索: {query}")
print("=" * 80)
# 1. 线性融合搜索
print("[线性融合搜索] BM25分数min-max归一化 + 加权融合 (文本权重0.6, 向量权重0.4)")
print("-" * 40)
start_time = time.time()
linear_results = self.hybrid_search(query, fusion_method="linear")
linear_time = time.time() - start_time
if linear_results:
print(f"找到 {len(linear_results)} 条结果 (耗时: {linear_time:.3f}s):")
for i, result in enumerate(linear_results, 1):
print(f"{i}. 【{result['name']}】颜色: {result['color']}")
print(f" 描述: {result['description'][:80]}...")
print(f" BM25分数: {result.get('text_score', 0):.4f}")
print(f" 向量相似度: {result.get('vector_score', 0):.4f}")
print(f" 融合分数: {result.get('hybrid_score', 0):.4f}")
print()
# 2. RRF融合搜索
print("\n[RRF融合搜索] Reciprocal Rank Fusion,k=60")
print("-" * 40)
start_time = time.time()
rrf_results = self.hybrid_search(query, fusion_method="rrf")
rrf_time = time.time() - start_time
if rrf_results:
print(f"找到 {len(rrf_results)} 条结果 (耗时: {rrf_time:.3f}s):")
for i, result in enumerate(rrf_results, 1):
text_rank = result.get('text_rank') if result.get('text_rank') else 'N/A'
vector_rank = result.get('vector_rank') if result.get('vector_rank') else 'N/A'
print(f"{i}. 【{result['name']}】颜色: {result['color']}")
print(f" 描述: {result['description'][:80]}...")
print(f" BM25排名: {text_rank}, 向量排名: {vector_rank}")
print(f" RRF分数: {result.get('hybrid_score', 0):.4f}")
print()
# 3. 分别搜索对比
print("\n[分别搜索对比]:")
print("-" * 40)
start_time = time.time()
separate_results = self.search_separate(query)
separate_time = time.time() - start_time
# BM25结果
print(f"[BM25搜索结果]:")
if separate_results["bm25"]:
for i, result in enumerate(separate_results["bm25"], 1):
print(f" {i}. 【{result['name']}】颜色: {result['color']}(分数: {result.get('SCORE', 0):.4f})")
else:
print(" 无结果")
# 向量搜索结果
print(f"\n[向量搜索结果]:")
if separate_results["vector"]:
for i, result in enumerate(separate_results["vector"], 1):
print(f" {i}. 【{result['name']}】颜色: {result['color']}(相似度: {result.get('SIMILARITY', 0):.4f})")
else:
print(" 无结果")
# 4. 结果分析
print(f"\n[搜索性能与结果分析]:")
print(f" 线性融合搜索耗时: {linear_time:.3f}s")
print(f" RRF融合搜索耗时: {rrf_time:.3f}s")
print(f" 分别搜索耗时: {separate_time:.3f}s")
if linear_results:
bm25_hits = sum(1 for r in linear_results if r.get('text_score', 0) > 0)
vector_hits = sum(1 for r in linear_results if r.get('vector_score', 0) > 0)
print(f" 线性融合中 - BM25命中: {bm25_hits}条, 向量命中: {vector_hits}条")
if rrf_results and linear_results:
# 比较两种融合方法的结果差异
linear_ids = [r['id'] for r in linear_results]
rrf_ids = [r['id'] for r in rrf_results]
common = set(linear_ids) & set(rrf_ids)
print(f" 两种融合方法共同命中: {len(common)}条")
print(f" 线性融合特有: {len(set(linear_ids) - set(rrf_ids))}条")
print(f" RRF融合特有: {len(set(rrf_ids) - set(linear_ids))}条")
# 语义相似度测试
if self.transformer_available and len(linear_results) >= 2:
try:
text1 = f"{linear_results[0]['name']} {linear_results[0]['color']} {linear_results[0]['description']}"
text2 = f"{linear_results[1]['name']} {linear_results[1]['color']} {linear_results[1]['description']}"
similarity = self.model.similarity([text1], [text2])[0][0]
print(f" Top1与Top2语义相似度: {similarity:.4f}")
except Exception as e:
logger.debug(f"语义相似度计算失败: {e}")
return linear_results
def close(self):
"""关闭连接"""
if self.conn:
self.conn.close()
self.conn = None
logger.info("数据库连接已关闭")
def main():
print("=" * 80)
print("BM25 + 向量融合检索Demo")
print("使用sentence-transformers生成语义嵌入向量")
print("基于VexDB实现BM25全文检索与向量相似度融合")
print("=" * 80)
demo = HybridSearchDemo()
if not demo.transformer_available:
print("\n[sentence-transformers库未安装]")
print("请运行: pip install sentence-transformers")
print("当前将使用随机向量进行演示")
print()
try:
if not demo.init_data():
print("数据初始化失败")
return
print("\n数据准备完成,开始测试搜索\n")
# 测试查询
test_queries = [
"苹果 iPhone",
"徕卡相机",
"卫星通话",
"快充手机",
"华为手机",
"小米 徕卡"
]
for query in test_queries:
demo.test_search(query)
print("=" * 80)
print("所有测试完成!")
except Exception as e:
logger.error(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
demo.close()
if __name__ == "__main__":
main()
说明