text2vec
Advanced tools
| exclude tests/* |
+1
-1
| Metadata-Version: 2.1 | ||
| Name: text2vec | ||
| Version: 1.3.1 | ||
| Version: 1.3.2 | ||
| Summary: Text to vector Tool, encode text | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/shibing624/text2vec |
| Metadata-Version: 2.1 | ||
| Name: text2vec | ||
| Version: 1.3.1 | ||
| Version: 1.3.2 | ||
| Summary: Text to vector Tool, encode text | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/shibing624/text2vec |
| LICENSE | ||
| MANIFEST.in | ||
| README.md | ||
| setup.py | ||
| tests/test_hf_dataset.py | ||
| tests/test_issue.py | ||
| tests/test_lcqmc_similarity.py | ||
| tests/test_longtext_simscore.py | ||
| tests/test_model_spearman.py | ||
| tests/test_multi_process.py | ||
| tests/test_qps.py | ||
| tests/test_rankbm25.py | ||
| tests/test_sbert_embeddings.py | ||
| tests/test_similarity.py | ||
| tests/test_w2v_embeddings.py | ||
| text2vec/__init__.py | ||
@@ -16,0 +6,0 @@ text2vec/bertmatching_dataset.py |
@@ -7,2 +7,2 @@ # -*- coding: utf-8 -*- | ||
| __version__ = '1.3.1' | ||
| __version__ = '1.3.2' |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from datasets import load_dataset | ||
| class DatasetTestCase(unittest.TestCase): | ||
| def test_data_diff(self): | ||
| test_dataset = load_dataset("shibing624/nli_zh", "STS-B", split="test") | ||
| # Predict embeddings | ||
| srcs = [] | ||
| trgs = [] | ||
| labels = [] | ||
| for terms in test_dataset: | ||
| src, trg, label = terms['sentence1'], terms['sentence2'], terms['label'] | ||
| srcs.append(src) | ||
| trgs.append(trg) | ||
| labels.append(label) | ||
| if len(src) > 100: | ||
| break | ||
| print(f'{test_dataset[0]}') | ||
| print(f'{srcs[0]}') | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import SentenceModel, cos_sim | ||
| from text2vec import BM25 | ||
| sbert_model = SentenceModel() | ||
| def sbert_sim_score(str_a, str_b): | ||
| a_emb = sbert_model.encode(str_a) | ||
| b_emb = sbert_model.encode(str_b) | ||
| return cos_sim(a_emb, b_emb) | ||
| class IssueTestCase(unittest.TestCase): | ||
| def test_sim_diff(self): | ||
| a = '研究团队面向国家重大战略需求追踪国际前沿发展借鉴国际人工智能研究领域的科研模式有效整合创新资源解决复' | ||
| b = '英汉互译比较语言学' | ||
| r = sbert_sim_score(a, b) | ||
| print(a, b, r) | ||
| self.assertTrue(abs(float(r) - 0.4098) < 0.001) | ||
| def test_sim_same(self): | ||
| a = '汉英翻译比较语言学' | ||
| b = '英汉互译比较语言学' | ||
| r = sbert_sim_score(a, b) | ||
| print(a, b, r) | ||
| self.assertTrue(abs(float(r) - 0.8905) < 0.001) | ||
| def test_search_sim(self): | ||
| sentences = [ | ||
| '原称《车骑出行》。', '画面从左至右为:四导骑,', '两两并排前行,', '骑手面相对,', '似交谈;', '三导骑,', | ||
| '并排前行;', '二马轺车,', '轮有辐,', | ||
| '车上一驭者执鞭,', '一尊者坐,', '回首;', '二导从,', '骑手面相对,', '似交谈;', '二马轺车,', '轮有辐,', | ||
| '车上一驭者执鞭,', '一尊者坐;', '四导从,', '两两相对并排前行;', '两骑手,', '反身张弓射虎;', '虎,', | ||
| '跃起前扑。', '上下右三边有框,', '上沿双边框内填刻三角纹,', '下沿双边框内填刻斜条纹。'] | ||
| self.assertEqual(len(sentences), 28) | ||
| uniq_sentences = list(set(sentences)) | ||
| print(uniq_sentences) | ||
| print(len(uniq_sentences)) | ||
| self.assertEqual(len(uniq_sentences), 23) | ||
| search_sim = BM25(corpus=uniq_sentences) | ||
| print(len(search_sim.corpus)) | ||
| query = '上沿双边框内填刻三角形纹' | ||
| scores = search_sim.get_scores(query=query, top_k=None) | ||
| print(scores) | ||
| print(len(scores)) | ||
| self.assertEqual(len(scores), 23) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import SentenceModel, cos_sim | ||
| sbert_model = SentenceModel() | ||
| # query1 query2 matching? | ||
| case_same_keywords = [['飞行员没钱买房怎么办?', '父母没钱买房子', False], | ||
| ['聊天室都有哪些好的', '聊天室哪个好', True], | ||
| ['不锈钢上贴的膜怎么去除', '不锈钢上的胶怎么去除', True], | ||
| ['动漫人物的口头禅', '白羊座的动漫人物', False]] | ||
| case_categories_corresponding_pairs = [['从广州到长沙在哪里定高铁票', '在长沙哪里坐高铁回广州?', False], | ||
| ['请问现在最好用的听音乐软件是什么啊', '听歌用什么软件比较好', True], | ||
| ['谁有吃过完美的产品吗?如何?', '完美产品好不好', True], | ||
| ['朱熹是哪个朝代的诗人', '朱熹是明理学的集大成者,他生活在哪个朝代', True], | ||
| ['这是哪个奥特曼?', '这是什么奥特曼...', True], | ||
| ['网上找工作可靠吗', '网上找工作靠谱吗', True], | ||
| ['你们都喜欢火影忍者里的谁啊', '火影忍者里你最喜欢谁', True]] | ||
| def sbert_sim_score(str_a, str_b): | ||
| a_emb = sbert_model.encode(str_a) | ||
| b_emb = sbert_model.encode(str_b) | ||
| return cos_sim(a_emb, b_emb).item() | ||
| def apply_sbert_case(cases): | ||
| for line in cases: | ||
| q1 = line[0] | ||
| q2 = line[1] | ||
| a = line[2] | ||
| s = sbert_sim_score(q1, q2) | ||
| print(f'q1:{q1}, q2:{q2}, expect:{a}, actual:{s:.4f}') | ||
| class LcqTestCase(unittest.TestCase): | ||
| def test_sbert(self): | ||
| """测试sbert结果""" | ||
| apply_sbert_case(case_same_keywords) | ||
| apply_sbert_case(case_categories_corresponding_pairs) | ||
| # q1: 飞行员没钱买房怎么办?, q2: 父母没钱买房子, expect: False, actual: 0.3742 | ||
| # q1: 聊天室都有哪些好的, q2: 聊天室哪个好, expect: True, actual: 0.9497 | ||
| # q1: 不锈钢上贴的膜怎么去除, q2: 不锈钢上的胶怎么去除, expect: True, actual: 0.8708 | ||
| # q1: 动漫人物的口头禅, q2: 白羊座的动漫人物, expect: False, actual: 0.8510 | ||
| # q1: 从广州到长沙在哪里定高铁票, q2: 在长沙哪里坐高铁回广州?, expect: False, actual: 0.9163 | ||
| # q1: 请问现在最好用的听音乐软件是什么啊, q2: 听歌用什么软件比较好, expect: True, actual: 0.9182 | ||
| # q1: 谁有吃过完美的产品吗?如何?, q2: 完美产品好不好, expect: True, actual: 0.7370 | ||
| # q1: 朱熹是哪个朝代的诗人, q2: 朱熹是明理学的集大成者,他生活在哪个朝代, expect: True, actual: 0.7382 | ||
| # q1: 这是哪个奥特曼?, q2: 这是什么奥特曼..., expect: True, actual: 0.8744 | ||
| # q1: 网上找工作可靠吗, q2: 网上找工作靠谱吗, expect: True, actual: 0.9531 | ||
| # q1: 你们都喜欢火影忍者里的谁啊, q2: 火影忍者里你最喜欢谁, expect: True, actual: 0.9643 | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import SentenceModel, cos_sim | ||
| sbert_model = SentenceModel() | ||
| a = '你们都喜欢火影忍者里的谁啊,你说的到底是谁?看Bert里面extract_features.py这个文件,可以得到类似预训练的词向量组成的句子表示,' \ | ||
| '类似于Keras里面第一步Embedding层。以题主所说的句子相似度计算为例,只需要把两个句子用分隔符隔开送到bert的输入(首位加特殊标记符' \ | ||
| 'CLS的embedding),然后取bert输出中和CLS对应的那个vector(记为c)进行变换就可以了。原文中提到的是多分类任务,给出的输出变换是' \ | ||
| ')就可以了。至于题主提到的句向量表示,上文中提到的向量c即可一定程度表' \ | ||
| '示整个句子的语义,原文中有提到“ The final hidden state (i.e., output of Transformer) corresponding to this token ' \ | ||
| 'is used as the aggregate sequence representation for classification tasks.”' \ | ||
| '这句话中的“this token”就是CLS位。补充:除了直接使用bert的句对匹配之外,还可以只用bert来对每个句子求embedding。之后再通过向' \ | ||
| 'Siamese Network这样的经典模式去求相似度也可以' | ||
| b = '你说的到底是谁?看Bert里面extract_features.py这个文件,可以得到类似预训练的词向量组成的句子表示,' \ | ||
| '类似于Keras里面第一步Embedding层。以题主所说的句子相似度计算为例,只需要把两个句子用分隔符隔开送到bert的输入(首位加特殊标记符' \ | ||
| 'CLS的embedding),然后取bert输出中和CLS对应的那个vector(记为c)进行变换就可以了。原文中提到的是多分类任务,给出的输出变换是' \ | ||
| ')就可以了。至于题主提到的句向量表示,上文中提到的向量c即可一定程度表' | ||
| def sbert_sim_score(str_a, str_b): | ||
| a_emb = sbert_model.encode(str_a) | ||
| b_emb = sbert_model.encode(str_b) | ||
| return cos_sim(a_emb, b_emb).item() | ||
| class TestCase(unittest.TestCase): | ||
| def test_bert_sim(self): | ||
| r = sbert_sim_score(a, b) | ||
| print(r) | ||
| self.assertEqual(abs(r - 0.872) < 0.2, True) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import os | ||
| import sys | ||
| import unittest | ||
| from time import time | ||
| sys.path.append('..') | ||
| from text2vec import Similarity, SimilarityType, EmbeddingType, compute_spearmanr | ||
| from text2vec import load_jsonl | ||
| pwd_path = os.path.abspath(os.path.dirname(__file__)) | ||
| is_debug = True | ||
| def load_test_data(path): | ||
| sents1, sents2, labels = [], [], [] | ||
| if not os.path.isfile(path): | ||
| return sents1, sents2, labels | ||
| with open(path, 'r', encoding='utf8') as f: | ||
| for line in f: | ||
| line = line.strip().split('\t') | ||
| if len(line) != 3: | ||
| continue | ||
| sents1.append(line[0]) | ||
| sents2.append(line[1]) | ||
| labels.append(int(line[2])) | ||
| if is_debug and len(sents1) > 10: | ||
| break | ||
| return sents1, sents2, labels | ||
| def get_corr(model, test_path): | ||
| sents1, sents2, labels = load_test_data(test_path) | ||
| t1 = time() | ||
| scores = model.get_scores(sents1, sents2, only_aligned=True) | ||
| sims = [] | ||
| for i in range(len(sents1)): | ||
| sims.append(scores[i][i]) | ||
| spend_time = max(time() - t1, 1e-9) | ||
| corr = compute_spearmanr(sims, labels) | ||
| print('scores:', sims[:10]) | ||
| print('labels:', labels[:10]) | ||
| print(f'{test_path} spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| return corr | ||
| class SimModelTestCase(unittest.TestCase): | ||
| def test_w2v_sim_batch(self): | ||
| """测试test_w2v_sim_batch""" | ||
| model_name = 'w2v-light-tencent-chinese' | ||
| print(model_name) | ||
| m = Similarity(model_name, similarity_type=SimilarityType.COSINE, embedding_type=EmbeddingType.WORD2VEC) | ||
| test_path = os.path.join(pwd_path, '../examples/data/STS-B/STS-B.test.data') | ||
| get_corr(m, test_path) | ||
| # ATEC | ||
| test_path = os.path.join(pwd_path, '../examples/data/ATEC/ATEC.test.data') | ||
| get_corr(m, test_path) | ||
| # BQ | ||
| test_path = os.path.join(pwd_path, '../examples/data/BQ/BQ.test.data') | ||
| get_corr(m, test_path) | ||
| # LCQMC | ||
| test_path = os.path.join(pwd_path, '../examples/data/LCQMC/LCQMC.test.data') | ||
| get_corr(m, test_path) | ||
| # PAWSX | ||
| test_path = os.path.join(pwd_path, '../examples/data/PAWSX/PAWSX.test.data') | ||
| get_corr(m, test_path) | ||
| def test_sbert_sim_stsb_batch(self): | ||
| """测试sbert_sim_each_batch""" | ||
| model_name = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' | ||
| print(model_name) | ||
| m = Similarity( | ||
| model_name, | ||
| similarity_type=SimilarityType.COSINE, | ||
| embedding_type=EmbeddingType.BERT, | ||
| encoder_type="MEAN" | ||
| ) | ||
| test_path = os.path.join(pwd_path, '../examples/data/STS-B/STS-B.test.data') | ||
| get_corr(m, test_path) | ||
| # ATEC | ||
| test_path = os.path.join(pwd_path, '../examples/data/ATEC/ATEC.test.data') | ||
| get_corr(m, test_path) | ||
| # BQ | ||
| test_path = os.path.join(pwd_path, '../examples/data/BQ/BQ.test.data') | ||
| get_corr(m, test_path) | ||
| # LCQMC | ||
| test_path = os.path.join(pwd_path, '../examples/data/LCQMC/LCQMC.test.data') | ||
| get_corr(m, test_path) | ||
| # PAWSX | ||
| test_path = os.path.join(pwd_path, '../examples/data/PAWSX/PAWSX.test.data') | ||
| get_corr(m, test_path) | ||
| def test_set_sim_model_batch(self): | ||
| """测试test_set_sim_model_batch""" | ||
| m = Similarity( | ||
| 'shibing624/text2vec-base-chinese', | ||
| similarity_type=SimilarityType.COSINE, | ||
| embedding_type=EmbeddingType.BERT, | ||
| encoder_type="MEAN" | ||
| ) | ||
| print(m) | ||
| test_path = os.path.join(pwd_path, '../examples/data/STS-B/STS-B.test.data') | ||
| c1 = get_corr(m, test_path) | ||
| # ATEC | ||
| test_path = os.path.join(pwd_path, '../examples/data/ATEC/ATEC.test.data') | ||
| c2 = get_corr(m, test_path) | ||
| # BQ | ||
| test_path = os.path.join(pwd_path, '../examples/data/BQ/BQ.test.data') | ||
| c3 = get_corr(m, test_path) | ||
| # LCQMC | ||
| test_path = os.path.join(pwd_path, '../examples/data/LCQMC/LCQMC.test.data') | ||
| c4 = get_corr(m, test_path) | ||
| # PAWSX | ||
| test_path = os.path.join(pwd_path, '../examples/data/PAWSX/PAWSX.test.data') | ||
| c5 = get_corr(m, test_path) | ||
| # SOHU-dd | ||
| test_path = os.path.join(pwd_path, '../examples/data/SOHU/dd-test.jsonl') | ||
| data = load_jsonl(test_path) | ||
| sents1, sents2, labels = [], [], [] | ||
| for item in data: | ||
| sents1.append(item['sentence1']) | ||
| sents2.append(item['sentence2']) | ||
| labels.append(item['label']) | ||
| t1 = time() | ||
| scores = m.get_scores(sents1, sents2) | ||
| sims = [] | ||
| for i in range(len(sents1)): | ||
| sims.append(scores[i][i]) | ||
| spend_time = max(time() - t1, 1e-9) | ||
| corr = compute_spearmanr(sims, labels) | ||
| print('scores:', sims[:10]) | ||
| print('labels:', labels[:10]) | ||
| print(f'{test_path} spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| c6 = corr | ||
| # SOHU-dc | ||
| test_path = os.path.join(pwd_path, '../examples/data/SOHU/dc-test.jsonl') | ||
| data = load_jsonl(test_path) | ||
| sents1, sents2, labels = [], [], [] | ||
| for item in data: | ||
| sents1.append(item['sentence1']) | ||
| sents2.append(item['sentence2']) | ||
| labels.append(item['label']) | ||
| t1 = time() | ||
| scores = m.get_scores(sents1, sents2) | ||
| sims = [] | ||
| for i in range(len(sents1)): | ||
| sims.append(scores[i][i]) | ||
| spend_time = max(time() - t1, 1e-9) | ||
| corr = compute_spearmanr(sims, labels) | ||
| print('scores:', sims[:10]) | ||
| print('labels:', labels[:10]) | ||
| print(f'{test_path} spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| c7 = corr | ||
| print('average spearman corr:', (c1 + c2 + c3 + c4 + c5 + c6 + c7) / 7) | ||
| def test_uer_sbert_nli_model(self): | ||
| # uer/sbert-base-chinese-nli | ||
| # STS-B spearman corr: 0.7179 | ||
| # ATEC spearman corr: 0.2953 | ||
| # BQ spearman corr: 0.4332 | ||
| # LCQMC spearman corr: 0.6239 | ||
| # PAWSX spearman corr: 0.1345 | ||
| # avg: 0.44096 | ||
| pass | ||
| def test_ernie3_0_nano_model(self): | ||
| # nghuyong/ernie-3.0-nano-zh | ||
| # STS-B spearman corr: 0.6677 | ||
| # ATEC spearman corr: 0.2331 | ||
| # BQ spearman corr: 0.3716 | ||
| # LCQMC spearman corr: 0.6007 | ||
| # PAWSX spearman corr: 0.0970 | ||
| # avg: 0.3918 | ||
| # V100 QPS: 2858 | ||
| pass | ||
| def test_ernie3_0_base_model(self): | ||
| # nghuyong/ernie-3.0-base-zh | ||
| # training with first_last_avg pooling and inference with mean pooling | ||
| # STS-B spearman corr: 0.7981 | ||
| # ATEC spearman corr: 0.2965 | ||
| # BQ spearman corr: 0.3535 | ||
| # LCQMC spearman corr: 0.7184 | ||
| # PAWSX spearman corr: 0.1453 | ||
| # avg: 0.4619 | ||
| # V100 QPS: 1547 | ||
| # training with first_last_avg pooling and inference with first_last_avg pooling | ||
| # STS-B spearman corr: 0.7931 | ||
| # ATEC spearman corr: 0.2997 | ||
| # BQ spearman corr: 0.3749 | ||
| # LCQMC spearman corr: 0.7110 | ||
| # PAWSX spearman corr: 0.1326 | ||
| # avg: 0.4421 | ||
| # V100 QPS: 1613 | ||
| # training with mean pooling and inference with mean pooling | ||
| # STS-B spearman corr: 0.8153 | ||
| # ATEC spearman corr: 0.3319 | ||
| # BQ spearman corr: 0.4284 | ||
| # LCQMC spearman corr: 0.7293 | ||
| # PAWSX spearman corr: 0.1499 | ||
| # avg: 0.4909 (best) | ||
| # sohu-dd spearman corr: 0.7032 | ||
| # sohu-dc spearman corr: 0.5723 | ||
| # add sohu-dd and sohu-dc avg: 0.5329 | ||
| # V100 QPS: 1588 | ||
| # training with mean pooling and inference with mean pooling | ||
| # retrain with 512 length | ||
| # STS-B spearman corr: 0.7962 | ||
| # ATEC spearman corr: 0.2852 | ||
| # BQ spearman corr: 0.34746 | ||
| # LCQMC spearman corr: 0.7073 | ||
| # PAWSX spearman corr: 0.16109 | ||
| # avg: | ||
| # V100 QPS: 1552 | ||
| # training with mean pooling and inference with mean pooling | ||
| # retrain with 256 length | ||
| # STS-B spearman corr: 0.8080 | ||
| # ATEC spearman corr: 0.3006 | ||
| # BQ spearman corr: 0.3927 | ||
| # LCQMC spearman corr: 0.71993 | ||
| # PAWSX spearman corr: 0.1371 | ||
| # avg: | ||
| # V100 QPS: | ||
| # training with mean pooling and inference with mean pooling | ||
| # training data: STS-B + ATEC + BQ + LCQMC + PAWSX | ||
| # STS-B spearman corr: 0.8070 | ||
| # ATEC spearman corr: 0.5126 | ||
| # BQ spearman corr: 0.6872 | ||
| # LCQMC spearman corr: 0.7913 | ||
| # PAWSX spearman corr: 0.3428 | ||
| # avg: 0.6281 | ||
| # sohu-dd spearman corr: 0.6939 | ||
| # sohu-dc spearman corr: 0.5544 | ||
| # add sohu-dd and sohu-dc avg: 0.6270 | ||
| # V100 QPS: 1526 | ||
| # training with mean pooling and inference with mean pooling | ||
| # training data: all nli-zh-all random sampled data | ||
| # STS-B spearman corr: 0.7742 | ||
| # ATEC spearman corr: 0.4394 | ||
| # BQ spearman corr: 0.6436 | ||
| # LCQMC spearman corr: 0.7345 | ||
| # PAWSX spearman corr: 0.3914 | ||
| # avg: 0.5966 | ||
| # sohu-dd spearman corr: 0.49124 | ||
| # sohu-dc spearman corr: 0.3653 | ||
| # V100 QPS: 1603 | ||
| # training with mean pooling and inference with mean pooling | ||
| # training data: all nli-zh-all human sampled data (v2) | ||
| # STS-B spearman corr: 0.7825 | ||
| # ATEC spearman corr: 0.4337 | ||
| # BQ spearman corr: 0.6143 | ||
| # LCQMC spearman corr: 0.7348 | ||
| # PAWSX spearman corr: 0.3890 | ||
| # avg: 0.5908 | ||
| # sohu-dd spearman corr: 0.7034 | ||
| # sohu-dc spearman corr: 0.5491 | ||
| # add sohu avg: 0.6009 | ||
| # V100 QPS: 1560 | ||
| # training with mean pooling and inference with mean pooling | ||
| # training data: all sts-zh sampled data only stsb + sts22 (v3) (drop) | ||
| # STS-B spearman corr: 0.8163 | ||
| # ATEC spearman corr: 0.3236 | ||
| # BQ spearman corr: 0.3905 | ||
| # LCQMC spearman corr: 0.7279 | ||
| # PAWSX spearman corr: 0.1377 | ||
| # avg: 0.4811 | ||
| # sohu-dd spearman corr: 0.7162 | ||
| # sohu-dc spearman corr: 0.5721 | ||
| # V100 QPS: 1557 | ||
| # training with mean pooling and inference with mean pooling | ||
| # training data: all sts-zh human sampled data only stsb + nli (v4) | ||
| # STS-B spearman corr: 0.7893 | ||
| # ATEC spearman corr: 0.4489 | ||
| # BQ spearman corr: 0.6358 | ||
| # LCQMC spearman corr: 0.7424 | ||
| # PAWSX spearman corr: 0.4090 | ||
| # avg: 0.6072 | ||
| # sohu-ddb spearman corr: 0.7670 | ||
| # sohu-dcb spearman corr: 0.6330 | ||
| # add sohu avg: 0.6308 | ||
| # V100 QPS: 1601 | ||
| pass | ||
| def test_ernie3_0_xbase_model(self): | ||
| # nghuyong/ernie-3.0-xbase-zh | ||
| # STS-B spearman corr: 0.7827 | ||
| # ATEC spearman corr: 0.3463 | ||
| # BQ spearman corr: 0.4267 | ||
| # LCQMC spearman corr: 0.7181 | ||
| # PAWSX spearman corr: 0.1318 | ||
| # avg: 0.4811 | ||
| # V100 QPS: 468 | ||
| pass | ||
| def test_hfl_chinese_bert_wwm_ext_model(self): | ||
| # hfl/chinese-bert-wwm-ext | ||
| # STS-B spearman corr: 0.7635 | ||
| # ATEC spearman corr: 0.2708 | ||
| # BQ spearman corr: 0.3480 | ||
| # LCQMC spearman corr: 0.7056 | ||
| # PAWSX spearman corr: 0.1699 | ||
| # avg: 0.4515 | ||
| # V100 QPS: 1507 | ||
| pass | ||
| def test_hfl_chinese_roberta_wwm_ext_model(self): | ||
| # hfl/chinese-roberta-wwm-ext | ||
| # training with first_last_avg pooling and inference with mean pooling | ||
| # STS-B spearman corr: 0.7894 | ||
| # ATEC spearman corr: 0.3241 | ||
| # BQ spearman corr: 0.4362 | ||
| # LCQMC spearman corr: 0.7107 | ||
| # PAWSX spearman corr: 0.1446 | ||
| # avg: 0.4808 | ||
| # V100 QPS: 1472 | ||
| # hfl/chinese-roberta-wwm-ext | ||
| # training with first_last_avg pooling and inference with first_last_avg pooling | ||
| # STS-B spearman corr: 0.7854 | ||
| # ATEC spearman corr: 0.3234 | ||
| # BQ spearman corr: 0.4402 | ||
| # LCQMC spearman corr: 0.7029 | ||
| # PAWSX spearman corr: 0.1295 | ||
| # avg: 0.4739 | ||
| # V100 QPS: 1581 | ||
| # hfl/chinese-roberta-wwm-ext | ||
| # training with mean pooling and inference with mean pooling | ||
| # STS-B spearman corr: 0.7996 | ||
| # ATEC spearman corr: 0.3315 | ||
| # BQ spearman corr: 0.4364 | ||
| # LCQMC spearman corr: 0.7175 | ||
| # PAWSX spearman corr: 0.1472 | ||
| # avg: 0.4864 | ||
| # V100 QPS: 1487 | ||
| pass | ||
| def test_hfl_chinese_macbert_large_model(self): | ||
| # hfl/chinese-macbert-large | ||
| # STS-B spearman corr: 0.7495 | ||
| # ATEC spearman corr: 0.3222 | ||
| # BQ spearman corr: 0.4608 | ||
| # LCQMC spearman corr: 0.6784 | ||
| # PAWSX spearman corr: 0.1081 | ||
| # avg: 0.4634 | ||
| # V100 QPS: 672 | ||
| pass | ||
| def test_m3e_base_model(self): | ||
| # moka-ai/m3e-base | ||
| # STS-B spearman corr: 0.7696 | ||
| # ATEC spearman corr: 0.4127 | ||
| # BQ spearman corr: 0.6381 | ||
| # LCQMC spearman corr: 0.7487 | ||
| # PAWSX spearman corr: 0.1220 | ||
| # avg: 0.5378 | ||
| # V100 QPS: 1490 | ||
| # sohu-dd spearman corr: 0.7583 | ||
| # sohu-dc spearman corr: 0.6055 | ||
| # add sohu avg: 0.5793 | ||
| pass | ||
| def test_bge_large_zh_noinstruct_model(self): | ||
| # BAAI/bge-large-zh-noinstruct | ||
| # STS-B spearman corr: 0.7292 | ||
| # ATEC spearman corr: 0.4466 | ||
| # BQ spearman corr: 0.54995 | ||
| # LCQMC spearman corr: 0.69834 | ||
| # PAWSX spearman corr: 0.15612 | ||
| # avg: 0.51606 | ||
| # V100 QPS: 470 | ||
| # sohu-dd spearman corr: 0.53378 | ||
| # sohu-dc spearman corr: 0.198637 | ||
| # add sohu avg: 0.4732 | ||
| pass | ||
| def test_bge_large_zh_noinstruct_cosent_model(self): | ||
| # BAAI/bge-large-zh-noinstruct with sts-b cosent finetuned | ||
| # STS-B spearman corr: 0.8059 | ||
| # ATEC spearman corr: 0.4234 | ||
| # BQ spearman corr: 0.515842 | ||
| # LCQMC spearman corr: 0.7291 | ||
| # PAWSX spearman corr: 0.1249 | ||
| # avg: 0.5198 | ||
| # V100 QPS: 498 | ||
| # sohu-dd spearman corr: 0.7243 | ||
| # sohu-dc spearman corr: 0.58399 | ||
| # add sohu avg: 0.5582 | ||
| pass | ||
| def test_bge_large_zh_noinstruct_cosent_passage_model(self): | ||
| # BAAI/bge-large-zh-noinstruct with sts-b cosent finetuned v2 | ||
| # STS-B spearman corr: 0.7644 | ||
| # ATEC spearman corr: 0.38411 | ||
| # BQ spearman corr: 0.61348 | ||
| # LCQMC spearman corr: 0.717220 | ||
| # PAWSX spearman corr: 0.351538 | ||
| # avg: 0.5661 | ||
| # V100 QPS:427 | ||
| # sohu-dd spearman corr: 0.7181 | ||
| # sohu-dc spearman corr: 0.631528 | ||
| # add sohu avg: 0.5972 | ||
| pass | ||
| def test_bge_large_zh_noinstruct_bge_model(self): | ||
| # BAAI/bge-large-zh-noinstruct with bge finetuned v2 | ||
| # STS-B spearman corr: 0.8093 | ||
| # ATEC spearman corr: 0.45839 | ||
| # BQ spearman corr: 0.56505 | ||
| # LCQMC spearman corr: 0.742664 | ||
| # PAWSX spearman corr: 0.11136 | ||
| # avg: 0.53736 | ||
| # V100 QPS: 605 | ||
| # sohu-dd spearman corr: 0.566741 | ||
| # sohu-dc spearman corr: 0.2098 | ||
| # add sohu avg: 0.4947 | ||
| pass | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| code copy from: SentenceTransformers.tests.test_multi_process.py | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import SentenceModel | ||
| import numpy as np | ||
| class ComputeMultiProcessTest(unittest.TestCase): | ||
| def setUp(self): | ||
| self.model = SentenceModel() | ||
| def test_multi_gpu_encode(self): | ||
| # Start the multi processes pool on all available CUDA devices | ||
| pool = self.model.start_multi_process_pool(['cpu', 'cpu']) | ||
| sentences = ["This is sentence {}".format(i) for i in range(1000)] | ||
| # Compute the embeddings using the multi processes pool | ||
| emb = self.model.encode_multi_process(sentences, pool, chunk_size=50) | ||
| assert emb.shape == (len(sentences), 768) | ||
| emb_normal = self.model.encode(sentences) | ||
| diff = np.max(np.abs(emb - emb_normal)) | ||
| print("Max multi proc diff", diff) | ||
| assert diff < 0.001 |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import os | ||
| import sys | ||
| import unittest | ||
| from loguru import logger | ||
| import time | ||
| import os | ||
| import torch | ||
| from transformers import AutoTokenizer, AutoModel | ||
| sys.path.append('..') | ||
| from text2vec import Word2Vec, SentenceModel | ||
| from sentence_transformers import SentenceTransformer | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | ||
| pwd_path = os.path.abspath(os.path.dirname(__file__)) | ||
| logger.add('test.log') | ||
| data = ['如何更换花呗绑定银行卡', | ||
| '花呗更改绑定银行卡'] | ||
| print("data:", data) | ||
| num_tokens = sum([len(i) for i in data]) | ||
| use_cuda = torch.cuda.is_available() | ||
| repeat = 10 if use_cuda else 4 | ||
| class TransformersEncoder: | ||
| def __init__(self, model_name='shibing624/text2vec-base-chinese'): | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| self.model = AutoModel.from_pretrained(model_name).to(device) | ||
| def encode(self, sentences): | ||
| # Mean Pooling - Take attention mask into account for correct averaging | ||
| def mean_pooling(model_output, attention_mask): | ||
| token_embeddings = model_output[0] # First element of model_output contains all token embeddings | ||
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), | ||
| min=1e-9) | ||
| # Tokenize sentences | ||
| encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device) | ||
| # Compute token embeddings | ||
| with torch.no_grad(): | ||
| model_output = self.model(**encoded_input) | ||
| # Perform pooling. In this case, max pooling. | ||
| sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | ||
| return sentence_embeddings | ||
| class SentenceTransformersEncoder: | ||
| def __init__(self, model_name="shibing624/text2vec-base-chinese"): | ||
| self.model = SentenceTransformer(model_name) | ||
| def encode(self, sentences, convert_to_numpy=True): | ||
| sentence_embeddings = self.model.encode(sentences, convert_to_numpy=convert_to_numpy) | ||
| return sentence_embeddings | ||
| class QPSEncoderTestCase(unittest.TestCase): | ||
| def test_cosent_speed(self): | ||
| """测试cosent_speed""" | ||
| logger.info("\n---- cosent:") | ||
| model = SentenceModel('shibing624/text2vec-base-chinese') | ||
| logger.info(' convert_to_numpy=True:') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp, convert_to_numpy=True) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| logger.info(' convert_to_numpy=False:') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp, convert_to_numpy=False) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {len(r)}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| def test_origin_transformers_speed(self): | ||
| """测试origin_transformers_speed""" | ||
| logger.info("\n---- origin transformers:") | ||
| model = TransformersEncoder('shibing624/text2vec-base-chinese') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| def test_origin_sentence_transformers_speed(self): | ||
| """测试origin_sentence_transformers_speed""" | ||
| logger.info("\n---- origin sentence_transformers:") | ||
| model = SentenceTransformersEncoder('shibing624/text2vec-base-chinese') | ||
| logger.info(' convert_to_numpy=True:') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp, convert_to_numpy=True) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| logger.info(' convert_to_numpy=False:') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp, convert_to_numpy=False) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {len(r)}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| def test_sbert_speed(self): | ||
| """测试sbert_speed""" | ||
| logger.info("\n---- sbert:") | ||
| model = SentenceModel('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| def test_w2v_speed(self): | ||
| """测试w2v_speed""" | ||
| logger.info("\n---- w2v:") | ||
| model = Word2Vec() | ||
| for j in range(repeat): | ||
| tmp = data * (2 ** j) | ||
| c_num_tokens = num_tokens * (2 ** j) | ||
| start_t = time.time() | ||
| r = model.encode(tmp) | ||
| assert r is not None | ||
| if j == 0: | ||
| logger.info(f"result shape: {r.shape}, emb: {r[0][:10]}") | ||
| time_t = time.time() - start_t | ||
| logger.info('encoding %d sentences, spend %.2fs, %4d samples/s, %6d tokens/s' % | ||
| (len(tmp), time_t, int(len(tmp) / time_t), int(c_num_tokens / time_t))) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec.utils.rank_bm25 import BM25Okapi | ||
| from text2vec.utils.tokenizer import segment | ||
| class RankTestCase(unittest.TestCase): | ||
| def test_en_topn(self): | ||
| """测试en文本bm25 topn""" | ||
| corpus = [ | ||
| "Hello there good man!", | ||
| "It is quite windy in London", | ||
| "How is the weather today?" | ||
| ] | ||
| tokenized_corpus = [doc.split(" ") for doc in corpus] | ||
| bm25 = BM25Okapi(tokenized_corpus) | ||
| query = "windy London" | ||
| tokenized_query = query.split(" ") | ||
| doc_scores = bm25.get_scores(tokenized_query) | ||
| print(doc_scores) | ||
| self.assertTrue(' '.join(["{:.3f}".format(i) for i in doc_scores]) == "0.000 0.937 0.000") | ||
| a = bm25.get_top_n(tokenized_query, corpus, n=2) | ||
| print(a) | ||
| self.assertEqual(a, ['It is quite windy in London', 'How is the weather today?']) | ||
| def test_zh_topn(self): | ||
| """测试zh文本bm25 topn""" | ||
| corpus = ['女网红能火的只是一小部分', '当下最火的男明星为鹿晗', "How is the weather today?", "你觉得哪个女的明星最红?"] | ||
| tokenized_corpus = [segment(doc) for doc in corpus] | ||
| bm25 = BM25Okapi(tokenized_corpus) | ||
| query = '当下最火的女的明星是谁?' | ||
| tokenized_query = segment(query) | ||
| doc_scores = bm25.get_scores(tokenized_query) | ||
| print(doc_scores) | ||
| a = bm25.get_top_n(tokenized_query, corpus, n=3) | ||
| print(a) | ||
| self.assertEqual(a, ['你觉得哪个女的明星最红?', '当下最火的男明星为鹿晗', '女网红能火的只是一小部分']) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import SentenceModel | ||
| def use_transformers(sentences=('如何更换花呗绑定银行卡', '花呗更改绑定银行卡')): | ||
| from transformers import BertTokenizer, BertModel | ||
| import torch | ||
| # Mean Pooling - Take attention mask into account for correct averaging | ||
| def mean_pooling(model_output, attention_mask): | ||
| token_embeddings = model_output[0] # First element of model_output contains all token embeddings | ||
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
| # Load model from HuggingFace Hub | ||
| tokenizer = BertTokenizer.from_pretrained('shibing624/text2vec-base-chinese') | ||
| model = BertModel.from_pretrained('shibing624/text2vec-base-chinese') | ||
| # Tokenize sentences | ||
| encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') | ||
| # Compute token embeddings | ||
| with torch.no_grad(): | ||
| model_output = model(**encoded_input) | ||
| # Perform pooling. In this case, max pooling. | ||
| sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | ||
| print(sentence_embeddings.shape) | ||
| return sentence_embeddings | ||
| class SBERTEmbeddingsTestCase(unittest.TestCase): | ||
| def test_encode_text(self): | ||
| """测试文本 text encode结果""" | ||
| a = '如何更换花呗绑定银行卡' | ||
| m = SentenceModel('shibing624/text2vec-base-chinese') | ||
| emb = m.encode(a) | ||
| print(a) | ||
| self.assertEqual(emb.shape, (768,)) | ||
| def test_tr_emb(self): | ||
| """测试test_tr_emb""" | ||
| r = use_transformers() | ||
| print(r.shape) | ||
| print("Sentence embeddings:") | ||
| print(r) | ||
| def test_sbert_encode_text(self): | ||
| """测试test_sbert_encode_text""" | ||
| a = '如何更换花呗绑定银行卡' | ||
| m = SentenceModel('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | ||
| emb = m.encode(a) | ||
| print(a) | ||
| self.assertEqual(emb.shape, (384,)) | ||
| def test_sbert_dim(self): | ||
| m = SentenceModel('shibing624/text2vec-base-chinese') | ||
| print('dim:', m.bert.pooler.dense.out_features) | ||
| def get_sentence_embedding_dimension(model): | ||
| if model: | ||
| sent_embedding_dim_method = getattr(model.bert.pooler.dense, "out_features", None) | ||
| if sent_embedding_dim_method: | ||
| return sent_embedding_dim_method | ||
| return None | ||
| dim = get_sentence_embedding_dimension(m) | ||
| print(dim) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| from time import time | ||
| import os | ||
| sys.path.append('..') | ||
| from text2vec import Similarity, SimilarityType, EmbeddingType, compute_spearmanr | ||
| pwd_path = os.path.abspath(os.path.dirname(__file__)) | ||
| sts_test_path = os.path.join(pwd_path, '../examples/data/STS-B/STS-B.test.data') | ||
| def load_test_data(path): | ||
| sents1, sents2, labels = [], [], [] | ||
| with open(path, 'r', encoding='utf8') as f: | ||
| for line in f: | ||
| line = line.strip().split('\t') | ||
| if len(line) != 3: | ||
| continue | ||
| sents1.append(line[0]) | ||
| sents2.append(line[1]) | ||
| labels.append(int(line[2])) | ||
| if len(sents1) > 10: | ||
| break | ||
| return sents1, sents2, labels | ||
| class SimTestCase(unittest.TestCase): | ||
| def test_w2v_sim_each(self): | ||
| """测试w2v_sim_each""" | ||
| m = Similarity(similarity_type=SimilarityType.COSINE, embedding_type=EmbeddingType.WORD2VEC) | ||
| print(m) | ||
| sents1, sents2, labels = load_test_data(sts_test_path) | ||
| t1 = time() | ||
| scores = [] | ||
| for s1, s2 in zip(sents1, sents2): | ||
| s = m.get_score(s1, s2) | ||
| scores.append(s) | ||
| spend_time = time() - t1 | ||
| corr = compute_spearmanr(scores, labels) | ||
| print('scores:', scores[:10]) | ||
| print('labels:', labels[:10]) | ||
| print('w2v_each_sim spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| def test_w2v_sim_batch(self): | ||
| """测试w2v_sim_batch""" | ||
| m = Similarity(similarity_type=SimilarityType.COSINE, embedding_type=EmbeddingType.WORD2VEC) | ||
| sents1, sents2, labels = load_test_data(sts_test_path) | ||
| t1 = time() | ||
| scores = m.get_scores(sents1, sents2) | ||
| sims = [] | ||
| for i in range(len(sents1)): | ||
| sims.append(scores[i][i]) | ||
| spend_time = time() - t1 | ||
| corr = compute_spearmanr(sims, labels) | ||
| print('scores:', sims[:10]) | ||
| print('labels:', labels[:10]) | ||
| print('w2v_batch_sim spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| def test_sbert_sim_each(self): | ||
| """测试sbert_sim_each""" | ||
| m = Similarity(similarity_type=SimilarityType.COSINE, embedding_type=EmbeddingType.BERT) | ||
| sents1, sents2, labels = load_test_data(sts_test_path) | ||
| t1 = time() | ||
| scores = [] | ||
| for s1, s2 in zip(sents1, sents2): | ||
| s = m.get_score(s1, s2) | ||
| scores.append(s) | ||
| spend_time = time() - t1 | ||
| corr = compute_spearmanr(scores, labels) | ||
| print('scores:', scores[:10]) | ||
| print('labels:', labels[:10]) | ||
| print('sbert_each_sim spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| def test_sbert_sim_batch(self): | ||
| """测试sbert_sim_each_batch""" | ||
| m = Similarity(similarity_type=SimilarityType.COSINE, embedding_type=EmbeddingType.BERT) | ||
| sents1, sents2, labels = load_test_data(sts_test_path) | ||
| t1 = time() | ||
| scores = m.get_scores(sents1, sents2) | ||
| sims = [] | ||
| for i in range(len(sents1)): | ||
| for j in range(len(sents2)): | ||
| print(scores[i][j], sents1[i], sents2[j]) | ||
| print() | ||
| for i in range(len(sents1)): | ||
| sims.append(scores[i][i]) | ||
| print(scores[i][i], sents1[i], sents2[i]) | ||
| spend_time = time() - t1 | ||
| corr = compute_spearmanr(sims, labels) | ||
| print('scores:', sims[:10]) | ||
| print('labels:', labels[:10]) | ||
| print('sbert_batch_sim spearman corr:', corr) | ||
| print('spend time:', spend_time, ' seconds count:', len(sents1) * 2, 'qps:', len(sents1) * 2 / spend_time) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
| # -*- coding: utf-8 -*- | ||
| """ | ||
| @author:XuMing(xuming624@qq.com) | ||
| @description: | ||
| """ | ||
| import sys | ||
| import unittest | ||
| sys.path.append('..') | ||
| from text2vec import Word2Vec | ||
| import numpy as np | ||
| w2v_model = Word2Vec() | ||
| class EmbeddingsTestCase(unittest.TestCase): | ||
| def test_encode_char(self): | ||
| """测试文本 char encode结果""" | ||
| char = '卡' | ||
| emb = w2v_model.encode(char) | ||
| t = type(emb) | ||
| print(type(emb)) | ||
| self.assertTrue(t == np.ndarray) | ||
| print(char, emb.shape) | ||
| self.assertEqual(emb.shape, (200,)) | ||
| print(' '.join(["{:.3f}".format(i) for i in emb[:3]])) | ||
| self.assertTrue(' '.join(["{:.3f}".format(i) for i in emb[:3]]) == "0.068 -0.110 -0.048") | ||
| def test_encode_word(self): | ||
| """测试文本 word encode结果""" | ||
| word = '银行卡' | ||
| emb = w2v_model.encode(word) | ||
| print(word, emb[:10]) | ||
| self.assertEqual(emb.shape, (200,)) | ||
| self.assertTrue(abs(emb[0] - 0.0103) < 0.001) | ||
| def test_encode_text(self): | ||
| """测试文本 text encode结果""" | ||
| a = '如何更换花呗绑定银行卡' | ||
| emb = w2v_model.encode(a) | ||
| print(a, emb[:10]) | ||
| self.assertEqual(emb.shape, (200,)) | ||
| self.assertTrue(abs(emb[0] - 0.02396) < 0.001) | ||
| def test_oov_emb(self): | ||
| """测试 OOV word embedding""" | ||
| w = ',' | ||
| comma_res = w2v_model.encode(w) | ||
| print(w, comma_res) | ||
| self.assertEqual(comma_res[0], 0.0) | ||
| w = '特价机票' | ||
| r = w2v_model.encode(w) | ||
| print(w, r[:10]) | ||
| w = '特价' | ||
| r1 = w2v_model.encode(w) | ||
| print(w, r1[:10]) | ||
| w = '机票' | ||
| r2 = w2v_model.encode(w) | ||
| print(w, r2[:10]) | ||
| emb = [r1, r2] | ||
| r_average = np.array(emb).sum(axis=0) / 2.0 | ||
| print('r_average:', r_average[:10], r_average.shape) | ||
| if r[0] == r_average[0]: | ||
| print('same') | ||
| else: | ||
| print('diff') | ||
| self.assertTrue(r[0] == r_average[0]) | ||
| if __name__ == '__main__': | ||
| unittest.main() |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
336640
-11.46%37
-21.28%3502
-22.28%