1 Star 0 Fork 0

kinglau2008 / es-semantic-search

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
es_search.py 4.12 KB
Copy Edit Raw Blame History
liushouqian authored 2023-12-27 11:39 . fm
from elasticsearch import Elasticsearch
import tensorflow_hub as hub
import tensorflow.compat.v1 as tf
import pandas as pd
import numpy as np
import tensorflow_text as text
# es连接
es = Elasticsearch("http://127.0.0.1:9200")
# es索引名称
INDEX_NAME = "yiliao_vectors"
# 加载模型
graph = tf.Graph()
with tf.Session(graph = graph) as session:
print("Loading pre-trained embeddings")
embed = hub.load("./model")
text_ph = tf.placeholder(tf.string)
embeddings = embed(text_ph)
print("Creating tensorflow session…")
session = tf.Session()
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
# 实现把样本数据导入到es,并把title向量化处理,存入字段title_vector
def import_data():
df = pd.read_csv('./yiliao.csv')
# print(df['title'][0])
# title处理为向量
vectors = session.run(embeddings, feed_dict={text_ph: df['title']})
vector = []
for i in vectors:
vector.append(i)
df["Embeddings"] = vector
# 创建索引
create_es_index()
# actions = []
for index, row in df.iterrows():
doc = {
# "id": index,
"department": row["department"],
"title": row["title"],
"ask": row["ask"],
"answer": row["answer"],
"title_vector": row["Embeddings"]
}
# 当行数比较多时, es.bulk会超时,改为一行行写入
es.index(index=INDEX_NAME, body=doc, id=index)
print(index)
# action = {"index": {"_index": INDEX_NAME, "_id": index}}
# actions.append(action)
# actions.append(doc)
# 当行数比较多时, es.bulk会超时,改为一行行写入
# es.bulk(index=INDEX_NAME, body=actions, refresh=True)
# 创建索引
def create_es_index():
configurations = {
"settings": {
"index": {"number_of_replicas": 2},
"analysis": {
"filter": {
"ngram_filter": {
"type": "edge_ngram",
"min_gram": 2,
"max_gram": 15,
}
},
"analyzer": {
"ngram_analyzer": {
"type": "custom",
"tokenizer": "standard",
"filter": ["lowercase", "ngram_filter"],
}
}
}
},
"mappings": {
"properties": {
"title_vector": {
"type": "dense_vector",
"dims": 512
},
}
}
}
if(es.indices.exists(index=INDEX_NAME)):
print("索引已存在,会自动删除后再重建")
es.indices.delete(index=INDEX_NAME)
es.indices.create( index=INDEX_NAME,
body=configurations
)
# 处理文本嵌入
def embed_text(text):
vectors = session.run(embeddings, feed_dict={text_ph: text})
return [vector.tolist() for vector in vectors]
# 用于测试
def test():
while True:
try:
query = input("请输入搜索内容: ")
print(query)
do_query(query)
except KeyboardInterrupt:
return
# 执行语义搜索(向量查询)
def do_query(query):
query_vector = embed_text([query])[0]
# print(query_vector)
source_fields = ["title", "ask", "answer"]
response = es.search(
index = INDEX_NAME,
body={
"_source": source_fields,
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, doc['title_vector'])+1",
"params": {
"queryVector": query_vector
}
}
}
}
})
print("您搜索的title是:", end="\n")
for hit in response["hits"]["hits"]:
print(hit["_source"]["title"], end="\n")
print("\n")
# print(response)
if __name__ == '__main__':
# 样本数据导入es,对title进行嵌入处理
# import_data()
# 导入后,执行测试
test()
1
https://gitee.com/kinglau2008/es-semantic-search.git
git@gitee.com:kinglau2008/es-semantic-search.git
kinglau2008
es-semantic-search
es-semantic-search
master

Search