From fa343ef17094027d0005822bec23bf36cbbb9aa0 Mon Sep 17 00:00:00 2001 From: Jun FENG <99384777+6fj@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:57:06 +0800 Subject: [PATCH] repo-sync-2024-02-19T16:30:32+0800 (#78) --- .github/workflows/scorecard.yml | 2 +- README.md | 9 +- RELEASE.md | 5 + bazel/patches/apsi.patch | 173 ++++---- bazel/patches/grpc.patch | 1 - bazel/patches/ippcp.patch | 1 - bazel/patches/perfetto.patch | 2 +- bazel/sparsehash.BUILD | 2 +- docs/conf.py | 2 +- docs/getting_started.rst | 11 +- docs/reference/launch_config.md | 6 +- docs/reference/psi_v2_config.md | 2 +- docs/requirements.txt | 6 +- docs/user_guide/faq.md | 90 +++++ docs/user_guide/index.rst | 1 + docs/user_guide/psi.rst | 12 +- docs/user_guide/psi_v2.rst | 30 +- examples/pir/BUILD.bazel | 18 +- examples/pir/keyword_pir_client.cc | 16 +- examples/pir/keyword_pir_mem_server.cc | 21 +- examples/pir/keyword_pir_server.cc | 20 +- examples/pir/keyword_pir_setup.cc | 18 +- psi/BUILD.bazel | 81 +++- .../core/labeled_psi => apsi}/BUILD.bazel | 170 ++++++-- psi/{psi/core/labeled_psi => apsi}/README.md | 0 .../apsi_bench.cc => apsi/apsi_benchmark.cc} | 31 +- .../labeled_psi => apsi}/apsi_label_test.cc | 33 +- .../core/labeled_psi => apsi}/apsi_test.cc | 33 +- psi/{psi/core/labeled_psi => apsi}/kv_test.cc | 54 +-- psi/{psi/core/labeled_psi => apsi}/package.cc | 8 +- psi/{psi/core/labeled_psi => apsi}/package.h | 10 +- psi/{psi/core/labeled_psi => apsi}/padding.cc | 6 +- psi/{psi/core/labeled_psi => apsi}/padding.h | 4 +- psi/{pir => apsi}/pir.cc | 201 +++++---- psi/{pir => apsi}/pir.h | 6 +- psi/{pir => apsi}/pir_test.cc | 74 ++-- .../core/labeled_psi => apsi}/psi_params.cc | 40 +- .../core/labeled_psi => apsi}/psi_params.h | 26 +- .../core/labeled_psi => apsi}/receiver.cc | 77 ++-- psi/{psi/core/labeled_psi => apsi}/receiver.h | 41 +- psi/{psi/core/labeled_psi => apsi}/sender.cc | 84 ++-- psi/{psi/core/labeled_psi => apsi}/sender.h | 16 +- .../core/labeled_psi => apsi}/sender_db.cc | 49 +-- .../core/labeled_psi => apsi}/sender_db.h | 48 +-- .../core/labeled_psi => apsi}/sender_kvdb.cc | 133 +++--- .../core/labeled_psi => apsi}/sender_kvdb.h | 16 +- .../core/labeled_psi => apsi}/sender_memdb.cc | 169 ++++---- .../core/labeled_psi => apsi}/sender_memdb.h | 28 +- .../labeled_psi => apsi}/serializable.proto | 2 +- .../core/labeled_psi => apsi}/serialize.h | 48 +-- psi/{psi/core/bc22_psi => bc22}/BUILD.bazel | 16 +- psi/{psi/core/bc22_psi => bc22}/bc22_psi.cc | 12 +- psi/{psi/core/bc22_psi => bc22}/bc22_psi.h | 10 +- .../bc22_psi_benchmark.cc} | 6 +- .../core/bc22_psi => bc22}/bc22_psi_test.cc | 6 +- psi/{psi/core/bc22_psi => bc22}/emp_vole.cc | 6 +- psi/{psi/core/bc22_psi => bc22}/emp_vole.h | 10 +- .../core/bc22_psi => bc22}/emp_vole_test.cc | 8 +- .../generalized_cuckoo_hash.cc | 6 +- .../generalized_cuckoo_hash.h | 6 +- .../generalized_cuckoo_hash_test.cc | 6 +- psi/{psi => }/cryptor/BUILD.bazel | 2 +- psi/{psi => }/cryptor/cryptor_selector.cc | 16 +- psi/{psi => }/cryptor/cryptor_selector.h | 6 +- psi/{psi => }/cryptor/ecc_cryptor.cc | 6 +- psi/{psi => }/cryptor/ecc_cryptor.h | 4 +- psi/{psi => }/cryptor/ecc_utils.h | 4 +- psi/{psi => }/cryptor/ecc_utils_test.cc | 6 +- psi/{psi => }/cryptor/fourq_cryptor.cc | 6 +- psi/{psi => }/cryptor/fourq_cryptor.h | 6 +- psi/{psi => }/cryptor/fpga_ecc_cryptor.h | 6 +- .../cryptor/hash_to_curve_elligator2.cc | 6 +- .../cryptor/hash_to_curve_elligator2.h | 4 +- .../cryptor/hash_to_curve_elligator2_test.cc | 6 +- psi/{psi => }/cryptor/ipp_ecc_cryptor.cc | 8 +- psi/{psi => }/cryptor/ipp_ecc_cryptor.h | 6 +- psi/{psi => }/cryptor/sm2_cryptor.cc | 8 +- psi/{psi => }/cryptor/sm2_cryptor.h | 6 +- psi/{psi => }/cryptor/sm2_cryptor_test.cc | 6 +- .../cryptor/sodium_curve25519_cryptor.cc | 8 +- .../cryptor/sodium_curve25519_cryptor.h | 6 +- psi/ecdh/BUILD.bazel | 237 +++++++++++ .../ecdh_oprf => ecdh}/basic_ecdh_oprf.cc | 8 +- .../core/ecdh_oprf => ecdh}/basic_ecdh_oprf.h | 10 +- .../basic_ecdh_oprf_test.cc | 8 +- psi/{psi => }/ecdh/client.cc | 12 +- psi/{psi => }/ecdh/client.h | 8 +- psi/{psi => }/ecdh/common.h | 4 +- psi/{psi/core => ecdh}/ecdh_3pc_psi.cc | 8 +- psi/{psi/core => ecdh}/ecdh_3pc_psi.h | 8 +- .../ecdh_3pc_psi_benchmark.cc} | 24 +- psi/{psi/core => ecdh}/ecdh_3pc_psi_test.cc | 77 ++-- psi/ecdh/ecdh_logger.h | 41 ++ psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf.cc | 6 +- psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf.h | 6 +- psi/{psi/core => ecdh}/ecdh_oprf_psi.cc | 13 +- psi/{psi/core => ecdh}/ecdh_oprf_psi.h | 14 +- psi/{psi/core => ecdh}/ecdh_oprf_psi_test.cc | 16 +- .../ecdh_oprf => ecdh}/ecdh_oprf_selector.cc | 10 +- .../ecdh_oprf => ecdh}/ecdh_oprf_selector.h | 6 +- psi/{psi/core => ecdh}/ecdh_psi.cc | 37 +- psi/{psi/core => ecdh}/ecdh_psi.h | 20 +- .../ecdh_psi_benchmark.cc} | 20 +- psi/{psi/core => ecdh}/ecdh_psi_test.cc | 10 +- psi/{psi => }/ecdh/receiver.cc | 14 +- psi/{psi => }/ecdh/receiver.h | 10 +- psi/{psi => }/ecdh/sender.cc | 14 +- psi/{psi => }/ecdh/sender.h | 10 +- psi/{psi => }/ecdh/server.cc | 12 +- psi/{psi => }/ecdh/server.h | 8 +- psi/{psi => }/factory.cc | 22 +- psi/{psi => }/factory.h | 6 +- psi/{psi => }/interface.cc | 20 +- psi/{psi => }/interface.h | 10 +- psi/kkrt/BUILD.bazel | 88 ++++ psi/{psi => }/kkrt/common.cc | 8 +- psi/{psi => }/kkrt/common.h | 6 +- psi/{psi/core => kkrt}/kkrt_psi.cc | 12 +- psi/{psi/core => kkrt}/kkrt_psi.h | 4 +- .../kkrt_psi_benchmark.cc} | 10 +- psi/{psi/core => kkrt}/kkrt_psi_test.cc | 6 +- psi/{psi => }/kkrt/receiver.cc | 22 +- psi/{psi => }/kkrt/receiver.h | 8 +- psi/{psi => }/kkrt/sender.cc | 22 +- psi/{psi => }/kkrt/sender.h | 8 +- psi/{psi => }/launch.cc | 14 +- psi/{psi => }/launch.h | 6 +- psi/{psi/operator => legacy}/BUILD.bazel | 101 ++++- psi/{psi/operator => legacy}/base_operator.cc | 10 +- psi/{psi/operator => legacy}/base_operator.h | 4 +- .../operator => legacy}/bc22_2party_psi.cc | 12 +- .../operator => legacy}/bc22_2party_psi.h | 6 +- psi/{psi => legacy}/bucket_psi.cc | 33 +- psi/{psi => legacy}/bucket_psi.h | 16 +- psi/{psi => legacy}/bucket_psi_test.cc | 8 +- psi/{psi => legacy}/bucket_ub_psi.cc | 74 ++-- psi/{psi => legacy}/bucket_ub_psi.h | 24 +- psi/{psi => legacy}/bucket_ub_psi_test.cc | 10 +- psi/{psi/operator => legacy}/dp_2party_psi.cc | 23 +- psi/{psi/operator => legacy}/dp_2party_psi.h | 12 +- psi/{psi/core => legacy}/dp_psi/BUILD.bazel | 33 +- psi/{psi/core => legacy}/dp_psi/dp_psi.cc | 30 +- psi/{psi/core => legacy}/dp_psi/dp_psi.h | 6 +- .../dp_psi/dp_psi_benchmark.cc} | 6 +- .../dp_psi/dp_psi_payload_benchmark.cc} | 8 +- .../core => legacy}/dp_psi/dp_psi_test.cc | 6 +- .../core => legacy}/dp_psi/dp_psi_utils.cc | 6 +- .../core => legacy}/dp_psi/dp_psi_utils.h | 4 +- .../operator => legacy}/ecdh_3party_psi.cc | 14 +- .../operator => legacy}/ecdh_3party_psi.h | 10 +- psi/{psi/operator => legacy}/factory.h | 6 +- .../operator => legacy}/kkrt_2party_psi.cc | 19 +- .../operator => legacy}/kkrt_2party_psi.h | 7 +- psi/{psi => legacy}/memory_psi.cc | 18 +- psi/{psi => legacy}/memory_psi.h | 4 +- psi/{psi => legacy}/memory_psi_test.cc | 8 +- .../mini_psi}/BUILD.bazel | 43 +- psi/{psi/core => legacy/mini_psi}/mini_psi.cc | 20 +- psi/{psi/core => legacy/mini_psi}/mini_psi.h | 5 +- .../core => legacy/mini_psi}/mini_psi_demo.cc | 18 +- .../core => legacy/mini_psi}/mini_psi_test.cc | 8 +- .../mini_psi}/polynomial.cc | 6 +- .../mini_psi}/polynomial.h | 4 +- .../mini_psi}/polynomial_test.cc | 6 +- psi/{psi/operator => legacy}/nparty_psi.cc | 24 +- psi/{psi/operator => legacy}/nparty_psi.h | 8 +- .../operator => legacy}/nparty_psi_test.cc | 8 +- .../operator => legacy}/rr22_2party_psi.cc | 21 +- .../operator => legacy}/rr22_2party_psi.h | 10 +- psi/main.cc | 13 +- psi/{psi => }/prelude.h | 16 +- psi/proto/entry.proto | 6 +- psi/proto/pir.proto | 2 +- psi/proto/psi.proto | 2 +- psi/proto/psi_v2.proto | 4 +- psi/psi/BUILD.bazel | 181 --------- psi/psi/benchmark/BUILD.bazel | 64 --- psi/psi/benchmark/mparty_bench.cc | 102 ----- psi/psi/benchmark/mparty_bench.h | 310 -------------- psi/psi/benchmark/standalone_bench.cc | 45 --- psi/psi/benchmark/standalone_bench.h | 380 ------------------ psi/psi/core/BUILD.bazel | 224 ----------- psi/psi/core/ecdh_oprf/BUILD.bazel | 84 ---- psi/psi/core/generate_psi.py | 75 ---- psi/psi/ecdh/BUILD.bazel | 67 --- psi/psi/kkrt/BUILD.bazel | 52 --- psi/psi/rr22/BUILD.bazel | 51 --- psi/{psi => }/psi_test.cc | 10 +- psi/{psi/core/vole_psi => rr22}/BUILD.bazel | 76 +++- psi/{psi => }/rr22/common.cc | 8 +- psi/{psi => }/rr22/common.h | 10 +- .../vole_psi => rr22}/davis_meyer_hash.cc | 8 +- .../core/vole_psi => rr22}/davis_meyer_hash.h | 4 +- .../davis_meyer_hash_test.cc | 6 +- .../core/vole_psi => rr22}/okvs/BUILD.bazel | 0 .../core/vole_psi => rr22}/okvs/aes_crhash.cc | 6 +- .../core/vole_psi => rr22}/okvs/aes_crhash.h | 4 +- .../vole_psi => rr22}/okvs/aes_crhash_test.cc | 8 +- psi/{psi/core/vole_psi => rr22}/okvs/baxos.cc | 8 +- psi/{psi/core/vole_psi => rr22}/okvs/baxos.h | 12 +- .../core/vole_psi => rr22}/okvs/baxos_test.cc | 6 +- .../core/vole_psi => rr22}/okvs/dense_mtx.cc | 6 +- .../core/vole_psi => rr22}/okvs/dense_mtx.h | 6 +- .../core/vole_psi => rr22}/okvs/galois128.cc | 8 +- .../core/vole_psi => rr22}/okvs/galois128.h | 6 +- .../vole_psi => rr22}/okvs/galois128_test.cc | 6 +- psi/{psi/core/vole_psi => rr22}/okvs/paxos.cc | 6 +- psi/{psi/core/vole_psi => rr22}/okvs/paxos.h | 12 +- .../core/vole_psi => rr22}/okvs/paxos_hash.cc | 6 +- .../core/vole_psi => rr22}/okvs/paxos_hash.h | 8 +- .../vole_psi => rr22}/okvs/paxos_hash_test.cc | 6 +- .../core/vole_psi => rr22}/okvs/paxos_test.cc | 6 +- .../vole_psi => rr22}/okvs/paxos_utils.cc | 4 +- .../core/vole_psi => rr22}/okvs/paxos_utils.h | 6 +- .../vole_psi => rr22}/okvs/simple_index.cc | 6 +- .../vole_psi => rr22}/okvs/simple_index.h | 4 +- psi/{psi => }/rr22/receiver.cc | 22 +- psi/{psi => }/rr22/receiver.h | 8 +- psi/{psi/core/vole_psi => rr22}/rr22_oprf.cc | 10 +- psi/{psi/core/vole_psi => rr22}/rr22_oprf.h | 6 +- .../core/vole_psi => rr22}/rr22_oprf_test.cc | 8 +- psi/{psi/core/vole_psi => rr22}/rr22_psi.cc | 22 +- psi/{psi/core/vole_psi => rr22}/rr22_psi.h | 14 +- .../rr22_psi_benchmark.cc} | 16 +- .../core/vole_psi => rr22}/rr22_psi_test.cc | 12 +- psi/{psi/core/vole_psi => rr22}/rr22_utils.cc | 10 +- psi/{psi/core/vole_psi => rr22}/rr22_utils.h | 2 +- psi/{psi => }/rr22/sender.cc | 20 +- psi/{psi => }/rr22/sender.h | 8 +- .../core/vole_psi => rr22}/sparseconfig.h | 0 psi/{pir => seal_pir}/BUILD.bazel | 30 +- psi/{pir => seal_pir}/seal_mpir.cc | 16 +- psi/{pir => seal_pir}/seal_mpir.h | 14 +- psi/{pir => seal_pir}/seal_mpir_test.cc | 19 +- psi/{pir => seal_pir}/seal_pir.cc | 12 +- psi/{pir => seal_pir}/seal_pir.h | 8 +- psi/{pir => seal_pir}/seal_pir_test.cc | 16 +- psi/{pir => seal_pir}/seal_pir_utils.cc | 6 +- psi/{pir => seal_pir}/seal_pir_utils.h | 4 +- psi/{pir => seal_pir}/serializable.proto | 2 +- psi/{psi => }/trace_categories.cc | 2 +- psi/{psi => }/trace_categories.h | 0 psi/{psi => }/utils/BUILD.bazel | 48 ++- psi/{psi => }/utils/advanced_join.cc | 10 +- psi/{psi => }/utils/advanced_join.h | 4 +- psi/{psi => }/utils/advanced_join_test.cc | 6 +- .../utils/arrow_csv_batch_provider.cc | 8 +- .../utils/arrow_csv_batch_provider.h | 6 +- .../utils/arrow_csv_batch_provider_test.cc | 6 +- psi/{psi => }/utils/batch_provider.cc | 8 +- psi/{psi => }/utils/batch_provider.h | 8 +- psi/{psi => }/utils/bucket.cc | 10 +- psi/{psi => }/utils/bucket.h | 10 +- psi/{psi/core => utils}/communication.cc | 6 +- psi/{psi/core => utils}/communication.h | 6 +- psi/{psi => }/utils/csv_checker.cc | 10 +- psi/{psi => }/utils/csv_checker.h | 4 +- psi/{psi => }/utils/csv_checker_test.cc | 8 +- psi/{psi => }/utils/csv_header_analyzer.h | 4 +- psi/{psi => }/utils/csv_header_parser.cc | 6 +- psi/{psi => }/utils/csv_header_parser.h | 4 +- psi/{psi => }/utils/csv_header_parser_test.cc | 6 +- psi/{psi/core => utils}/cuckoo_index.cc | 6 +- psi/{psi/core => utils}/cuckoo_index.h | 4 +- psi/{psi/core => utils}/cuckoo_index_test.cc | 6 +- psi/{psi => }/utils/ec.cc | 10 +- psi/{psi => }/utils/ec.h | 4 +- psi/{psi => }/utils/ec_point_store.cc | 8 +- psi/{psi => }/utils/ec_point_store.h | 8 +- psi/{psi => }/utils/emp_io_adapter.cc | 2 +- psi/{psi => }/utils/emp_io_adapter.h | 0 psi/{psi => }/utils/emp_io_adapter_test.cc | 2 +- psi/{psi => }/utils/hash_bucket_cache.cc | 8 +- psi/{psi => }/utils/hash_bucket_cache.h | 8 +- psi/{psi => }/utils/index_store.cc | 6 +- psi/{psi => }/utils/index_store.h | 4 +- psi/{psi => }/utils/index_store_test.cc | 6 +- psi/{psi => }/utils/io.cc | 6 +- psi/{psi => }/utils/io.h | 4 +- psi/{psi => }/utils/key.cc | 10 +- psi/{psi => }/utils/key.h | 4 +- psi/{psi => }/utils/multiplex_disk_cache.cc | 6 +- psi/{psi => }/utils/multiplex_disk_cache.h | 6 +- .../utils/multiplex_disk_cache_test.cc | 8 +- psi/{psi => }/utils/progress.cc | 6 +- psi/{psi => }/utils/progress.h | 4 +- psi/{psi => }/utils/progress_test.cc | 6 +- psi/{psi => }/utils/recovery.cc | 8 +- psi/{psi => }/utils/recovery.h | 6 +- psi/{psi => }/utils/recovery_test.cc | 8 +- psi/{psi => }/utils/resource.cc | 6 +- psi/{psi => }/utils/resource.h | 4 +- psi/{psi => }/utils/serializable.proto | 2 +- psi/{psi => }/utils/serialize.h | 6 +- psi/{psi => }/utils/sync.cc | 8 +- psi/{psi => }/utils/sync.h | 6 +- psi/{psi => }/utils/test_utils.h | 4 +- psi/{psi => }/utils/ub_psi_cache.cc | 8 +- psi/{psi => }/utils/ub_psi_cache.h | 8 +- psi/{psi => }/utils/ub_psi_cache_test.cc | 6 +- psi/version.h | 2 +- 301 files changed, 2808 insertions(+), 3617 deletions(-) create mode 100644 docs/user_guide/faq.md rename psi/{psi/core/labeled_psi => apsi}/BUILD.bazel (52%) rename psi/{psi/core/labeled_psi => apsi}/README.md (100%) rename psi/{psi/core/labeled_psi/apsi_bench.cc => apsi/apsi_benchmark.cc} (91%) rename psi/{psi/core/labeled_psi => apsi}/apsi_label_test.cc (90%) rename psi/{psi/core/labeled_psi => apsi}/apsi_test.cc (90%) rename psi/{psi/core/labeled_psi => apsi}/kv_test.cc (90%) rename psi/{psi/core/labeled_psi => apsi}/package.cc (93%) rename psi/{psi/core/labeled_psi => apsi}/package.h (83%) rename psi/{psi/core/labeled_psi => apsi}/padding.cc (94%) rename psi/{psi/core/labeled_psi => apsi}/padding.h (94%) rename psi/{pir => apsi}/pir.cc (77%) rename psi/{pir => apsi}/pir.h (94%) rename psi/{pir => apsi}/pir_test.cc (80%) rename psi/{psi/core/labeled_psi => apsi}/psi_params.cc (91%) rename psi/{psi/core/labeled_psi => apsi}/psi_params.h (73%) rename psi/{psi/core/labeled_psi => apsi}/receiver.cc (90%) rename psi/{psi/core/labeled_psi => apsi}/receiver.h (76%) rename psi/{psi/core/labeled_psi => apsi}/sender.cc (89%) rename psi/{psi/core/labeled_psi => apsi}/sender.h (83%) rename psi/{psi/core/labeled_psi => apsi}/sender_db.cc (87%) rename psi/{psi/core/labeled_psi => apsi}/sender_db.h (88%) rename psi/{psi/core/labeled_psi => apsi}/sender_kvdb.cc (87%) rename psi/{psi/core/labeled_psi => apsi}/sender_kvdb.h (93%) rename psi/{psi/core/labeled_psi => apsi}/sender_memdb.cc (86%) rename psi/{psi/core/labeled_psi => apsi}/sender_memdb.h (89%) rename psi/{psi/core/labeled_psi => apsi}/serializable.proto (98%) rename psi/{psi/core/labeled_psi => apsi}/serialize.h (75%) rename psi/{psi/core/bc22_psi => bc22}/BUILD.bazel (90%) rename psi/{psi/core/bc22_psi => bc22}/bc22_psi.cc (98%) rename psi/{psi/core/bc22_psi => bc22}/bc22_psi.h (93%) rename psi/{psi/core/bc22_psi/bc22_psi_bench.cc => bc22/bc22_psi_benchmark.cc} (97%) rename psi/{psi/core/bc22_psi => bc22}/bc22_psi_test.cc (97%) rename psi/{psi/core/bc22_psi => bc22}/emp_vole.cc (97%) rename psi/{psi/core/bc22_psi => bc22}/emp_vole.h (93%) rename psi/{psi/core/bc22_psi => bc22}/emp_vole_test.cc (96%) rename psi/{psi/core/bc22_psi => bc22}/generalized_cuckoo_hash.cc (98%) rename psi/{psi/core/bc22_psi => bc22}/generalized_cuckoo_hash.h (97%) rename psi/{psi/core/bc22_psi => bc22}/generalized_cuckoo_hash_test.cc (97%) rename psi/{psi => }/cryptor/BUILD.bazel (99%) rename psi/{psi => }/cryptor/cryptor_selector.cc (90%) rename psi/{psi => }/cryptor/cryptor_selector.h (88%) rename psi/{psi => }/cryptor/ecc_cryptor.cc (97%) rename psi/{psi => }/cryptor/ecc_cryptor.h (98%) rename psi/{psi => }/cryptor/ecc_utils.h (99%) rename psi/{psi => }/cryptor/ecc_utils_test.cc (92%) rename psi/{psi => }/cryptor/fourq_cryptor.cc (96%) rename psi/{psi => }/cryptor/fourq_cryptor.h (92%) rename psi/{psi => }/cryptor/fpga_ecc_cryptor.h (90%) rename psi/{psi => }/cryptor/hash_to_curve_elligator2.cc (99%) rename psi/{psi => }/cryptor/hash_to_curve_elligator2.h (96%) rename psi/{psi => }/cryptor/hash_to_curve_elligator2_test.cc (95%) rename psi/{psi => }/cryptor/ipp_ecc_cryptor.cc (94%) rename psi/{psi => }/cryptor/ipp_ecc_cryptor.h (93%) rename psi/{psi => }/cryptor/sm2_cryptor.cc (95%) rename psi/{psi => }/cryptor/sm2_cryptor.h (95%) rename psi/{psi => }/cryptor/sm2_cryptor_test.cc (97%) rename psi/{psi => }/cryptor/sodium_curve25519_cryptor.cc (94%) rename psi/{psi => }/cryptor/sodium_curve25519_cryptor.h (95%) create mode 100644 psi/ecdh/BUILD.bazel rename psi/{psi/core/ecdh_oprf => ecdh}/basic_ecdh_oprf.cc (99%) rename psi/{psi/core/ecdh_oprf => ecdh}/basic_ecdh_oprf.h (96%) rename psi/{psi/core/ecdh_oprf => ecdh}/basic_ecdh_oprf_test.cc (94%) rename psi/{psi => }/ecdh/client.cc (95%) rename psi/{psi => }/ecdh/client.h (89%) rename psi/{psi => }/ecdh/common.h (92%) rename psi/{psi/core => ecdh}/ecdh_3pc_psi.cc (98%) rename psi/{psi/core => ecdh}/ecdh_3pc_psi.h (97%) rename psi/{psi/core/ecdh_3pc_psi_bench.cc => ecdh/ecdh_3pc_psi_benchmark.cc} (78%) rename psi/{psi/core => ecdh}/ecdh_3pc_psi_test.cc (78%) create mode 100644 psi/ecdh/ecdh_logger.h rename psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf.cc (96%) rename psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf.h (98%) rename psi/{psi/core => ecdh}/ecdh_oprf_psi.cc (98%) rename psi/{psi/core => ecdh}/ecdh_oprf_psi.h (96%) rename psi/{psi/core => ecdh}/ecdh_oprf_psi_test.cc (97%) rename psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf_selector.cc (96%) rename psi/{psi/core/ecdh_oprf => ecdh}/ecdh_oprf_selector.h (92%) rename psi/{psi/core => ecdh}/ecdh_psi.cc (92%) rename psi/{psi/core => ecdh}/ecdh_psi.h (92%) rename psi/{psi/core/ecdh_psi_bench.cc => ecdh/ecdh_psi_benchmark.cc} (82%) rename psi/{psi/core => ecdh}/ecdh_psi_test.cc (97%) rename psi/{psi => }/ecdh/receiver.cc (94%) rename psi/{psi => }/ecdh/receiver.h (87%) rename psi/{psi => }/ecdh/sender.cc (94%) rename psi/{psi => }/ecdh/sender.h (87%) rename psi/{psi => }/ecdh/server.cc (96%) rename psi/{psi => }/ecdh/server.h (89%) rename psi/{psi => }/factory.cc (87%) rename psi/{psi => }/factory.h (92%) rename psi/{psi => }/interface.cc (97%) rename psi/{psi => }/interface.h (96%) create mode 100644 psi/kkrt/BUILD.bazel rename psi/{psi => }/kkrt/common.cc (88%) rename psi/{psi => }/kkrt/common.h (89%) rename psi/{psi/core => kkrt}/kkrt_psi.cc (98%) rename psi/{psi/core => kkrt}/kkrt_psi.h (98%) rename psi/{psi/core/kkrt_psi_bench.cc => kkrt/kkrt_psi_benchmark.cc} (88%) rename psi/{psi/core => kkrt}/kkrt_psi_test.cc (98%) rename psi/{psi => }/kkrt/receiver.cc (93%) rename psi/{psi => }/kkrt/receiver.h (90%) rename psi/{psi => }/kkrt/sender.cc (92%) rename psi/{psi => }/kkrt/sender.h (89%) rename psi/{psi => }/launch.cc (96%) rename psi/{psi => }/launch.h (94%) rename psi/{psi/operator => legacy}/BUILD.bazel (57%) rename psi/{psi/operator => legacy}/base_operator.cc (87%) rename psi/{psi/operator => legacy}/base_operator.h (96%) rename psi/{psi/operator => legacy}/bc22_2party_psi.cc (88%) rename psi/{psi/operator => legacy}/bc22_2party_psi.h (92%) rename psi/{psi => legacy}/bucket_psi.cc (96%) rename psi/{psi => legacy}/bucket_psi.h (95%) rename psi/{psi => legacy}/bucket_psi_test.cc (99%) rename psi/{psi => legacy}/bucket_ub_psi.cc (87%) rename psi/{psi => legacy}/bucket_ub_psi.h (76%) rename psi/{psi => legacy}/bucket_ub_psi_test.cc (98%) rename psi/{psi/operator => legacy}/dp_2party_psi.cc (78%) rename psi/{psi/operator => legacy}/dp_2party_psi.h (81%) rename psi/{psi/core => legacy}/dp_psi/BUILD.bazel (77%) rename psi/{psi/core => legacy}/dp_psi/dp_psi.cc (94%) rename psi/{psi/core => legacy}/dp_psi/dp_psi.h (96%) rename psi/{psi/core/dp_psi/dp_psi_bench.cc => legacy/dp_psi/dp_psi_benchmark.cc} (98%) rename psi/{psi/core/dp_psi/dp_psi_payload_bench.cc => legacy/dp_psi/dp_psi_payload_benchmark.cc} (98%) rename psi/{psi/core => legacy}/dp_psi/dp_psi_test.cc (97%) rename psi/{psi/core => legacy}/dp_psi/dp_psi_utils.cc (98%) rename psi/{psi/core => legacy}/dp_psi/dp_psi_utils.h (95%) rename psi/{psi/operator => legacy}/ecdh_3party_psi.cc (91%) rename psi/{psi/operator => legacy}/ecdh_3party_psi.h (95%) rename psi/{psi/operator => legacy}/factory.h (96%) rename psi/{psi/operator => legacy}/kkrt_2party_psi.cc (80%) rename psi/{psi/operator => legacy}/kkrt_2party_psi.h (90%) rename psi/{psi => legacy}/memory_psi.cc (90%) rename psi/{psi => legacy}/memory_psi.h (96%) rename psi/{psi => legacy}/memory_psi_test.cc (98%) rename psi/{psi/core/polynomial => legacy/mini_psi}/BUILD.bazel (50%) rename psi/{psi/core => legacy/mini_psi}/mini_psi.cc (97%) rename psi/{psi/core => legacy/mini_psi}/mini_psi.h (96%) rename psi/{psi/core => legacy/mini_psi}/mini_psi_demo.cc (91%) rename psi/{psi/core => legacy/mini_psi}/mini_psi_test.cc (95%) rename psi/{psi/core/polynomial => legacy/mini_psi}/polynomial.cc (98%) rename psi/{psi/core/polynomial => legacy/mini_psi}/polynomial.h (96%) rename psi/{psi/core/polynomial => legacy/mini_psi}/polynomial_test.cc (96%) rename psi/{psi/operator => legacy}/nparty_psi.cc (94%) rename psi/{psi/operator => legacy}/nparty_psi.h (95%) rename psi/{psi/operator => legacy}/nparty_psi_test.cc (97%) rename psi/{psi/operator => legacy}/rr22_2party_psi.cc (89%) rename psi/{psi/operator => legacy}/rr22_2party_psi.h (85%) rename psi/{psi => }/prelude.h (63%) delete mode 100644 psi/psi/BUILD.bazel delete mode 100644 psi/psi/benchmark/BUILD.bazel delete mode 100644 psi/psi/benchmark/mparty_bench.cc delete mode 100644 psi/psi/benchmark/mparty_bench.h delete mode 100644 psi/psi/benchmark/standalone_bench.cc delete mode 100644 psi/psi/benchmark/standalone_bench.h delete mode 100644 psi/psi/core/BUILD.bazel delete mode 100644 psi/psi/core/ecdh_oprf/BUILD.bazel delete mode 100644 psi/psi/core/generate_psi.py delete mode 100644 psi/psi/ecdh/BUILD.bazel delete mode 100644 psi/psi/kkrt/BUILD.bazel delete mode 100644 psi/psi/rr22/BUILD.bazel rename psi/{psi => }/psi_test.cc (99%) rename psi/{psi/core/vole_psi => rr22}/BUILD.bazel (74%) rename psi/{psi => }/rr22/common.cc (91%) rename psi/{psi => }/rr22/common.h (84%) rename psi/{psi/core/vole_psi => rr22}/davis_meyer_hash.cc (94%) rename psi/{psi/core/vole_psi => rr22}/davis_meyer_hash.h (96%) rename psi/{psi/core/vole_psi => rr22}/davis_meyer_hash_test.cc (93%) rename psi/{psi/core/vole_psi => rr22}/okvs/BUILD.bazel (100%) rename psi/{psi/core/vole_psi => rr22}/okvs/aes_crhash.cc (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/aes_crhash.h (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/aes_crhash_test.cc (92%) rename psi/{psi/core/vole_psi => rr22}/okvs/baxos.cc (99%) rename psi/{psi/core/vole_psi => rr22}/okvs/baxos.h (95%) rename psi/{psi/core/vole_psi => rr22}/okvs/baxos_test.cc (95%) rename psi/{psi/core/vole_psi => rr22}/okvs/dense_mtx.cc (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/dense_mtx.h (98%) rename psi/{psi/core/vole_psi => rr22}/okvs/galois128.cc (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/galois128.h (95%) rename psi/{psi/core/vole_psi => rr22}/okvs/galois128_test.cc (95%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos.cc (99%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos.h (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_hash.cc (99%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_hash.h (97%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_hash_test.cc (93%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_test.cc (96%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_utils.cc (88%) rename psi/{psi/core/vole_psi => rr22}/okvs/paxos_utils.h (98%) rename psi/{psi/core/vole_psi => rr22}/okvs/simple_index.cc (99%) rename psi/{psi/core/vole_psi => rr22}/okvs/simple_index.h (93%) rename psi/{psi => }/rr22/receiver.cc (93%) rename psi/{psi => }/rr22/receiver.h (89%) rename psi/{psi/core/vole_psi => rr22}/rr22_oprf.cc (99%) rename psi/{psi/core/vole_psi => rr22}/rr22_oprf.h (98%) rename psi/{psi/core/vole_psi => rr22}/rr22_oprf_test.cc (96%) rename psi/{psi/core/vole_psi => rr22}/rr22_psi.cc (91%) rename psi/{psi/core/vole_psi => rr22}/rr22_psi.h (86%) rename psi/{psi/core/vole_psi/rr22_psi_bench.cc => rr22/rr22_psi_benchmark.cc} (90%) rename psi/{psi/core/vole_psi => rr22}/rr22_psi_test.cc (91%) rename psi/{psi/core/vole_psi => rr22}/rr22_utils.cc (96%) rename psi/{psi/core/vole_psi => rr22}/rr22_utils.h (97%) rename psi/{psi => }/rr22/sender.cc (92%) rename psi/{psi => }/rr22/sender.h (89%) rename psi/{psi/core/vole_psi => rr22}/sparseconfig.h (100%) rename psi/{pir => seal_pir}/BUILD.bazel (75%) rename psi/{pir => seal_pir}/seal_mpir.cc (96%) rename psi/{pir => seal_pir}/seal_mpir.h (94%) rename psi/{pir => seal_pir}/seal_mpir_test.cc (91%) rename psi/{pir => seal_pir}/seal_pir.cc (99%) rename psi/{pir => seal_pir}/seal_pir.h (98%) rename psi/{pir => seal_pir}/seal_pir_test.cc (90%) rename psi/{pir => seal_pir}/seal_pir_utils.cc (95%) rename psi/{pir => seal_pir}/seal_pir_utils.h (98%) rename psi/{pir => seal_pir}/serializable.proto (98%) rename psi/{psi => }/trace_categories.cc (94%) rename psi/{psi => }/trace_categories.h (100%) rename psi/{psi => }/utils/BUILD.bazel (88%) rename psi/{psi => }/utils/advanced_join.cc (99%) rename psi/{psi => }/utils/advanced_join.h (98%) rename psi/{psi => }/utils/advanced_join_test.cc (99%) rename psi/{psi => }/utils/arrow_csv_batch_provider.cc (95%) rename psi/{psi => }/utils/arrow_csv_batch_provider.h (94%) rename psi/{psi => }/utils/arrow_csv_batch_provider_test.cc (96%) rename psi/{psi => }/utils/batch_provider.cc (99%) rename psi/{psi => }/utils/batch_provider.h (97%) rename psi/{psi => }/utils/bucket.cc (95%) rename psi/{psi => }/utils/bucket.h (89%) rename psi/{psi/core => utils}/communication.cc (96%) rename psi/{psi/core => utils}/communication.h (97%) rename psi/{psi => }/utils/csv_checker.cc (98%) rename psi/{psi => }/utils/csv_checker.h (96%) rename psi/{psi => }/utils/csv_checker_test.cc (98%) rename psi/{psi => }/utils/csv_header_analyzer.h (99%) rename psi/{psi => }/utils/csv_header_parser.cc (96%) rename psi/{psi => }/utils/csv_header_parser.h (96%) rename psi/{psi => }/utils/csv_header_parser_test.cc (94%) rename psi/{psi/core => utils}/cuckoo_index.cc (98%) rename psi/{psi/core => utils}/cuckoo_index.h (99%) rename psi/{psi/core => utils}/cuckoo_index_test.cc (97%) rename psi/{psi => }/utils/ec.cc (89%) rename psi/{psi => }/utils/ec.h (93%) rename psi/{psi => }/utils/ec_point_store.cc (98%) rename psi/{psi => }/utils/ec_point_store.h (96%) rename psi/{psi => }/utils/emp_io_adapter.cc (99%) rename psi/{psi => }/utils/emp_io_adapter.h (100%) rename psi/{psi => }/utils/emp_io_adapter_test.cc (98%) rename psi/{psi => }/utils/hash_bucket_cache.cc (95%) rename psi/{psi => }/utils/hash_bucket_cache.h (94%) rename psi/{psi => }/utils/index_store.cc (98%) rename psi/{psi => }/utils/index_store.h (97%) rename psi/{psi => }/utils/index_store_test.cc (97%) rename psi/{psi => }/utils/io.cc (97%) rename psi/{psi => }/utils/io.h (98%) rename psi/{psi => }/utils/key.cc (95%) rename psi/{psi => }/utils/key.h (97%) rename psi/{psi => }/utils/multiplex_disk_cache.cc (96%) rename psi/{psi => }/utils/multiplex_disk_cache.h (96%) rename psi/{psi => }/utils/multiplex_disk_cache_test.cc (95%) rename psi/{psi => }/utils/progress.cc (97%) rename psi/{psi => }/utils/progress.h (97%) rename psi/{psi => }/utils/progress_test.cc (96%) rename psi/{psi => }/utils/recovery.cc (98%) rename psi/{psi => }/utils/recovery.h (97%) rename psi/{psi => }/utils/recovery_test.cc (96%) rename psi/{psi => }/utils/resource.cc (95%) rename psi/{psi => }/utils/resource.h (94%) rename psi/{psi => }/utils/serializable.proto (97%) rename psi/{psi => }/utils/serialize.h (94%) rename psi/{psi => }/utils/sync.cc (94%) rename psi/{psi => }/utils/sync.h (97%) rename psi/{psi => }/utils/test_utils.h (96%) rename psi/{psi => }/utils/ub_psi_cache.cc (97%) rename psi/{psi => }/utils/ub_psi_cache.h (95%) rename psi/{psi => }/utils/ub_psi_cache_test.cc (96%) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index f84f2f1f..188012c1 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@379614612a29c9e28f31f39a59013eb8012a51f0 # v3.24.3 + uses: github/codeql-action/upload-sarif@0b21cf2492b6b02c465a3e5d7c473717ad7721ba # v3.23.1 with: sarif_file: results.sarif diff --git a/README.md b/README.md index 7ada0339..9cd9a228 100644 --- a/README.md +++ b/README.md @@ -119,13 +119,13 @@ sender.config: In the first terminal, run the following command ```bash -docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config receiver/receiver.config" +docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config receiver/receiver.config" ``` In the other terminal, run the following command simultaneously. ```bash -docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config sender/sender.config" +docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config sender/sender.config" ``` ## Building SecretFlow PSI Library @@ -155,9 +155,12 @@ docker exec -it psi-dev-$(whoami) bash #### Linux ```sh -Install gcc>=11.2, cmake>=3.26, ninja, nasm>=2.15, python>=3.8, bazel==6.4.0, golang, xxd, lld +Install gcc>=11.2, cmake>=3.26, ninja, nasm>=2.15, python>=3.8, bazel, golang, xxd, lld ``` +> **Note**
+Please install bazel with version in .bazelversion or use bazelisk. + ### Build & UnitTest diff --git a/RELEASE.md b/RELEASE.md index 23b10952..819569a2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,6 +5,11 @@ > - `[API]` prefix for API changes. > - `[Improvement]` prefix for implementation improvement. +## v0.2.0.dev240219 + +- [Feature] add ecdh logger for debug purposes. +- [API] modify repo structure. + ## v0.2.0.dev240123 - [Feature] add RFC9380 25519 elligator2 hash_to_curve. diff --git a/bazel/patches/apsi.patch b/bazel/patches/apsi.patch index 34bd1d09..13eb7e79 100644 --- a/bazel/patches/apsi.patch +++ b/bazel/patches/apsi.patch @@ -95,9 +95,88 @@ index 3b15780..5085038 100644 if(APSI_FOURQ_AMD64) add_subdirectory(amd64) +diff --git a/common/apsi/fourq/FourQ_internal.h b/common/apsi/fourq/FourQ_internal.h +index 009bb1d..5aa9886 100644 +--- a/common/apsi/fourq/FourQ_internal.h ++++ b/common/apsi/fourq/FourQ_internal.h +@@ -143,7 +143,7 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) + + // 64x64-bit multiplication + #define MUL128(multiplier, multiplicand, product) \ +- mp_mul( \ ++ fq_mp_mul( \ + (digit_t *)&(multiplier), \ + (digit_t *)&(multiplicand), \ + (digit_t *)&(product), \ +@@ -151,12 +151,12 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) + + // 128-bit addition, inputs < 2^127 + #define ADD128(addend1, addend2, addition) \ +- mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); ++ fq_mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); + + // 128-bit addition with output carry + #define ADC128(addend1, addend2, carry, addition) \ + (carry) = \ +- mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); ++ fq_mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); + + #elif (TARGET == TARGET_AMD64 && OS_TARGET == OS_WIN) + +@@ -257,10 +257,10 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) + bool is_zero_ct(digit_t *a, unsigned int nwords); + + // Multiprecision addition, c = a+b. Returns the carry bit +-unsigned int mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords); ++unsigned int fq_mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords); + + // Schoolbook multiprecision multiply, c = a*b +-void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords); ++void fq_mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords); + + // Multiprecision subtraction, c = a-b. Returns the borrow bit + #if defined(GENERIC_IMPLEMENTATION) +diff --git a/common/apsi/fourq/generic/fp.h b/common/apsi/fourq/generic/fp.h +index f475de1..e24a26a 100644 +--- a/common/apsi/fourq/generic/fp.h ++++ b/common/apsi/fourq/generic/fp.h +@@ -172,7 +172,7 @@ void mod1271(felm_t a) + ADDC(borrow, a[NWORDS_FIELD - 1], (mask >> 1), borrow, a[NWORDS_FIELD - 1]); + } + +-void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) ++void fq_mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) + { // Schoolbook multiprecision multiply, c = a*b + unsigned int i, j; + digit_t u, v, UV[2]; +@@ -195,7 +195,7 @@ void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int n + } + } + +-unsigned int mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords) ++unsigned int fq_mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords) + { // Multiprecision addition, c = a+b, where lng(a) = lng(b) = nwords. Returns the carry bit + unsigned int i, carry = 0; + +@@ -263,13 +263,13 @@ void fpinv1271(felm_t a) + static void multiply(const digit_t *a, const digit_t *b, digit_t *c) + { // Schoolbook multiprecision multiply, c = a*b + +- mp_mul(a, b, c, NWORDS_ORDER); ++ fq_mp_mul(a, b, c, NWORDS_ORDER); + } + + static unsigned int add(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) + { // Multiprecision addition, c = a+b, where lng(a) = lng(b) = nwords. Returns the carry bit + +- return mp_add((digit_t *)a, (digit_t *)b, c, (unsigned int)nwords); ++ return fq_mp_add((digit_t *)a, (digit_t *)b, c, (unsigned int)nwords); + } + + unsigned int subtract(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) diff --git a/common/apsi/fourq/kex.c b/common/apsi/fourq/kex.c new file mode 100644 -index 0000000..5c37c60 +index 0000000..d59af6d --- /dev/null +++ b/common/apsi/fourq/kex.c @@ -0,0 +1,181 @@ @@ -282,6 +361,19 @@ index 0000000..5c37c60 + + return Status; +} +diff --git a/common/apsi/util/stopwatch.h b/common/apsi/util/stopwatch.h +index e09a53b..c5e4bab 100644 +--- a/common/apsi/util/stopwatch.h ++++ b/common/apsi/util/stopwatch.h +@@ -22,7 +22,7 @@ + + // Measure a block + #define STOPWATCH(stopwatch, name) \ +- apsi::util::StopwatchScope UNIQUE_STOPWATCH_NAME(stopwatchscope)(stopwatch, name); ++ ::apsi::util::StopwatchScope UNIQUE_STOPWATCH_NAME(stopwatchscope)(stopwatch, name); + + namespace apsi { + namespace util { diff --git a/receiver/apsi/CMakeLists.txt b/receiver/apsi/CMakeLists.txt index afce298..7757b68 100644 --- a/receiver/apsi/CMakeLists.txt @@ -355,82 +447,3 @@ index fd245d7..99e4228 100644 DESTINATION ${APSI_INCLUDES_INSTALL_DIR}/apsi ) - diff --git a/common/apsi/fourq/FourQ_internal.h b/common/apsi/fourq/FourQ_internal.h -index 009bb1d..5aa9886 100644 ---- a/common/apsi/fourq/FourQ_internal.h -+++ b/common/apsi/fourq/FourQ_internal.h -@@ -143,7 +143,7 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) - - // 64x64-bit multiplication - #define MUL128(multiplier, multiplicand, product) \ -- mp_mul( \ -+ fq_mp_mul( \ - (digit_t *)&(multiplier), \ - (digit_t *)&(multiplicand), \ - (digit_t *)&(product), \ -@@ -151,12 +151,12 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) - - // 128-bit addition, inputs < 2^127 - #define ADD128(addend1, addend2, addition) \ -- mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); -+ fq_mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); - - // 128-bit addition with output carry - #define ADC128(addend1, addend2, carry, addition) \ - (carry) = \ -- mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); -+ fq_mp_add((digit_t *)(addend1), (digit_t *)(addend2), (digit_t *)(addition), NWORDS_FIELD); - - #elif (TARGET == TARGET_AMD64 && OS_TARGET == OS_WIN) - -@@ -257,10 +257,10 @@ static __inline unsigned int is_digit_lessthan_ct(digit_t x, digit_t y) - bool is_zero_ct(digit_t *a, unsigned int nwords); - - // Multiprecision addition, c = a+b. Returns the carry bit --unsigned int mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords); -+unsigned int fq_mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords); - - // Schoolbook multiprecision multiply, c = a*b --void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords); -+void fq_mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords); - - // Multiprecision subtraction, c = a-b. Returns the borrow bit - #if defined(GENERIC_IMPLEMENTATION) -diff --git a/common/apsi/fourq/generic/fp.h b/common/apsi/fourq/generic/fp.h -index f475de1..e24a26a 100644 ---- a/common/apsi/fourq/generic/fp.h -+++ b/common/apsi/fourq/generic/fp.h -@@ -172,7 +172,7 @@ void mod1271(felm_t a) - ADDC(borrow, a[NWORDS_FIELD - 1], (mask >> 1), borrow, a[NWORDS_FIELD - 1]); - } - --void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) -+void fq_mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) - { // Schoolbook multiprecision multiply, c = a*b - unsigned int i, j; - digit_t u, v, UV[2]; -@@ -195,7 +195,7 @@ void mp_mul(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int n - } - } - --unsigned int mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords) -+unsigned int fq_mp_add(digit_t *a, digit_t *b, digit_t *c, unsigned int nwords) - { // Multiprecision addition, c = a+b, where lng(a) = lng(b) = nwords. Returns the carry bit - unsigned int i, carry = 0; - -@@ -263,13 +263,13 @@ void fpinv1271(felm_t a) - static void multiply(const digit_t *a, const digit_t *b, digit_t *c) - { // Schoolbook multiprecision multiply, c = a*b - -- mp_mul(a, b, c, NWORDS_ORDER); -+ fq_mp_mul(a, b, c, NWORDS_ORDER); - } - - static unsigned int add(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) - { // Multiprecision addition, c = a+b, where lng(a) = lng(b) = nwords. Returns the carry bit - -- return mp_add((digit_t *)a, (digit_t *)b, c, (unsigned int)nwords); -+ return fq_mp_add((digit_t *)a, (digit_t *)b, c, (unsigned int)nwords); - } - - unsigned int subtract(const digit_t *a, const digit_t *b, digit_t *c, const unsigned int nwords) diff --git a/bazel/patches/grpc.patch b/bazel/patches/grpc.patch index 08bde547..fd8e09fb 100644 --- a/bazel/patches/grpc.patch +++ b/bazel/patches/grpc.patch @@ -30,4 +30,3 @@ index 72e1b6609e..aded52d0db 100644 #include #include #include - diff --git a/bazel/patches/ippcp.patch b/bazel/patches/ippcp.patch index ad5a909c..0af05b25 100644 --- a/bazel/patches/ippcp.patch +++ b/bazel/patches/ippcp.patch @@ -248,4 +248,3 @@ index 315d1a3..8b11c7a 100644 endif(${ARCH} MATCHES "ia32") endif(APPLE) endif(UNIX) - diff --git a/bazel/patches/perfetto.patch b/bazel/patches/perfetto.patch index 70ecb72e..99254d1b 100644 --- a/bazel/patches/perfetto.patch +++ b/bazel/patches/perfetto.patch @@ -13,4 +13,4 @@ index 4ebb0576b..6322273b8 100644 + #include #include - #include + #include \ No newline at end of file diff --git a/bazel/sparsehash.BUILD b/bazel/sparsehash.BUILD index baa8d805..d53da1fb 100644 --- a/bazel/sparsehash.BUILD +++ b/bazel/sparsehash.BUILD @@ -19,6 +19,6 @@ cc_library( includes = ["src"], visibility = ["//visibility:public"], deps = [ - "@psi//psi/psi/core/vole_psi:sparsehash_config", + "@psi//psi/rr22:sparsehash_config", ], ) diff --git a/docs/conf.py b/docs/conf.py index e89f2a53..e132f2f7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -39,7 +39,7 @@ # global variables extlinks = { "psi_doc_host": ("https://www.secretflow.org.cn/docs/psi/en/", "doc "), - "psi_code_host": ("https://github.com/secretflow", "code "), + "psi_code_host": ("https://github.com/secretflow/psi/", "code "), } html_theme = "pydata_sphinx_theme" diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 91cdcb66..7073f8f8 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -7,7 +7,7 @@ Welcome to SecretFlow PSI Library. There are multiple methods to use PSI/PIR. * Python packages * `SPU `_ warps the library as Python bindings. You could call PSI/PIR with spu. - * `SecretFlow `_ warps SPU further. + * `SecretFlow `_ warps SPU further with user-friendly APIs. * Applications @@ -147,11 +147,13 @@ You need to install: * ninja * nasm>=2.15 * python>=3.8 -* bazel==6.4.0 +* bazel * golang * xxd * lld +For bazel, please check version in `.bazelversion `_ or use bazelisk instead. + Build & UnitTest ^^^^^^^^^^^^^^^^ @@ -172,8 +174,3 @@ Reporting an Issue Please create an issue at `Github Issues `_. We will look into issues and get back to you soon. - -Frequently Asked Questions (FAQ) --------------------------------- - -We will collect some popular questions from users and update this part promptly. diff --git a/docs/reference/launch_config.md b/docs/reference/launch_config.md index 1c1aa1ba..bd6622f1 100644 --- a/docs/reference/launch_config.md +++ b/docs/reference/launch_config.md @@ -41,9 +41,9 @@ Please check psi.v2.PsiConfig and psi.v2.UbPsiConfig at **PSI v2 Configuration** | ----- | ---- | ----------- | | link_config | [ yacl.link.ContextDescProto](#yacllinkcontextdescproto) | Configs for network. | | self_link_party | [ string](#string) | With link_config. | -| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.legacy_psi_config | [ psi.BucketPsiConfig](#psibucketpsiconfig) | Please check at psi.proto. | -| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.psi_config | [ psi.v2.PsiConfig](#psiv2psiconfig) | Please check at psi_v2.proto. | -| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.ub_psi_config | [ psi.v2.UbPsiConfig](#psiv2ubpsiconfig) | Please check at psi_v2.proto. | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.legacy_psi_config | [ BucketPsiConfig](#bucketpsiconfig) | Please check at psi.proto. | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.psi_config | [ v2.PsiConfig](#v2psiconfig) | Please check at psi_v2.proto. | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) runtime_config.ub_psi_config | [ v2.UbPsiConfig](#v2ubpsiconfig) | Please check at psi_v2.proto. | diff --git a/docs/reference/psi_v2_config.md b/docs/reference/psi_v2_config.md index a7cffa61..aaa67bc6 100644 --- a/docs/reference/psi_v2_config.md +++ b/docs/reference/psi_v2_config.md @@ -66,7 +66,7 @@ Configs for ECDH protocol. | Field | Type | Description | | ----- | ---- | ----------- | -| curve | [ psi.psi.CurveType](#psipsicurvetype) | none | +| curve | [ psi.CurveType](#psicurvetype) | none | diff --git a/docs/requirements.txt b/docs/requirements.txt index d2ff723c..673bb04c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ -nbsphinx==0.9.3 -sphinx==7.2.6 -myst-parser==2.0.0 +nbsphinx==0.8.9 +sphinx==5.3.0 +myst-parser==0.18.1 sphinx-intl==2.1.0 pydata-sphinx-theme diff --git a/docs/user_guide/faq.md b/docs/user_guide/faq.md new file mode 100644 index 00000000..8e92122f --- /dev/null +++ b/docs/user_guide/faq.md @@ -0,0 +1,90 @@ +# Frequently Asked Questions (FAQ) + +We will collect some popular questions from users and update this part promptly. + +## Config Issues + +1. In PSI config, what is difference of **broadcast_result** and **receiver**? Is it safe to turn on **broadcast_result**? + +In PSI protocols, the parties who are promised to receive the intersection are called **receiver**s, the other parties are called **sender**s. +When **broadcast_result** is turn on, **sender**s also receive the intersection. Both parties must agree on the value of **broadcast_result**, otherwise the program will stop. + +If **broadcast_result** is turn on, only **receiver**s and **sender**s could receive the result while any third parties could not see. So it is safe to set **broadcast_result** to true, if both **receiver**s and **sender**s wish to get the result. + +2. What is **IO_TYPE_UNSPECIFIED**? + +You must select a type as IoType. **IO_TYPE_UNSPECIFIED** is the default value of **IoType**, which is meaningless. At this moment, we only support **IO_TYPE_FILE_CSV**. + +3. What is **ADVANCED_JOIN_TYPE_UNSPECIFIED**? + +PSI protocols doesn‘t allow duplicates in ids of inputs. However, sometimes we may intend to have duplicates in ids and perform LEFT / RIGHT / FULL join following rules of SQL. This is called **AdvancedJoinType**. + +**ADVANCED_JOIN_TYPE_UNSPECIFIED** is default value of AdvancedJoinType, which means default implementation of PSI configs and duplicates is disallowed. If ids of inputs contains duplicates at this moment, the behavior is undefined. + +4. What is the recommendation value of bucket size? + +The default value is 2^20. You shouldn't set this value unless you have very limited computation resource. + +5. What is **disable_alignment**? + +If **disable_alignment** turns on, the intersection received by **receiver**s and **sender**s are not promised to be aligned(the order doesn't match) and save time. + +If any **AdvancedJoinType** is specified, aligement is promised due to implementation, **disable_alignment** is ignored. + +6. What is **RetryOptionsProto** in **ContextDescProto**? + +We have proper default values for all fields. You shouldn't set any values unless the network is pretty bad. + +## Feature Issues + +1. How to enable SSL? + +We support mTLS and you should provide proper **ContextDescProto**: + +- **enable_ssl** is enabled. +- In **client_ssl_opts**, set **verify_depth** and provide peer CA file with **ca_file_path** +- In **server_ssl_opts**, provide self certificate and private key file with **certificate_path** and **private_key_path** +- You must provide these settings at both sides. + +``` +{ + "psi_config": {}, + "link_config": { + "parties": [ + { + "id": "receiver", + "host": "127.0.0.1:5300" + }, + { + "id": "sender", + "host": "127.0.0.1:5400" + } + ], + "enable_ssl": true, + "client_ssl_opts": { + "verify_depth": 1, + "ca_file_path": "/path/to/peer/CA/file" + }, + "server_ssl_opts": { + "certificate_path": "/path/to/self/certificate/file", + "private_key_path": "/path/to/self/private/key/file" + } + }, + "self_link_party": "sender" +} +``` + +2. How to use recovery? + +We provide recovery feature in PSI v2. + +You have to provide a proper **RecoveryConfig**: + +- **enabled** set to true. +- **folder** is provided to store checkpoints. + +If a PSI task fails, just restart the task with the same config, the progress will resume. + +3. What is **Easy PSI**? Why and when to use **Easy PSI**? + +[Easy PSI](https://www.secretflow.org.cn/docs/quickstart/easy-psi) is a standalone PSI product powered by this library. It provides a simple User Interface and utilize [Kuscia](https://www.secretflow.org.cn/docs/kuscia) to launch PSI binaries between both parties. diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst index ee78d35f..dad825f4 100644 --- a/docs/user_guide/index.rst +++ b/docs/user_guide/index.rst @@ -9,3 +9,4 @@ PSI v2 is recommended to use. We are still working on PIR code refactoring. psi psi_v2 pir + faq \ No newline at end of file diff --git a/docs/user_guide/psi.rst b/docs/user_guide/psi.rst index 5e63e564..d03fc3d1 100644 --- a/docs/user_guide/psi.rst +++ b/docs/user_guide/psi.rst @@ -7,9 +7,9 @@ Quick start with Private Set Intersection (PSI) V1 APIs. Supported Protocols ---------------------- -.. The :psi_code_host:`ECDH-PSI ` is favorable if the bandwidth is the bottleneck. -.. If the computing is the bottleneck, you should try the BaRK-OPRF based -.. PSI :psi_code_host:`KKRT-PSI API `. +The :psi_code_host:`ECDH-PSI ` is favorable if the bandwidth is the bottleneck. +If the computing is the bottleneck, you should try the BaRK-OPRF based +PSI :psi_code_host:`KKRT-PSI `. +---------------+--------------+--------------+--------------+ | PSI protocols | Threat Model | Party Number | PsiTypeCode | @@ -69,12 +69,12 @@ Run PSI In the first terminal, run the following command:: - docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config receiver/receiver.config" + docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config receiver/receiver.config" In the other terminal, run the following command simultaneously:: - docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config sender/sender.config" + docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:latest bash -c "./main --config sender/sender.config" Building from source @@ -97,7 +97,7 @@ benchmark result without data load time ecdh-psi Benchmark >>>>>>>>>>>>>>>>>> -:psi_code_host:`DH-PSI benchmark code ` +:psi_code_host:`DH-PSI benchmark code ` cpu limited by docker(--cpu) diff --git a/docs/user_guide/psi_v2.rst b/docs/user_guide/psi_v2.rst index b0dda180..eff234d3 100644 --- a/docs/user_guide/psi_v2.rst +++ b/docs/user_guide/psi_v2.rst @@ -96,20 +96,20 @@ To launch PSI, please check LaunchConfig at :doc:`/reference/launch_config` and "check_hash_digest": false, "recovery_config": { "enabled": false - }, - "link_config": { - "parties": [ - { - "id": "receiver", - "host": "127.0.0.1:5300" - }, - { - "id": "sender", - "host": "127.0.0.1:5400" - } - ] } }, + "link_config": { + "parties": [ + { + "id": "receiver", + "host": "127.0.0.1:5300" + }, + { + "id": "sender", + "host": "127.0.0.1:5400" + } + ] + }, "self_link_party": "sender" } @@ -123,7 +123,7 @@ You need to prepare following files: +------------------------+------------------------------------------------+-------------------------------------------------------------------------------+ | sender.config | /tmp/sender/sender.config | Config for sender. | +------------------------+------------------------------------------------+-------------------------------------------------------------------------------+ -| receiver_input.csv | /tmp/receiver/receiver_input.config | SupInput for receiver. Make sure the file contains two id keys - id0 and id1. | +| receiver_input.csv | /tmp/receiver/receiver_input.config | Input for receiver. Make sure the file contains two id keys - id0 and id1. | +------------------------+------------------------------------------------+-------------------------------------------------------------------------------+ | sender_input.csv | /tmp/sender/sender_input.config | Input for sender. Make sure the file contains two id keys - id0 and id1. | +------------------------+------------------------------------------------+-------------------------------------------------------------------------------+ @@ -134,12 +134,12 @@ Run PSI In the first terminal, run the following command:: - docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config receiver/receiver.config" + docker run -it --rm --network host --mount type=bind,source=/tmp/receiver,target=/root/receiver --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config receiver/receiver.config" In the other terminal, run the following command simultaneously:: - docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender -w /root --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config sender/sender.config" + docker run -it --rm --network host --mount type=bind,source=/tmp/sender,target=/root/sender --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/psi-anolis8:0.1.0beta bash -c "./main --config sender/sender.config" Building from source diff --git a/examples/pir/BUILD.bazel b/examples/pir/BUILD.bazel index 56bfe1c8..20cf9d12 100644 --- a/examples/pir/BUILD.bazel +++ b/examples/pir/BUILD.bazel @@ -35,8 +35,7 @@ psi_cc_binary( ], deps = [ ":utils", - "//psi/pir", - "//psi/psi/core/labeled_psi", + "//psi/apsi:pir", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/utils:rand", ], @@ -50,9 +49,8 @@ psi_cc_binary( ], deps = [ ":utils", - "//psi/pir", - "//psi/psi/core/labeled_psi", - "//psi/psi/utils:serialize", + "//psi/apsi:pir", + "//psi/utils:serialize", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/utils:rand", ], @@ -66,9 +64,8 @@ psi_cc_binary( ], deps = [ ":utils", - "//psi/pir", - "//psi/psi/core/labeled_psi", - "//psi/psi/utils:serialize", + "//psi/apsi:pir", + "//psi/utils:serialize", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/utils:rand", ], @@ -82,9 +79,8 @@ psi_cc_binary( ], deps = [ ":utils", - "//psi/pir", - "//psi/psi/core/labeled_psi", - "//psi/psi/utils:serialize", + "//psi/apsi:pir", + "//psi/utils:serialize", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/utils:rand", "@yacl//yacl/io/rw:csv_writer", diff --git a/examples/pir/keyword_pir_client.cc b/examples/pir/keyword_pir_client.cc index 6a87112e..533849ee 100644 --- a/examples/pir/keyword_pir_client.cc +++ b/examples/pir/keyword_pir_client.cc @@ -25,11 +25,11 @@ #include "examples/pir/utils.h" #include "yacl/io/rw/csv_writer.h" -#include "psi/pir/pir.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/serialize.h" +#include "psi/apsi/pir.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/serialize.h" #include "psi/proto/pir.pb.h" @@ -56,15 +56,15 @@ int main(int argc, char **argv) { std::vector ids = absl::StrSplit(FLAGS_key_columns, ','); - psi::pir::PirClientConfig config; + psi::PirClientConfig config; - config.set_pir_protocol(psi::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_pir_protocol(psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); config.set_input_path(FLAGS_in_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); config.set_output_path(FLAGS_out_path); - psi::pir::PirResultReport report = psi::pir::PirClient(link_ctx, config); + psi::PirResultReport report = psi::apsi::PirClient(link_ctx, config); SPDLOG_INFO("data count:{}", report.data_count()); diff --git a/examples/pir/keyword_pir_mem_server.cc b/examples/pir/keyword_pir_mem_server.cc index 2ff6eb3f..10372954 100644 --- a/examples/pir/keyword_pir_mem_server.cc +++ b/examples/pir/keyword_pir_mem_server.cc @@ -26,12 +26,12 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/pir/pir.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/utils/serialize.h" +#include "psi/apsi/pir.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/ecdh//ecdh_oprf_selector.h" +#include "psi/utils/serialize.h" #include "psi/proto/pir.pb.h" @@ -72,10 +72,10 @@ int main(int argc, char **argv) { std::vector ids = absl::StrSplit(FLAGS_key_columns, ','); std::vector labels = absl::StrSplit(FLAGS_label_columns, ','); - psi::pir::PirSetupConfig config; + psi::PirSetupConfig config; - config.set_pir_protocol(psi::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(psi::pir::KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(psi::KvStoreType::LEVELDB_KV_STORE); config.set_input_path(FLAGS_in_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); @@ -89,8 +89,7 @@ int main(int argc, char **argv) { config.set_bucket_size(FLAGS_bucket); config.set_max_items_per_bin(FLAGS_max_items_per_bin); - psi::pir::PirResultReport report = - psi::pir::PirMemoryServer(link_ctx, config); + psi::PirResultReport report = psi::apsi::PirMemoryServer(link_ctx, config); SPDLOG_INFO("data count:{}", report.data_count()); diff --git a/examples/pir/keyword_pir_server.cc b/examples/pir/keyword_pir_server.cc index c969e466..2d9d49c2 100644 --- a/examples/pir/keyword_pir_server.cc +++ b/examples/pir/keyword_pir_server.cc @@ -26,12 +26,12 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/pir/pir.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/utils/serialize.h" +#include "psi/apsi/pir.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/ecdh//ecdh_oprf_selector.h" +#include "psi/utils/serialize.h" #include "psi/proto/pir.pb.h" @@ -57,15 +57,15 @@ int main(int argc, char **argv) { link_ctx->SetRecvTimeout(kLinkRecvTimeout); - psi::pir::PirServerConfig config; + psi::PirServerConfig config; - config.set_pir_protocol(psi::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(psi::pir::KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(psi::KvStoreType::LEVELDB_KV_STORE); config.set_oprf_key_path(FLAGS_oprf_key_path); config.set_setup_path(FLAGS_setup_path); - psi::pir::PirResultReport report = psi::pir::PirServer(link_ctx, config); + psi::PirResultReport report = psi::apsi::PirServer(link_ctx, config); SPDLOG_INFO("data count:{}", report.data_count()); diff --git a/examples/pir/keyword_pir_setup.cc b/examples/pir/keyword_pir_setup.cc index 8aecc0a9..f9f82dce 100644 --- a/examples/pir/keyword_pir_setup.cc +++ b/examples/pir/keyword_pir_setup.cc @@ -30,11 +30,11 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/pir/pir.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" +#include "psi/apsi/pir.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/ecdh//ecdh_oprf_selector.h" #include "psi/proto/pir.pb.h" @@ -74,10 +74,10 @@ int main(int argc, char **argv) { SPDLOG_INFO("key columns: {}", FLAGS_key_columns); SPDLOG_INFO("label columns: {}", FLAGS_label_columns); - psi::pir::PirSetupConfig config; + psi::PirSetupConfig config; - config.set_pir_protocol(psi::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(psi::pir::KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(psi::KvStoreType::LEVELDB_KV_STORE); config.set_input_path(FLAGS_in_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); @@ -91,7 +91,7 @@ int main(int argc, char **argv) { config.set_bucket_size(FLAGS_bucket); config.set_max_items_per_bin(FLAGS_max_items_per_bin); - psi::pir::PirResultReport report = psi::pir::PirSetup(config); + psi::PirResultReport report = psi::apsi::PirSetup(config); SPDLOG_INFO("data count:{}", report.data_count()); diff --git a/psi/BUILD.bazel b/psi/BUILD.bazel index 6b26509f..fda9b50f 100644 --- a/psi/BUILD.bazel +++ b/psi/BUILD.bazel @@ -16,6 +16,85 @@ load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") package(default_visibility = ["//visibility:public"]) +psi_cc_library( + name = "prelude", + hdrs = [ + "prelude.h", + ], + deps = [ + "//psi/proto:psi_cc_proto", + "//psi/proto:psi_v2_cc_proto", + ], +) + +psi_cc_library( + name = "interface", + srcs = ["interface.cc"], + hdrs = ["interface.h"], + deps = [ + ":trace_categories", + "//psi/legacy:bucket_psi", + "//psi/proto:psi_v2_cc_proto", + "//psi/utils:advanced_join", + "//psi/utils:index_store", + "//psi/utils:recovery", + "@boost//:uuid", + "@com_github_google_perfetto//:perfetto", + "@com_google_absl//absl/status", + "@yacl//yacl/link", + ], +) + +psi_cc_library( + name = "factory", + srcs = ["factory.cc"], + hdrs = ["factory.h"], + deps = [ + "//psi/ecdh:client", + "//psi/ecdh:receiver", + "//psi/ecdh:sender", + "//psi/ecdh:server", + "//psi/kkrt:receiver", + "//psi/kkrt:sender", + "//psi/rr22:receiver", + "//psi/rr22:sender", + "@yacl//yacl/base:exception", + ], +) + +psi_cc_library( + name = "launch", + srcs = ["launch.cc"], + hdrs = ["launch.h"], + deps = [ + ":factory", + ":trace_categories", + "//psi/legacy:bucket_psi", + "@boost//:algorithm", + "@boost//:uuid", + ], +) + +psi_cc_library( + name = "trace_categories", + srcs = ["trace_categories.cc"], + hdrs = ["trace_categories.h"], + deps = [ + "@com_github_google_perfetto//:perfetto", + ], +) + +psi_cc_test( + name = "psi_test", + srcs = ["psi_test.cc"], + deps = [ + ":factory", + "//psi/utils:arrow_csv_batch_provider", + "@boost//:uuid", + "@yacl//yacl/utils:scope_guard", + ], +) + psi_cc_library( name = "version", hdrs = ["version.h"], @@ -51,8 +130,8 @@ psi_cc_binary( deps = [ ":kuscia_adapter", ":version", + "//psi:launch", "//psi/proto:entry_cc_proto", - "//psi/psi:launch", "@com_github_gflags_gflags//:gflags", ], ) diff --git a/psi/psi/core/labeled_psi/BUILD.bazel b/psi/apsi/BUILD.bazel similarity index 52% rename from psi/psi/core/labeled_psi/BUILD.bazel rename to psi/apsi/BUILD.bazel index d0ada59f..7ef4e026 100644 --- a/psi/psi/core/labeled_psi/BUILD.bazel +++ b/psi/apsi/BUILD.bazel @@ -18,51 +18,115 @@ load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) +proto_library( + name = "serializable_proto", + srcs = ["serializable.proto"], +) + +cc_proto_library( + name = "serializable_cc_proto", + deps = [":serializable_proto"], +) + psi_cc_library( - name = "labeled_psi", - srcs = [ - "package.cc", - "psi_params.cc", - "receiver.cc", - "sender.cc", - "sender_db.cc", - "sender_kvdb.cc", - "sender_memdb.cc", - ], - hdrs = [ - "package.h", - "psi_params.h", - "receiver.h", - "sender.h", - "sender_db.h", - "sender_kvdb.h", - "sender_memdb.h", - "serialize.h", + name = "padding", + srcs = ["padding.cc"], + hdrs = ["padding.h"], + deps = [ + "@yacl//yacl/base:byte_container_view", + "@yacl//yacl/base:exception", ], +) + +psi_cc_library( + name = "package", + srcs = ["package.cc"], + hdrs = ["package.h"], + deps = [ + "@com_github_microsoft_apsi//:apsi", + "@yacl//yacl/base:exception", + ], +) + +psi_cc_library( + name = "psi_params", + srcs = ["psi_params.cc"], + hdrs = ["psi_params.h"], deps = [ - ":padding", ":serializable_cc_proto", - "//psi/psi/core/ecdh_oprf:ecdh_oprf_selector", - "//psi/psi/utils:batch_provider", "@com_github_microsoft_apsi//:apsi", - "@com_google_absl//absl/strings", "@yacl//yacl/base:exception", - "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/link", + ], +) + +psi_cc_library( + name = "serialize", + hdrs = ["serialize.h"], + deps = [ + ":serializable_cc_proto", + "@com_github_microsoft_apsi//:apsi", + ], +) + +psi_cc_library( + name = "sender_db", + srcs = ["sender_db.cc"], + hdrs = ["sender_db.h"], + deps = [ + ":serialize", + "//psi/ecdh:ecdh_oprf_selector", + "//psi/utils:batch_provider", + "@yacl//yacl/base:byte_container_view", + "@yacl//yacl/base:exception", "@yacl//yacl/io/kv:leveldb_kvstore", "@yacl//yacl/io/kv:memory_kvstore", - "@yacl//yacl/link", - "@yacl//yacl/utils:parallel", ], ) -proto_library( - name = "serializable_proto", - srcs = ["serializable.proto"], +psi_cc_library( + name = "sender_kvdb", + srcs = ["sender_kvdb.cc"], + hdrs = ["sender_kvdb.h"], + deps = [ + ":padding", + ":sender_db", + ], ) -cc_proto_library( - name = "serializable_cc_proto", - deps = [":serializable_proto"], +psi_cc_library( + name = "sender_memdb", + srcs = ["sender_memdb.cc"], + hdrs = ["sender_memdb.h"], + deps = [ + ":padding", + ":sender_db", + ], +) + +psi_cc_library( + name = "receiver", + srcs = ["receiver.cc"], + hdrs = ["receiver.h"], + deps = [ + ":package", + ":padding", + ":psi_params", + "//psi/ecdh:ecdh_oprf_selector", + ], +) + +psi_cc_library( + name = "sender", + srcs = ["sender.cc"], + hdrs = ["sender.h"], + deps = [ + ":package", + ":psi_params", + ":sender_kvdb", + ":sender_memdb", + "@yacl//yacl/link", + ], ) psi_cc_test( @@ -71,7 +135,8 @@ psi_cc_test( "apsi_test.cc", ], deps = [ - ":labeled_psi", + ":receiver", + ":sender", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:scope_guard", @@ -84,7 +149,8 @@ psi_cc_test( "apsi_label_test.cc", ], deps = [ - ":labeled_psi", + ":receiver", + ":sender", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:scope_guard", @@ -97,8 +163,9 @@ psi_cc_test( "kv_test.cc", ], deps = [ - ":labeled_psi", - "//psi/psi/utils:sync", + ":receiver", + ":sender", + "//psi/utils:sync", "@com_google_absl//absl/strings", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:scope_guard", @@ -106,12 +173,13 @@ psi_cc_test( ) psi_cc_binary( - name = "apsi_bench", + name = "apsi_benchmark", srcs = [ - "apsi_bench.cc", + "apsi_benchmark.cc", ], deps = [ - ":labeled_psi", + ":receiver", + ":sender", "@com_github_google_benchmark//:benchmark_main", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -120,11 +188,25 @@ psi_cc_binary( ) psi_cc_library( - name = "padding", - srcs = ["padding.cc"], - hdrs = ["padding.h"], + name = "pir", + srcs = ["pir.cc"], + hdrs = ["pir.h"], deps = [ - "@yacl//yacl/base:byte_container_view", - "@yacl//yacl/base:exception", + ":receiver", + ":sender", + "//psi/proto:pir_cc_proto", + "//psi/utils:serialize", + "//psi/utils:sync", + "@yacl//yacl/crypto/base/block_cipher:symmetric_crypto", + ], +) + +psi_cc_test( + name = "pir_test", + srcs = ["pir_test.cc"], + deps = [ + ":pir", + "//psi/utils:io", + "@yacl//yacl/utils:scope_guard", ], ) diff --git a/psi/psi/core/labeled_psi/README.md b/psi/apsi/README.md similarity index 100% rename from psi/psi/core/labeled_psi/README.md rename to psi/apsi/README.md diff --git a/psi/psi/core/labeled_psi/apsi_bench.cc b/psi/apsi/apsi_benchmark.cc similarity index 91% rename from psi/psi/core/labeled_psi/apsi_bench.cc rename to psi/apsi/apsi_benchmark.cc index 91199925..c24d33bd 100644 --- a/psi/psi/core/labeled_psi/apsi_bench.cc +++ b/psi/apsi/apsi_benchmark.cc @@ -26,13 +26,13 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/ecdh//ecdh_oprf_selector.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -140,17 +140,17 @@ static void BM_LabeledPsi(benchmark::State& state) { state.ResumeTiming(); - apsi::PSIParams psi_params = GetPsiParams(nr, ns); + ::apsi::PSIParams psi_params = GetPsiParams(nr, ns); // step 1: PsiParams Request and Response std::future f_sender_params = std::async([&] { return LabelPsiSender::RunPsiParams(ns, ctxs[0]); }); - std::future f_receiver_params = std::async( + std::future<::apsi::PSIParams> f_receiver_params = std::async( [&] { return LabelPsiReceiver::RequestPsiParams(nr, ctxs[1]); }); f_sender_params.get(); - apsi::PSIParams psi_params2 = f_receiver_params.get(); + ::apsi::PSIParams psi_params2 = f_receiver_params.get(); EXPECT_EQ(psi_params.table_params().table_size, psi_params2.table_params().table_size); @@ -194,8 +194,9 @@ static void BM_LabeledPsi(benchmark::State& state) { SPDLOG_INFO("after set db, bin_bundle_count:{}, packing_rate:{}", sender_db->GetBinBundleCount(), sender_db->GetPackingRate()); - std::unique_ptr oprf_server = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + CurveType::CURVE_FOURQ); LabelPsiSender sender(sender_db); @@ -208,13 +209,13 @@ static void BM_LabeledPsi(benchmark::State& state) { std::future f_sender_oprf = std::async( [&] { return sender.RunOPRF(std::move(oprf_server), ctxs[0]); }); - std::future< - std::pair, std::vector>> + std::future, + std::vector<::apsi::LabelKey>>> f_receiver_oprf = std::async( [&] { return receiver.RequestOPRF(receiver_items, ctxs[1]); }); f_sender_oprf.get(); - std::pair, std::vector> + std::pair, std::vector<::apsi::LabelKey>> oprf_pair = f_receiver_oprf.get(); const auto oprf_end = std::chrono::system_clock::now(); @@ -275,4 +276,4 @@ BENCHMARK(BM_LabeledPsi) BENCHMARK_MAIN(); -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/apsi_label_test.cc b/psi/apsi/apsi_label_test.cc similarity index 90% rename from psi/psi/core/labeled_psi/apsi_label_test.cc rename to psi/apsi/apsi_label_test.cc index 4088f860..2c8b4b7c 100644 --- a/psi/psi/core/labeled_psi/apsi_label_test.cc +++ b/psi/apsi/apsi_label_test.cc @@ -28,14 +28,14 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" -#include "psi/psi/core/labeled_psi/sender_memdb.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/apsi/sender_memdb.h" +#include "psi/ecdh//ecdh_oprf_selector.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -103,17 +103,17 @@ class LabelPsiTest : public testing::TestWithParam {}; TEST_P(LabelPsiTest, Works) { auto params = GetParam(); auto ctxs = yacl::link::test::SetupWorld(2); - apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); + ::apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); // step 1: PsiParams Request and Response std::future f_sender_params = std::async( [&] { return LabelPsiSender::RunPsiParams(params.ns, ctxs[0]); }); - std::future f_receiver_params = std::async( + std::future<::apsi::PSIParams> f_receiver_params = std::async( [&] { return LabelPsiReceiver::RequestPsiParams(params.nr, ctxs[1]); }); f_sender_params.get(); - apsi::PSIParams psi_params2 = f_receiver_params.get(); + ::apsi::PSIParams psi_params2 = f_receiver_params.get(); EXPECT_EQ(psi_params.table_params().table_size, psi_params2.table_params().table_size); @@ -183,7 +183,7 @@ TEST_P(LabelPsiTest, Works) { SPDLOG_INFO("after set db, bin_bundle_count:{}, packing_rate:{}", sender_db->GetBinBundleCount(), sender_db->GetPackingRate()); - const apsi::PSIParams apsi_params = sender_db->GetParams(); + const ::apsi::PSIParams apsi_params = sender_db->GetParams(); SPDLOG_INFO("params.bundle_idx_count={}", apsi_params.bundle_idx_count()); LabelPsiSender sender(sender_db); @@ -195,19 +195,20 @@ TEST_P(LabelPsiTest, Works) { const auto oprf_start = std::chrono::system_clock::now(); std::future f_sender_oprf = std::async([&] { - std::unique_ptr oprf_server = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + CurveType::CURVE_FOURQ); return sender.RunOPRF(std::move(oprf_server), ctxs[0]); }); std::future< - std::pair, std::vector>> + std::pair, std::vector<::apsi::LabelKey>>> f_receiver_oprf = std::async( [&] { return receiver.RequestOPRF(receiver_items, ctxs[1]); }); f_sender_oprf.get(); - std::pair, std::vector> + std::pair, std::vector<::apsi::LabelKey>> oprf_pair = f_receiver_oprf.get(); const auto oprf_end = std::chrono::system_clock::now(); @@ -266,4 +267,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, LabelPsiTest, #endif ); -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/apsi_test.cc b/psi/apsi/apsi_test.cc similarity index 90% rename from psi/psi/core/labeled_psi/apsi_test.cc rename to psi/apsi/apsi_test.cc index 1cb5c5bb..4666189e 100644 --- a/psi/psi/core/labeled_psi/apsi_test.cc +++ b/psi/apsi/apsi_test.cc @@ -28,14 +28,14 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" -#include "psi/psi/core/labeled_psi/sender_memdb.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/apsi/sender_memdb.h" +#include "psi/ecdh//ecdh_oprf_selector.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -94,17 +94,17 @@ class LabelPsiTest : public testing::TestWithParam {}; TEST_P(LabelPsiTest, Works) { auto params = GetParam(); auto ctxs = yacl::link::test::SetupWorld(2); - apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); + ::apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); // step 1: PsiParams Request and Response std::future f_sender_params = std::async( [&] { return LabelPsiSender::RunPsiParams(params.ns, ctxs[0]); }); - std::future f_receiver_params = std::async( + std::future<::apsi::PSIParams> f_receiver_params = std::async( [&] { return LabelPsiReceiver::RequestPsiParams(params.nr, ctxs[1]); }); f_sender_params.get(); - apsi::PSIParams psi_params2 = f_receiver_params.get(); + ::apsi::PSIParams psi_params2 = f_receiver_params.get(); EXPECT_EQ(psi_params.table_params().table_size, psi_params2.table_params().table_size); @@ -168,14 +168,15 @@ TEST_P(LabelPsiTest, Works) { SPDLOG_INFO("after set db, bin_bundle_count:{}, packing_rate:{}", sender_db->GetBinBundleCount(), sender_db->GetPackingRate()); - const apsi::PSIParams apsi_params = sender_db->GetParams(); + const ::apsi::PSIParams apsi_params = sender_db->GetParams(); SPDLOG_INFO("params.bundle_idx_count={}", apsi_params.bundle_idx_count()); for (size_t i = 0; i < apsi_params.bundle_idx_count(); ++i) { SPDLOG_INFO("i={},count={}", i, sender_db->GetBinBundleCount(i)); } - std::unique_ptr oprf_server = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + CurveType::CURVE_FOURQ); LabelPsiSender sender(sender_db); @@ -189,12 +190,12 @@ TEST_P(LabelPsiTest, Works) { [&] { return sender.RunOPRF(std::move(oprf_server), ctxs[0]); }); std::future< - std::pair, std::vector>> + std::pair, std::vector<::apsi::LabelKey>>> f_receiver_oprf = std::async( [&] { return receiver.RequestOPRF(receiver_items, ctxs[1]); }); f_sender_oprf.get(); - std::pair, std::vector> + std::pair, std::vector<::apsi::LabelKey>> oprf_pair = f_receiver_oprf.get(); const auto oprf_end = std::chrono::system_clock::now(); @@ -253,4 +254,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, LabelPsiTest, #endif ); -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/kv_test.cc b/psi/apsi/kv_test.cc similarity index 90% rename from psi/psi/core/labeled_psi/kv_test.cc rename to psi/apsi/kv_test.cc index 680798b9..dae6ba3e 100644 --- a/psi/psi/core/labeled_psi/kv_test.cc +++ b/psi/apsi/kv_test.cc @@ -28,15 +28,15 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/padding.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" -#include "psi/psi/utils/sync.h" +#include "psi/apsi/padding.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/ecdh//ecdh_oprf_selector.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -63,25 +63,25 @@ std::vector GenerateData(size_t seed, size_t item_count) { return items; } -std::vector GenerateSenderData( +std::vector<::apsi::Item> GenerateSenderData( size_t seed, size_t item_count, const absl::Span &receiver_items, std::vector *intersection_idx) { - std::vector sender_items; + std::vector<::apsi::Item> sender_items; yacl::crypto::Prg prg(seed); for (size_t i = 0; i < item_count; ++i) { - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; prg.Fill(absl::MakeSpan(value)); sender_items.emplace_back(value); } for (size_t i = 0; i < receiver_items.size(); i += 3) { - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; std::memcpy(value.data(), receiver_items[i].data(), receiver_items[i].length()); - apsi::Item item(value); + ::apsi::Item item(value); sender_items[kPsiStartPos + i * 5] = item; (*intersection_idx).emplace_back(i); } @@ -89,18 +89,18 @@ std::vector GenerateSenderData( return sender_items; } -std::vector> GenerateSenderData( +std::vector> GenerateSenderData( size_t seed, size_t item_count, size_t label_byte_count, const absl::Span &receiver_items, std::vector *intersection_idx, std::vector *intersection_label) { - std::vector> sender_items; + std::vector> sender_items; yacl::crypto::Prg prg(seed); for (size_t i = 0; i < item_count; ++i) { - apsi::Item item; - apsi::Label label; + ::apsi::Item item; + ::apsi::Label label; label.resize(label_byte_count); prg.Fill(absl::MakeSpan(item.value())); prg.Fill(absl::MakeSpan(label)); @@ -108,7 +108,7 @@ std::vector> GenerateSenderData( } for (size_t i = 0; i < receiver_items.size(); i += 3) { - apsi::Item item; + ::apsi::Item item; std::memcpy(item.value().data(), receiver_items[i].data(), receiver_items[i].length()); @@ -149,17 +149,17 @@ TEST_P(LabelPsiTest, Works) { SPDLOG_INFO("d1 len:{} d2 len:{}", d1.size(), d2.length()); } - apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); + ::apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); // step 1: PsiParams Request and Response std::future f_sender_params = std::async( [&] { return LabelPsiSender::RunPsiParams(params.ns, ctxs[0]); }); - std::future f_receiver_params = std::async( + std::future<::apsi::PSIParams> f_receiver_params = std::async( [&] { return LabelPsiReceiver::RequestPsiParams(params.nr, ctxs[1]); }); f_sender_params.get(); - apsi::PSIParams psi_params2 = f_receiver_params.get(); + ::apsi::PSIParams psi_params2 = f_receiver_params.get(); EXPECT_EQ(psi_params.table_params().table_size, psi_params2.table_params().table_size); @@ -203,7 +203,7 @@ TEST_P(LabelPsiTest, Works) { const auto setdb_start = std::chrono::system_clock::now(); if (params.label_bytes == 0) { - std::vector sender_items = GenerateSenderData( + std::vector<::apsi::Item> sender_items = GenerateSenderData( rd(), item_count, absl::MakeSpan(receiver_items), &intersection_idx); // sender_db->SetData(sender_items); @@ -222,7 +222,7 @@ TEST_P(LabelPsiTest, Works) { sender_db->SetData(batch_provider); } else { - std::vector> sender_items = + std::vector> sender_items = GenerateSenderData(rd(), item_count, label_byte_count - 6, absl::MakeSpan(receiver_items), &intersection_idx, &intersection_label); @@ -258,7 +258,7 @@ TEST_P(LabelPsiTest, Works) { SPDLOG_INFO("after set db, bin_bundle_count:{}, packing_rate:{}", sender_db->GetBinBundleCount(), sender_db->GetPackingRate()); - const apsi::PSIParams &apsi_params = sender_db->GetParams(); + const ::apsi::PSIParams &apsi_params = sender_db->GetParams(); SPDLOG_INFO("bundle_idx_count:{}", apsi_params.bundle_idx_count()); SPDLOG_INFO("BinBundleCount:{}", sender_db->GetBinBundleCount()); for (size_t i = 0; i < apsi_params.bundle_idx_count(); ++i) { @@ -281,12 +281,12 @@ TEST_P(LabelPsiTest, Works) { [&] { return sender.RunOPRF(std::move(oprf_server), ctxs[0]); }); std::future< - std::pair, std::vector>> + std::pair, std::vector<::apsi::LabelKey>>> f_receiver_oprf = std::async( [&] { return receiver.RequestOPRF(receiver_items, ctxs[1]); }); f_sender_oprf.get(); - std::pair, std::vector> + std::pair, std::vector<::apsi::LabelKey>> oprf_pair = f_receiver_oprf.get(); const auto oprf_end = std::chrono::system_clock::now(); @@ -360,4 +360,4 @@ INSTANTIATE_TEST_SUITE_P( #endif ); -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/package.cc b/psi/apsi/package.cc similarity index 93% rename from psi/psi/core/labeled_psi/package.cc rename to psi/apsi/package.cc index 55d262ae..47e7daf8 100644 --- a/psi/psi/core/labeled_psi/package.cc +++ b/psi/apsi/package.cc @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/labeled_psi/package.h" +#include "psi/apsi/package.h" #include #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi::apsi { PlainResultPackage ResultPackage::extract( - const apsi::CryptoContext &crypto_context) { + const ::apsi::CryptoContext &crypto_context) { YACL_ENFORCE(crypto_context.decryptor(), "decryptor is not configured in CryptoContext"); @@ -60,4 +60,4 @@ PlainResultPackage ResultPackage::extract( return plain_rp; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/package.h b/psi/apsi/package.h similarity index 83% rename from psi/psi/core/labeled_psi/package.h rename to psi/apsi/package.h index 6139e8fd..f1ec893c 100644 --- a/psi/psi/core/labeled_psi/package.h +++ b/psi/apsi/package.h @@ -21,7 +21,7 @@ #include "gsl/span" #include "seal/seal.h" -namespace psi::psi { +namespace psi::apsi { struct PlainResultPackage { std::uint32_t bundle_idx; @@ -37,19 +37,19 @@ struct PlainResultPackage { class ResultPackage { public: - PlainResultPackage extract(const apsi::CryptoContext& crypto_context); + PlainResultPackage extract(const ::apsi::CryptoContext& crypto_context); std::uint32_t bundle_idx; seal::compr_mode_type compr_mode = seal::Serialization::compr_mode_default; - apsi::SEALObject psi_result; + ::apsi::SEALObject psi_result; std::uint32_t label_byte_count; std::uint32_t nonce_byte_count; - std::vector> label_result; + std::vector<::apsi::SEALObject> label_result; }; // struct ResultPackage -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/padding.cc b/psi/apsi/padding.cc similarity index 94% rename from psi/psi/core/labeled_psi/padding.cc rename to psi/apsi/padding.cc index d2f1fcba..e17c0138 100644 --- a/psi/psi/core/labeled_psi/padding.cc +++ b/psi/apsi/padding.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/labeled_psi/padding.h" +#include "psi/apsi/padding.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi::apsi { // pad data to max_len bytes // format len(32bit)||data||00..00 @@ -45,4 +45,4 @@ std::string UnPaddingData(yacl::ByteContainerView data) { return ret; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/padding.h b/psi/apsi/padding.h similarity index 94% rename from psi/psi/core/labeled_psi/padding.h rename to psi/apsi/padding.h index a81750bf..d596cbf5 100644 --- a/psi/psi/core/labeled_psi/padding.h +++ b/psi/apsi/padding.h @@ -16,10 +16,10 @@ #include "yacl/base/byte_container_view.h" -namespace psi::psi { +namespace psi::apsi { std::vector PaddingData(yacl::ByteContainerView data, size_t max_len); std::string UnPaddingData(yacl::ByteContainerView data); -} // namespace psi::psi \ No newline at end of file +} // namespace psi::apsi diff --git a/psi/pir/pir.cc b/psi/apsi/pir.cc similarity index 77% rename from psi/pir/pir.cc rename to psi/apsi/pir.cc index 48389a87..6db91d54 100644 --- a/psi/pir/pir.cc +++ b/psi/apsi/pir.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/pir.h" +#include "psi/apsi/pir.h" #include #include @@ -26,19 +26,19 @@ #include "yacl/io/kv/leveldb_kvstore.h" #include "yacl/io/rw/csv_writer.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/receiver.h" -#include "psi/psi/core/labeled_psi/sender.h" -#include "psi/psi/core/labeled_psi/sender_db.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" -#include "psi/psi/core/labeled_psi/sender_memdb.h" -#include "psi/psi/cryptor/ecc_cryptor.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/apsi/receiver.h" +#include "psi/apsi/sender.h" +#include "psi/apsi/sender_db.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/apsi/sender_memdb.h" +#include "psi/cryptor/ecc_cryptor.h" +#include "psi/ecdh//ecdh_oprf_selector.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/io.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::pir { +namespace psi::apsi { namespace { @@ -49,15 +49,13 @@ std::vector ReadEcSecretKeyFile(const std::string &file_path) { } catch (std::filesystem::filesystem_error &e) { YACL_THROW("ReadEcSecretKeyFile {} Error: {}", file_path, e.what()); } - YACL_ENFORCE(file_byte_size == ::psi::psi::kEccKeySize, - "error format: key file bytes is not {}", - ::psi::psi::kEccKeySize); + YACL_ENFORCE(file_byte_size == ::psi::kEccKeySize, + "error format: key file bytes is not {}", ::psi::kEccKeySize); - std::vector secret_key(::psi::psi::kEccKeySize); + std::vector secret_key(::psi::kEccKeySize); - auto in = ::psi::psi::io::BuildInputStream( - ::psi::psi::io::FileIoOptions(file_path)); - in->Read(secret_key.data(), ::psi::psi::kEccKeySize); + auto in = ::psi::io::BuildInputStream(::psi::io::FileIoOptions(file_path)); + in->Read(secret_key.data(), ::psi::kEccKeySize); in->Close(); return secret_key; @@ -67,8 +65,8 @@ size_t CsvFileDataCount(const std::string &file_path, const std::vector &ids) { size_t data_count = 0; - std::shared_ptr<::psi::psi::IBasicBatchProvider> batch_provider = - std::make_shared<::psi::psi::CsvBatchProvider>(file_path, ids, 4096); + std::shared_ptr<::psi::IBasicBatchProvider> batch_provider = + std::make_shared<::psi::CsvBatchProvider>(file_path, ids, 4096); while (true) { auto batch = batch_provider->ReadNextBatch(); @@ -96,7 +94,7 @@ constexpr size_t kNonceByteCount = 16; void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, size_t count_per_query, size_t label_byte_count, const std::vector &label_cloumns, - const apsi::PSIParams &psi_params, size_t bucket_count, + const ::apsi::PSIParams &psi_params, size_t bucket_count, size_t bucket_size = 1000000, bool compressed = false) { std::string meta_store_name = fmt::format("{}/{}", setup_path, kMetaInfoStoreName); @@ -110,7 +108,7 @@ void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, meta_info_store->Put(kBucketSize, fmt::format("{}", bucket_size)); meta_info_store->Put(kCompressed, fmt::format("{}", compressed ? 1 : 0)); - ::psi::psi::proto::StrItemsProto proto; + ::psi::proto::StrItemsProto proto; for (const auto &label_cloumn : label_cloumns) { proto.add_items(label_cloumn); } @@ -119,7 +117,7 @@ void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, meta_info_store->Put(kLabelColumns, buf); - yacl::Buffer params_buffer = ::psi::psi::PsiParamsToBuffer(psi_params); + yacl::Buffer params_buffer = PsiParamsToBuffer(psi_params); meta_info_store->Put(kPsiParams, params_buffer); meta_info_store->Put(kBucketCount, fmt::format("{}", bucket_count)); @@ -136,12 +134,13 @@ size_t GetSizeFromStore( return key_value; } -apsi::PSIParams ReadMetaInfo(const std::string &setup_path, - size_t *server_data_count, size_t *count_per_query, - size_t *label_byte_count, - std::vector *label_cloumns, - size_t *bucket_count, size_t *bucket_size, - bool *compressed) { +::apsi::PSIParams ReadMetaInfo(const std::string &setup_path, + size_t *server_data_count, + size_t *count_per_query, + size_t *label_byte_count, + std::vector *label_cloumns, + size_t *bucket_count, size_t *bucket_size, + bool *compressed) { std::string meta_store_name = fmt::format("{}/{}", setup_path, kMetaInfoStoreName); std::shared_ptr meta_info_store = @@ -153,7 +152,7 @@ apsi::PSIParams ReadMetaInfo(const std::string &setup_path, yacl::Buffer label_columns_buf; meta_info_store->Get(kLabelColumns, &label_columns_buf); - ::psi::psi::proto::StrItemsProto proto; + ::psi::proto::StrItemsProto proto; proto.ParseFromArray(label_columns_buf.data(), label_columns_buf.size()); (*label_cloumns).reserve(proto.items_size()); for (auto item : proto.items()) { @@ -162,7 +161,7 @@ apsi::PSIParams ReadMetaInfo(const std::string &setup_path, yacl::Buffer params_buffer; meta_info_store->Get(kPsiParams, ¶ms_buffer); - apsi::PSIParams psi_params = ::psi::psi::ParsePsiParamsProto(params_buffer); + ::apsi::PSIParams psi_params = ParsePsiParamsProto(params_buffer); *bucket_count = GetSizeFromStore(meta_info_store, kBucketCount); *bucket_size = GetSizeFromStore(meta_info_store, kBucketSize); @@ -214,8 +213,8 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { } std::filesystem::create_directory(kv_store_path); - apsi::PSIParams psi_params = ::psi::psi::GetPsiParams( - count_per_query, bucket_size, config.max_items_per_bin()); + ::apsi::PSIParams psi_params = + GetPsiParams(count_per_query, bucket_size, config.max_items_per_bin()); SPDLOG_INFO("table_params hash_func_count:{}", psi_params.table_params().hash_func_count); @@ -231,8 +230,8 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { label_byte_count, label_columns, psi_params, bucket_count, config.bucket_size(), config.compressed()); - std::shared_ptr<::psi::psi::ILabeledBatchProvider> batch_provider = - std::make_shared<::psi::psi::CsvBatchProvider>( + std::shared_ptr<::psi::ILabeledBatchProvider> batch_provider = + std::make_shared<::psi::CsvBatchProvider>( config.input_path(), key_columns, bucket_size, label_columns); for (size_t i = 0; i < bucket_count; i++) { std::string bucket_setup_path = @@ -251,18 +250,17 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { std::filesystem::create_directory(bucket_setup_path); - apsi::PSIParams bucket_psi_params = ::psi::psi::GetPsiParams( + ::apsi::PSIParams bucket_psi_params = GetPsiParams( count_per_query, std::min(bucket_size, batch_ids.size()), config.max_items_per_bin()); - std::shared_ptr<::psi::psi::ISenderDB> sender_db = - std::make_shared<::psi::psi::SenderKvDB>( - bucket_psi_params, oprf_key, bucket_setup_path, label_byte_count, - nonce_byte_count, config.compressed()); + std::shared_ptr sender_db = std::make_shared( + bucket_psi_params, oprf_key, bucket_setup_path, label_byte_count, + nonce_byte_count, config.compressed()); - std::shared_ptr<::psi::psi::IBatchProvider> bucket_batch_provider = - std::make_shared<::psi::psi::MemoryBatchProvider>( - batch_ids, bucket_size, batch_labels); + std::shared_ptr<::psi::IBatchProvider> bucket_batch_provider = + std::make_shared<::psi::MemoryBatchProvider>(batch_ids, bucket_size, + batch_labels); sender_db->SetData(bucket_batch_provider); SPDLOG_INFO("finish bucket:{}", i); @@ -276,30 +274,29 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { PirResultReport LabeledPirServer( const std::shared_ptr &link_ctx, - const std::shared_ptr<::psi::psi::ISenderDB> &sender_db, - const std::vector &oprf_key, const apsi::PSIParams &psi_params, + const std::shared_ptr &sender_db, + const std::vector &oprf_key, const ::apsi::PSIParams &psi_params, const std::vector &label_columns, size_t bucket_count, size_t /*server_data_count*/, size_t count_per_query, size_t /*label_byte_count*/, uint32_t /*bucket_size*/) { // send count_per_query link_ctx->SendAsync(link_ctx->NextRank(), - ::psi::psi::utils::SerializeSize(count_per_query), + ::psi::utils::SerializeSize(count_per_query), fmt::format("count_per_query:{}", count_per_query)); - yacl::Buffer labels_buffer = - ::psi::psi::utils::SerializeStrItems(label_columns); + yacl::Buffer labels_buffer = ::psi::utils::SerializeStrItems(label_columns); // send labels column name link_ctx->SendAsync(link_ctx->NextRank(), labels_buffer, fmt::format("send label columns name")); // send psi params - yacl::Buffer params_buffer = ::psi::psi::PsiParamsToBuffer(psi_params); + yacl::Buffer params_buffer = PsiParamsToBuffer(psi_params); link_ctx->SendAsync(link_ctx->NextRank(), params_buffer, fmt::format("send psi params")); // bucket_count link_ctx->SendAsync(link_ctx->NextRank(), - ::psi::psi::utils::SerializeSize(bucket_count), + ::psi::utils::SerializeSize(bucket_count), fmt::format("bucket_count:{}", bucket_count)); // const auto total_query_start = std::chrono::system_clock::now(); @@ -307,11 +304,11 @@ PirResultReport LabeledPirServer( size_t query_count = 0; size_t data_count = 0; - ::psi::psi::LabelPsiSender sender(sender_db); + LabelPsiSender sender(sender_db); while (true) { // recv current batch_size - size_t batch_data_size = ::psi::psi::utils::DeserializeSize( + size_t batch_data_size = ::psi::utils::DeserializeSize( link_ctx->Recv(link_ctx->NextRank(), fmt::format("batch_data_size"))); SPDLOG_INFO("client data size: {}", batch_data_size); @@ -321,9 +318,9 @@ PirResultReport LabeledPirServer( data_count += batch_data_size; // oprf - std::unique_ptr<::psi::psi::IEcdhOprfServer> oprf_server = - ::psi::psi::CreateEcdhOprfServer(oprf_key, ::psi::psi::OprfType::Basic, - ::psi::psi::CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + ::psi::CurveType::CURVE_FOURQ); // const auto oprf_start = std::chrono::system_clock::now(); sender.RunOPRF(std::move(oprf_server), link_ctx); @@ -362,7 +359,7 @@ PirResultReport LabeledPirServer( size_t bucket_size; bool compressed; - apsi::PSIParams psi_params = + ::apsi::PSIParams psi_params = ReadMetaInfo(config.setup_path(), &server_data_count, &count_per_query, &label_byte_count, &label_columns, &bucket_count, &bucket_size, &compressed); @@ -373,52 +370,51 @@ PirResultReport LabeledPirServer( // server and client sync auto run_f = std::async([&] { return 0; }); - ::psi::psi::SyncWait(link_ctx, &run_f); + ::psi::SyncWait(link_ctx, &run_f); SPDLOG_INFO("table_params hash_func_count:{}", psi_params.table_params().hash_func_count); size_t nonce_byte_count = kNonceByteCount; - std::vector> sender_db(bucket_count); - std::vector> sender(bucket_count); + std::vector> sender_db(bucket_count); + std::vector> sender(bucket_count); - std::unique_ptr<::psi::psi::IEcdhOprfServer> oprf_server = - ::psi::psi::CreateEcdhOprfServer(oprf_key, ::psi::psi::OprfType::Basic, - ::psi::psi::CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + ::psi::CurveType::CURVE_FOURQ); for (size_t bucket_idx = 0; bucket_idx < bucket_count; ++bucket_idx) { std::string bucket_setup_path = fmt::format("{}/bucket_{}", config.setup_path(), bucket_idx); - sender_db[bucket_idx] = std::make_shared<::psi::psi::SenderKvDB>( + sender_db[bucket_idx] = std::make_shared( psi_params, oprf_key, bucket_setup_path, label_byte_count, nonce_byte_count, compressed); sender[bucket_idx] = - std::make_shared<::psi::psi::LabelPsiSender>(sender_db[bucket_idx]); + std::make_shared(sender_db[bucket_idx]); } SPDLOG_INFO("db GetItemCount:{}", sender_db[0]->GetItemCount()); // send count_per_query link_ctx->SendAsync(link_ctx->NextRank(), - ::psi::psi::utils::SerializeSize(count_per_query), + ::psi::utils::SerializeSize(count_per_query), fmt::format("count_per_query:{}", count_per_query)); - yacl::Buffer labels_buffer = - ::psi::psi::utils::SerializeStrItems(label_columns); + yacl::Buffer labels_buffer = ::psi::utils::SerializeStrItems(label_columns); // send labels column name link_ctx->SendAsync(link_ctx->NextRank(), labels_buffer, fmt::format("send label columns name")); // send psi params - yacl::Buffer params_buffer = ::psi::psi::PsiParamsToBuffer(psi_params); + yacl::Buffer params_buffer = PsiParamsToBuffer(psi_params); link_ctx->SendAsync(link_ctx->NextRank(), params_buffer, fmt::format("send psi params")); // send bucket_count link_ctx->SendAsync(link_ctx->NextRank(), - ::psi::psi::utils::SerializeSize(bucket_count), + ::psi::utils::SerializeSize(bucket_count), fmt::format("bucket_count:{}", bucket_count)); // const auto total_query_start = std::chrono::system_clock::now(); @@ -428,7 +424,7 @@ PirResultReport LabeledPirServer( while (true) { // recv current batch_size - size_t batch_data_size = ::psi::psi::utils::DeserializeSize( + size_t batch_data_size = ::psi::utils::DeserializeSize( link_ctx->Recv(link_ctx->NextRank(), fmt::format("batch_data_size"))); SPDLOG_INFO("client data size: {}", batch_data_size); @@ -438,9 +434,9 @@ PirResultReport LabeledPirServer( data_count += batch_data_size; // oprf - std::unique_ptr<::psi::psi::IEcdhOprfServer> oprf_server = - ::psi::psi::CreateEcdhOprfServer(oprf_key, ::psi::psi::OprfType::Basic, - ::psi::psi::CurveType::CURVE_FOURQ); + std::unique_ptr oprf_server = + ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + ::psi::CurveType::CURVE_FOURQ); // const auto oprf_start = std::chrono::system_clock::now(); sender[0]->RunOPRF(std::move(oprf_server), link_ctx); @@ -487,31 +483,29 @@ PirResultReport LabeledPirMemoryServer( YACL_ENFORCE(server_data_count <= config.bucket_size(), "data_count:{} bucket_size:{}", config.bucket_size()); - apsi::PSIParams psi_params = ::psi::psi::GetPsiParams( + ::apsi::PSIParams psi_params = GetPsiParams( count_per_query, server_data_count, config.max_items_per_bin()); - std::vector oprf_key = - yacl::crypto::RandBytes(::psi::psi::kEccKeySize); + std::vector oprf_key = yacl::crypto::RandBytes(::psi::kEccKeySize); size_t label_byte_count = config.label_max_len(); size_t nonce_byte_count = kNonceByteCount; - std::shared_ptr<::psi::psi::ISenderDB> sender_db = - std::make_shared<::psi::psi::SenderMemDB>( - psi_params, oprf_key, label_byte_count, nonce_byte_count, - config.compressed()); + std::shared_ptr sender_db = + std::make_shared(psi_params, oprf_key, label_byte_count, + nonce_byte_count, config.compressed()); // server and client sync auto run_f = std::async([&] { - std::shared_ptr<::psi::psi::IBatchProvider> batch_provider = - std::make_shared<::psi::psi::CsvBatchProvider>( + std::shared_ptr<::psi::IBatchProvider> batch_provider = + std::make_shared<::psi::CsvBatchProvider>( config.input_path(), key_columns, 500000, label_columns); sender_db->SetData(batch_provider); return 0; }); - ::psi::psi::SyncWait(link_ctx, &run_f); + ::psi::SyncWait(link_ctx, &run_f); SPDLOG_INFO("sender_db->GetItemCount:{}", sender_db->GetItemCount()); @@ -534,10 +528,10 @@ PirResultReport LabeledPirClient( // server and client sync auto run_f = std::async([&] { return 0; }); - ::psi::psi::SyncWait(link_ctx, &run_f); + ::psi::SyncWait(link_ctx, &run_f); // recv count_per_query - size_t count_per_query = ::psi::psi::utils::DeserializeSize( + size_t count_per_query = ::psi::utils::DeserializeSize( link_ctx->Recv(link_ctx->NextRank(), fmt::format("count_per_query"))); YACL_ENFORCE(count_per_query > 0, "Invalid nr:{}", count_per_query); @@ -547,8 +541,7 @@ PirResultReport LabeledPirClient( link_ctx->NextRank(), fmt::format("recv label columns name")); std::vector label_columns_name; - ::psi::psi::utils::DeserializeStrItems(label_columns_buffer, - &label_columns_name); + ::psi::utils::DeserializeStrItems(label_columns_buffer, &label_columns_name); yacl::io::Schema s; for (size_t i = 0; i < key_columns.size(); ++i) { @@ -565,8 +558,8 @@ PirResultReport LabeledPirClient( yacl::io::WriterOptions w_op; w_op.file_schema = s; - auto out = ::psi::psi::io::BuildOutputStream( - ::psi::psi::io::FileIoOptions(config.output_path())); + auto out = ::psi::io::BuildOutputStream( + ::psi::io::FileIoOptions(config.output_path())); yacl::io::CsvWriter writer(w_op, std::move(out)); writer.Init(); @@ -574,7 +567,7 @@ PirResultReport LabeledPirClient( yacl::Buffer params_buffer = link_ctx->Recv(link_ctx->NextRank(), fmt::format("recv psi params")); - apsi::PSIParams psi_params = ::psi::psi::ParsePsiParamsProto(params_buffer); + ::apsi::PSIParams psi_params = ParsePsiParamsProto(params_buffer); SPDLOG_INFO("table_params hash_func_count:{}", psi_params.table_params().hash_func_count); @@ -586,11 +579,11 @@ PirResultReport LabeledPirClient( SPDLOG_INFO("query_params query_powers size:{}", psi_params.query_params().query_powers.size()); - size_t bucket_count = ::psi::psi::utils::DeserializeSize( + size_t bucket_count = ::psi::utils::DeserializeSize( link_ctx->Recv(link_ctx->NextRank(), fmt::format("bucket_count"))); SPDLOG_INFO("bucket_count:{}", bucket_count); - ::psi::psi::LabelPsiReceiver receiver(psi_params, true); + LabelPsiReceiver receiver(psi_params, true); // const auto total_query_start = std::chrono::system_clock::now(); @@ -604,9 +597,9 @@ PirResultReport LabeledPirClient( size_t batch_read_count = std::max(table_items, count_per_query); SPDLOG_INFO("batch_read_count:{}", batch_read_count); - std::shared_ptr<::psi::psi::IBasicBatchProvider> query_batch_provider = - std::make_shared<::psi::psi::CsvBatchProvider>( - config.input_path(), key_columns, batch_read_count); + std::shared_ptr<::psi::IBasicBatchProvider> query_batch_provider = + std::make_shared<::psi::CsvBatchProvider>(config.input_path(), + key_columns, batch_read_count); while (true) { auto query_batch_items = query_batch_provider->ReadNextBatch(); @@ -614,7 +607,7 @@ PirResultReport LabeledPirClient( // send count_batch_size link_ctx->SendAsync( link_ctx->NextRank(), - ::psi::psi::utils::SerializeSize(query_batch_items.size()), + ::psi::utils::SerializeSize(query_batch_items.size()), fmt::format("count_batch_size:{}", query_batch_items.size())); if (query_batch_items.empty()) { @@ -623,7 +616,7 @@ PirResultReport LabeledPirClient( data_count += query_batch_items.size(); // const auto oprf_start = std::chrono::system_clock::now(); - std::pair, std::vector> + std::pair, std::vector<::apsi::LabelKey>> items_oprf = receiver.RequestOPRF(query_batch_items, link_ctx); // const auto oprf_end = std::chrono::system_clock::now(); @@ -700,7 +693,7 @@ PirResultReport LabeledPirClient( } PirResultReport PirSetup(const PirSetupConfig &config) { - if (config.pir_protocol() != KEYWORD_PIR_LABELED_PSI) { + if (config.pir_protocol() != ::psi::KEYWORD_PIR_LABELED_PSI) { YACL_THROW("Unsupported pir protocol {}", PirProtocol_Name(config.pir_protocol())); } @@ -710,7 +703,7 @@ PirResultReport PirSetup(const PirSetupConfig &config) { PirResultReport PirServer(const std::shared_ptr &link_ctx, const PirServerConfig &config) { - if (config.pir_protocol() != KEYWORD_PIR_LABELED_PSI) { + if (config.pir_protocol() != ::psi::KEYWORD_PIR_LABELED_PSI) { YACL_THROW("Unsupported pir protocol {}", PirProtocol_Name(config.pir_protocol())); } @@ -721,7 +714,7 @@ PirResultReport PirServer(const std::shared_ptr &link_ctx, PirResultReport PirMemoryServer( const std::shared_ptr &link_ctx, const PirSetupConfig &config) { - if (config.pir_protocol() != KEYWORD_PIR_LABELED_PSI) { + if (config.pir_protocol() != ::psi::KEYWORD_PIR_LABELED_PSI) { YACL_THROW("Unsupported pir protocol {}", PirProtocol_Name(config.pir_protocol())); } @@ -731,7 +724,7 @@ PirResultReport PirMemoryServer( PirResultReport PirClient(const std::shared_ptr &link_ctx, const PirClientConfig &config) { - if (config.pir_protocol() != KEYWORD_PIR_LABELED_PSI) { + if (config.pir_protocol() != ::psi::KEYWORD_PIR_LABELED_PSI) { YACL_THROW("Unsupported pir protocol {}", PirProtocol_Name(config.pir_protocol())); } @@ -739,4 +732,4 @@ PirResultReport PirClient(const std::shared_ptr &link_ctx, return LabeledPirClient(link_ctx, config); } -} // namespace psi::pir +} // namespace psi::apsi diff --git a/psi/pir/pir.h b/psi/apsi/pir.h similarity index 94% rename from psi/pir/pir.h rename to psi/apsi/pir.h index 6f70b7e6..e12fc081 100644 --- a/psi/pir/pir.h +++ b/psi/apsi/pir.h @@ -18,11 +18,11 @@ #include "yacl/link/link.h" -#include "psi/psi/core/labeled_psi/sender_db.h" +#include "psi/apsi/sender_db.h" #include "psi/proto/pir.pb.h" -namespace psi::pir { +namespace psi::apsi { PirResultReport PirSetup(const PirSetupConfig &config); @@ -51,4 +51,4 @@ PirResultReport LabeledPirClient( const std::shared_ptr &link_ctx, const PirClientConfig &config); -} // namespace psi::pir +} // namespace psi::apsi diff --git a/psi/pir/pir_test.cc b/psi/apsi/pir_test.cc similarity index 80% rename from psi/pir/pir_test.cc rename to psi/apsi/pir_test.cc index b65c66c9..d6d7e1c9 100644 --- a/psi/pir/pir_test.cc +++ b/psi/apsi/pir_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/pir.h" +#include "psi/apsi/pir.h" #include #include @@ -27,8 +27,8 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/utils/io.h" +#include "psi/apsi/psi_params.h" +#include "psi/utils/io.h" namespace { @@ -46,8 +46,7 @@ struct TestParams { void WriteCsvFile(const std::string &file_name, const std::string &id_name, const std::vector &items) { - auto out = - psi::psi::io::BuildOutputStream(psi::psi::io::FileIoOptions(file_name)); + auto out = psi::io::BuildOutputStream(psi::io::FileIoOptions(file_name)); out->Write(fmt::format("{}\n", id_name)); for (size_t i = 0; i < items.size(); ++i) { out->Write(fmt::format("{}\n", items[i])); @@ -59,8 +58,7 @@ void WriteCsvFile(const std::string &file_name, const std::string &id_name, const std::string &label_name, const std::vector &items, const std::vector &labels) { - auto out = - psi::psi::io::BuildOutputStream(psi::psi::io::FileIoOptions(file_name)); + auto out = psi::io::BuildOutputStream(psi::io::FileIoOptions(file_name)); out->Write(fmt::format("{},{}\n", id_name, label_name)); for (size_t i = 0; i < items.size(); ++i) { out->Write(fmt::format("{},{}\n", items[i], labels[i])); @@ -126,14 +124,14 @@ GenerateSenderData(size_t seed, size_t item_count, size_t label_byte_count, } // namespace -namespace psi::pir { +namespace psi::apsi { class PirTest : public testing::TestWithParam {}; TEST_P(PirTest, Works) { auto params = GetParam(); auto ctxs = yacl::link::test::SetupWorld(2); - apsi::PSIParams psi_params = ::psi::psi::GetPsiParams(params.nr, params.ns); + ::apsi::PSIParams psi_params = GetPsiParams(params.nr, params.ns); std::string tmp_store_path = fmt::format("data_{}_{}", yacl::crypto::FastRandU64(), params.ns); @@ -189,10 +187,10 @@ TEST_P(PirTest, Works) { std::vector labels = {label_cloumn_name}; if (params.use_filedb) { - PirSetupConfig config; + ::psi::PirSetupConfig config; - config.set_pir_protocol(PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(::psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(::psi::KvStoreType::LEVELDB_KV_STORE); config.set_input_path(server_csv_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); @@ -205,48 +203,48 @@ TEST_P(PirTest, Works) { config.set_compressed(params.compressed); config.set_bucket_size(params.bucket_size); - PirResultReport setup_report = PirSetup(config); + ::psi::PirResultReport setup_report = PirSetup(config); EXPECT_EQ(setup_report.data_count(), params.ns); - std::future f_server = std::async([&] { - PirServerConfig config; + std::future<::psi::PirResultReport> f_server = std::async([&] { + ::psi::PirServerConfig config; - config.set_pir_protocol(PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(::psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(::psi::KvStoreType::LEVELDB_KV_STORE); config.set_oprf_key_path(oprf_key_path); config.set_setup_path(setup_path); - PirResultReport report = PirServer(ctxs[0], config); + ::psi::PirResultReport report = PirServer(ctxs[0], config); return report; }); - std::future f_client = std::async([&] { - PirClientConfig config; + std::future<::psi::PirResultReport> f_client = std::async([&] { + ::psi::PirClientConfig config; - config.set_pir_protocol(PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_pir_protocol(::psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); config.set_input_path(client_csv_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); config.set_output_path(pir_result_path); - PirResultReport report = PirClient(ctxs[1], config); + ::psi::PirResultReport report = PirClient(ctxs[1], config); return report; }); - PirResultReport server_report = f_server.get(); - PirResultReport client_report = f_client.get(); + ::psi::PirResultReport server_report = f_server.get(); + ::psi::PirResultReport client_report = f_client.get(); EXPECT_EQ(server_report.data_count(), params.nr); EXPECT_EQ(client_report.data_count(), params.nr); } else { - std::future f_server = std::async([&] { - PirSetupConfig config; + std::future<::psi::PirResultReport> f_server = std::async([&] { + ::psi::PirSetupConfig config; - config.set_pir_protocol(PirProtocol::KEYWORD_PIR_LABELED_PSI); - config.set_store_type(KvStoreType::LEVELDB_KV_STORE); + config.set_pir_protocol(::psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(::psi::KvStoreType::LEVELDB_KV_STORE); config.set_input_path(server_csv_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); @@ -259,32 +257,32 @@ TEST_P(PirTest, Works) { config.set_compressed(params.compressed); config.set_bucket_size(params.bucket_size); - PirResultReport report = PirMemoryServer(ctxs[0], config); + ::psi::PirResultReport report = PirMemoryServer(ctxs[0], config); return report; }); - std::future f_client = std::async([&] { - PirClientConfig config; + std::future<::psi::PirResultReport> f_client = std::async([&] { + ::psi::PirClientConfig config; - config.set_pir_protocol(PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_pir_protocol(::psi::PirProtocol::KEYWORD_PIR_LABELED_PSI); config.set_input_path(client_csv_path); config.mutable_key_columns()->Add(ids.begin(), ids.end()); config.set_output_path(pir_result_path); - PirResultReport report = PirClient(ctxs[1], config); + ::psi::PirResultReport report = PirClient(ctxs[1], config); return report; }); - PirResultReport server_report = f_server.get(); - PirResultReport client_report = f_client.get(); + ::psi::PirResultReport server_report = f_server.get(); + ::psi::PirResultReport client_report = f_client.get(); EXPECT_EQ(server_report.data_count(), params.nr); EXPECT_EQ(client_report.data_count(), params.nr); } - std::shared_ptr<::psi::psi::ILabeledBatchProvider> pir_result_provider = - std::make_shared<::psi::psi::CsvBatchProvider>( + std::shared_ptr<::psi::ILabeledBatchProvider> pir_result_provider = + std::make_shared<::psi::CsvBatchProvider>( pir_result_path, ids, intersection_idx.size(), labels); // read pir_result from csv @@ -307,4 +305,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, PirTest, ); -} // namespace psi::pir +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/psi_params.cc b/psi/apsi/psi_params.cc similarity index 91% rename from psi/psi/core/labeled_psi/psi_params.cc rename to psi/apsi/psi_params.cc index eb80e6b5..4624a65a 100644 --- a/psi/psi/core/labeled_psi/psi_params.cc +++ b/psi/apsi/psi_params.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/labeled_psi/psi_params.h" +#include "psi/apsi/psi_params.h" #include #include @@ -21,7 +21,7 @@ #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -46,7 +46,7 @@ std::vector kSealParams = { {8192, 65537, 0, {56, 56, 30}}, // 14 }; -std::map kPolynomialParams = { +std::map kPolynomialParams = { {20, {0, {1, 2, 5, 8, 9, 10}}}, {35, {5, {1, 2, 3, 4, 5, 6, 18, 30, 42, 54, 60}}}, {36, {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, @@ -94,7 +94,7 @@ std::map kPolynomialParams = { }; } // namespace -yacl::Buffer PsiParamsToBuffer(const apsi::PSIParams &psi_params) { +yacl::Buffer PsiParamsToBuffer(const ::apsi::PSIParams &psi_params) { proto::LabelPsiParamsProto psi_params_proto; psi_params_proto.set_hash_func_count( @@ -132,7 +132,7 @@ yacl::Buffer PsiParamsToBuffer(const apsi::PSIParams &psi_params) { return buffer; } -apsi::PSIParams ParsePsiParamsProto(const yacl::Buffer &buffer) { +::apsi::PSIParams ParsePsiParamsProto(const yacl::Buffer &buffer) { proto::LabelPsiParamsProto psi_params_proto; YACL_ENFORCE(psi_params_proto.ParseFromArray(buffer.data(), buffer.size())); @@ -140,12 +140,12 @@ apsi::PSIParams ParsePsiParamsProto(const yacl::Buffer &buffer) { return ParsePsiParamsProto(psi_params_proto); } -apsi::PSIParams ParsePsiParamsProto( +::apsi::PSIParams ParsePsiParamsProto( const proto::LabelPsiParamsProto &psi_params_proto) { - apsi::PSIParams::ItemParams item_params; - apsi::PSIParams::TableParams table_params; - apsi::PSIParams::QueryParams query_params; - apsi::PSIParams::SEALParams seal_params; + ::apsi::PSIParams::ItemParams item_params; + ::apsi::PSIParams::TableParams table_params; + ::apsi::PSIParams::QueryParams query_params; + ::apsi::PSIParams::SEALParams seal_params; item_params.felts_per_item = psi_params_proto.felts_per_item(); @@ -182,8 +182,8 @@ apsi::PSIParams ParsePsiParamsProto( seal_params.set_coeff_modulus(coeff_modulus); - apsi::PSIParams psi_params(item_params, table_params, query_params, - seal_params); + ::apsi::PSIParams psi_params(item_params, table_params, query_params, + seal_params); return psi_params; } @@ -217,13 +217,13 @@ SEALParams GetSealParams(size_t nr, size_t ns) { return kSealParams[12]; } -apsi::PSIParams GetPsiParams(size_t nr, size_t ns, size_t max_items_per_bin) { +::apsi::PSIParams GetPsiParams(size_t nr, size_t ns, size_t max_items_per_bin) { SEALParams seal_params = GetSealParams(nr, ns); - apsi::PSIParams::ItemParams item_params; - apsi::PSIParams::TableParams table_params; - apsi::PSIParams::QueryParams query_params; - apsi::PSIParams::SEALParams apsi_seal_params; + ::apsi::PSIParams::ItemParams item_params; + ::apsi::PSIParams::TableParams table_params; + ::apsi::PSIParams::QueryParams query_params; + ::apsi::PSIParams::SEALParams apsi_seal_params; size_t hash_size = GetHashTruncateSize(nr, ns); item_params.felts_per_item = std::ceil( @@ -304,10 +304,10 @@ apsi::PSIParams GetPsiParams(size_t nr, size_t ns, size_t max_items_per_bin) { query_params = kPolynomialParams[max_items_per_bin]; } - apsi::PSIParams psi_params(item_params, table_params, query_params, - apsi_seal_params); + ::apsi::PSIParams psi_params(item_params, table_params, query_params, + apsi_seal_params); return psi_params; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/psi_params.h b/psi/apsi/psi_params.h similarity index 73% rename from psi/psi/core/labeled_psi/psi_params.h rename to psi/apsi/psi_params.h index 1742edcb..103489ed 100644 --- a/psi/psi/core/labeled_psi/psi_params.h +++ b/psi/apsi/psi_params.h @@ -23,9 +23,9 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/core/labeled_psi/serializable.pb.h" +#include "psi/apsi/serializable.pb.h" -namespace psi::psi { +namespace psi::apsi { struct SEALParams { size_t poly_modulus_degree; @@ -53,27 +53,27 @@ struct SEALParams { * * @param nr receiver's items size * @param ns sender's items size - * @return apsi::PSIParams + * @return ::apsi::PSIParams */ -apsi::PSIParams GetPsiParams(size_t nr, size_t ns, - size_t max_items_per_bin = 0); +::apsi::PSIParams GetPsiParams(size_t nr, size_t ns, + size_t max_items_per_bin = 0); /** - * @brief Serialize apsi::PSIParams to yacl::Buffer + * @brief Serialize ::apsi::PSIParams to yacl::Buffer * - * @param psi_params apsi::PSIParams + * @param psi_params ::apsi::PSIParams * @return yacl::Buffer */ -yacl::Buffer PsiParamsToBuffer(const apsi::PSIParams &psi_params); +yacl::Buffer PsiParamsToBuffer(const ::apsi::PSIParams &psi_params); /** - * @brief DeSerialize yacl::Buffer to apsi::PSIParams + * @brief DeSerialize yacl::Buffer to ::apsi::PSIParams * * @param buffer PSIParams bytes buffer - * @return apsi::PSIParams + * @return ::apsi::PSIParams */ -apsi::PSIParams ParsePsiParamsProto(const yacl::Buffer &buffer); -apsi::PSIParams ParsePsiParamsProto( +::apsi::PSIParams ParsePsiParamsProto(const yacl::Buffer &buffer); +::apsi::PSIParams ParsePsiParamsProto( const proto::LabelPsiParamsProto &psi_params_proto); -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/receiver.cc b/psi/apsi/receiver.cc similarity index 90% rename from psi/psi/core/labeled_psi/receiver.cc rename to psi/apsi/receiver.cc index 505c3850..1816c8b2 100644 --- a/psi/psi/core/labeled_psi/receiver.cc +++ b/psi/apsi/receiver.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/labeled_psi/receiver.h" +#include "psi/apsi/receiver.h" #include #include @@ -32,12 +32,12 @@ #include "spdlog/spdlog.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/package.h" -#include "psi/psi/core/labeled_psi/padding.h" -#include "psi/psi/core/labeled_psi/psi_params.h" +#include "psi/apsi/package.h" +#include "psi/apsi/padding.h" +#include "psi/apsi/psi_params.h" +#include "psi/ecdh//ecdh_oprf_selector.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -50,7 +50,7 @@ bool HasNZeros(T *ptr, size_t count) { } } // namespace -LabelPsiReceiver::LabelPsiReceiver(const apsi::PSIParams ¶ms, +LabelPsiReceiver::LabelPsiReceiver(const ::apsi::PSIParams ¶ms, bool has_label) : psi_params_(params), has_label_(has_label) { Initialize(); @@ -65,7 +65,7 @@ void LabelPsiReceiver::Initialize() { psi_params_.bins_per_bundle(), psi_params_.bundle_idx_count()); // Initialize the CryptoContext with a new SEALContext - crypto_context_ = apsi::CryptoContext(psi_params_); + crypto_context_ = ::apsi::CryptoContext(psi_params_); // Set up the PowersDag ResetPowersDag(psi_params_.query_params().query_powers); @@ -93,7 +93,7 @@ void LabelPsiReceiver::ResetKeys() { std::uint32_t LabelPsiReceiver::ResetPowersDag( const std::set &source_powers) { // First compute the target powers - std::set target_powers = apsi::util::create_powers_set( + std::set target_powers = ::apsi::util::create_powers_set( psi_params_.query_params().ps_low_degree, psi_params_.table_params().max_items_per_bin); @@ -104,8 +104,8 @@ std::uint32_t LabelPsiReceiver::ResetPowersDag( if (!pd_.is_configured()) { SPDLOG_ERROR( "Failed to configure PowersDag (source_powers:{} target_powers:{})", - apsi::util::to_string(source_powers), - apsi::util::to_string(target_powers)); + ::apsi::util::to_string(source_powers), + ::apsi::util::to_string(target_powers)); YACL_THROW("failed to configure PowersDag"); } @@ -114,7 +114,7 @@ std::uint32_t LabelPsiReceiver::ResetPowersDag( return pd_.depth(); } -apsi::PSIParams LabelPsiReceiver::RequestPsiParams( +::apsi::PSIParams LabelPsiReceiver::RequestPsiParams( size_t items_size, const std::shared_ptr &link_ctx) { yacl::Buffer buffer(&items_size, sizeof(items_size)); @@ -128,17 +128,18 @@ apsi::PSIParams LabelPsiReceiver::RequestPsiParams( return ParsePsiParamsProto(psi_params_buffer); } -std::pair, std::vector> +std::pair, std::vector<::apsi::LabelKey>> LabelPsiReceiver::RequestOPRF( const std::vector &items, const std::shared_ptr &link_ctx) { std::vector blind_items(items.size()); - std::vector> oprf_clients(items.size()); + std::vector> oprf_clients( + items.size()); yacl::parallel_for(0, items.size(), [&](int64_t begin, int64_t end) { for (int idx = begin; idx < end; ++idx) { - oprf_clients[idx] = - CreateEcdhOprfClient(OprfType::Basic, CurveType::CURVE_FOURQ); + oprf_clients[idx] = ecdh::CreateEcdhOprfClient(ecdh::OprfType::Basic, + CurveType::CURVE_FOURQ); oprf_clients[idx]->SetCompareLength(kEccKeySize); blind_items[idx] = oprf_clients[idx]->Blind(items[idx]); @@ -173,8 +174,8 @@ LabelPsiReceiver::RequestOPRF( } }); - std::vector hashed_items(items_oprf.size()); - std::vector label_keys(items_oprf.size()); + std::vector<::apsi::HashedItem> hashed_items(items_oprf.size()); + std::vector<::apsi::LabelKey> label_keys(items_oprf.size()); for (size_t idx = 0; idx < items_oprf.size(); ++idx) { std::memcpy(hashed_items[idx].value().data(), items_oprf[idx].data(), @@ -190,8 +191,8 @@ LabelPsiReceiver::RequestOPRF( std::pair, std::vector> LabelPsiReceiver::RequestQuery( - const std::vector &hashed_items, - const std::vector &label_keys, + const std::vector<::apsi::HashedItem> &hashed_items, + const std::vector<::apsi::LabelKey> &label_keys, const std::shared_ptr &link_ctx) { kuku::KukuTable cuckoo( psi_params_.table_params().table_size, // Size of the hash table @@ -231,7 +232,7 @@ LabelPsiReceiver::RequestQuery( "fill-rate: {}", cuckoo.loc_func_count(), cuckoo.fill_rate()); - apsi::receiver::IndexTranslationTable itt; + ::apsi::receiver::IndexTranslationTable itt; itt.item_count_ = hashed_items.size(); for (size_t item_idx = 0; item_idx < hashed_items.size(); item_idx++) { @@ -241,11 +242,11 @@ LabelPsiReceiver::RequestQuery( } // Set up unencrypted query data - std::vector plain_powers; + std::vector<::apsi::receiver::PlaintextPowers> plain_powers; // prepare_data { - STOPWATCH(apsi::util::recv_stopwatch, + STOPWATCH(::apsi::util::recv_stopwatch, "Receiver::create_query::prepare_data"); for (uint32_t bundle_idx = 0; bundle_idx < psi_params_.bundle_idx_count(); bundle_idx++) { @@ -262,12 +263,12 @@ LabelPsiReceiver::RequestQuery( // Now set up a BitstringView to this item gsl::span item_bytes( reinterpret_cast(item.data()), sizeof(item)); - apsi::BitstringView item_bits( + ::apsi::BitstringView item_bits( item_bytes, psi_params_.item_bit_count()); // Create an algebraic item by breaking up the item into modulo // plain_modulus parts - std::vector alg_item = apsi::util::bits_to_field_elts( + std::vector alg_item = ::apsi::util::bits_to_field_elts( item_bits, psi_params_.seal_params().plain_modulus()); std::copy(alg_item.cbegin(), alg_item.cend(), back_inserter(alg_items)); } @@ -283,7 +284,8 @@ LabelPsiReceiver::RequestQuery( // The very last thing to do is encrypt the plain_powers and consolidate the // matching powers for different bundle indices - std::unordered_map>> + std::unordered_map>> encrypted_powers; // encrypt_data @@ -383,8 +385,8 @@ LabelPsiReceiver::RequestQuery( std::vector> LabelPsiReceiver::ProcessQueryResult( const proto::QueryResultProto &query_result_proto, - const apsi::receiver::IndexTranslationTable &itt, - const std::vector &label_keys) { + const ::apsi::receiver::IndexTranslationTable &itt, + const std::vector<::apsi::LabelKey> &label_keys) { auto seal_context = GetSealContext(); ResultPackage result_package; @@ -404,7 +406,7 @@ LabelPsiReceiver::ProcessQueryResult( gsl::span label_data_span( reinterpret_cast(label_data.data()), label_data.length()); - apsi::SEALObject temp; + ::apsi::SEALObject temp; temp.load(seal_context, label_data_span); result_package.label_result.emplace_back(std::move(temp)); } @@ -521,34 +523,35 @@ LabelPsiReceiver::ProcessQueryResult( SPDLOG_DEBUG("Match found for items[{}] at cuckoo table index {}", item_idx, table_idx); - apsi::Label label; + ::apsi::Label label; if (label_byte_count) { SPDLOG_DEBUG( "Found {} label parts for items[{}]; expecting {}-byte label ", plain_rp.label_result.size(), item_idx, label_byte_count); // Collect the entire label into this vector - apsi::util::AlgLabel alg_label; + ::apsi::util::AlgLabel alg_label; size_t label_offset = seal::util::mul_safe(std::get<1>(i), felts_per_item); for (auto &label_parts : plain_rp.label_result) { - gsl::span label_part( + gsl::span<::apsi::util::felt_t> label_part( label_parts.data() + label_offset, felts_per_item); std::copy(label_part.begin(), label_part.end(), back_inserter(alg_label)); } // Create the label - apsi::EncryptedLabel encrypted_label = apsi::util::dealgebraize_label( - alg_label, received_label_bit_count, - psi_params_.seal_params().plain_modulus()); + ::apsi::EncryptedLabel encrypted_label = + ::apsi::util::dealgebraize_label( + alg_label, received_label_bit_count, + psi_params_.seal_params().plain_modulus()); // Resize down to the effective byte count encrypted_label.resize(effective_label_byte_count); // Decrypt the label - label = apsi::util::decrypt_label( + label = ::apsi::util::decrypt_label( encrypted_label, label_keys[item_idx], nonce_byte_count); } @@ -569,4 +572,4 @@ LabelPsiReceiver::ProcessQueryResult( return match_ids; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/receiver.h b/psi/apsi/receiver.h similarity index 76% rename from psi/psi/core/labeled_psi/receiver.h rename to psi/apsi/receiver.h index 42023485..a6718b2c 100644 --- a/psi/psi/core/labeled_psi/receiver.h +++ b/psi/apsi/receiver.h @@ -29,15 +29,15 @@ #include "seal/seal.h" #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" +#include "psi/ecdh//ecdh_oprf.h" -#include "psi/psi/core/labeled_psi/serializable.pb.h" +#include "psi/apsi/serializable.pb.h" -namespace psi::psi { +namespace psi::apsi { class LabelPsiReceiver { public: - explicit LabelPsiReceiver(const apsi::PSIParams ¶ms, + explicit LabelPsiReceiver(const ::apsi::PSIParams ¶ms, bool has_label = false); /** @@ -45,9 +45,9 @@ class LabelPsiReceiver { * * @param items_size receiver's items size * @param link_ctx link context - * @return apsi::PSIParams + * @return ::apsi::PSIParams */ - static apsi::PSIParams RequestPsiParams( + static ::apsi::PSIParams RequestPsiParams( size_t items_size, const std::shared_ptr &link_ctx); /** @@ -55,12 +55,13 @@ class LabelPsiReceiver { * * @param items receiver's items * @param link_ctx link context - * @return std::pair, - * std::vector> + * @return std::pair, + * std::vector<::apsi::LabelKey>> * * split items's oprf(32B) to HashedItem(16B) and LabelKey(16B) */ - static std::pair, std::vector> + static std::pair, + std::vector<::apsi::LabelKey>> RequestOPRF(const std::vector &items, const std::shared_ptr &link_ctx); @@ -76,8 +77,8 @@ class LabelPsiReceiver { * */ std::pair, std::vector> RequestQuery( - const std::vector &hashed_items, - const std::vector &label_keys, + const std::vector<::apsi::HashedItem> &hashed_items, + const std::vector<::apsi::LabelKey> &label_keys, const std::shared_ptr &link_ctx); /** @@ -88,12 +89,12 @@ class LabelPsiReceiver { /** Returns a reference to the PowersDag configured for this Receiver. */ - const apsi::PowersDag &GetPowersDag() const { return pd_; } + const ::apsi::PowersDag &GetPowersDag() const { return pd_; } /** Returns a reference to the CryptoContext for this Receiver. */ - const apsi::CryptoContext &GetCryptoContext() const { + const ::apsi::CryptoContext &GetCryptoContext() const { return crypto_context_; } @@ -117,20 +118,20 @@ class LabelPsiReceiver { std::vector> ProcessQueryResult( const proto::QueryResultProto &query_result_proto, - const apsi::receiver::IndexTranslationTable &itt, - const std::vector &label_keys); + const ::apsi::receiver::IndexTranslationTable &itt, + const std::vector<::apsi::LabelKey> &label_keys); - apsi::PSIParams psi_params_; + ::apsi::PSIParams psi_params_; - apsi::CryptoContext crypto_context_; + ::apsi::CryptoContext crypto_context_; - apsi::PowersDag pd_; + ::apsi::PowersDag pd_; - apsi::SEALObject relin_keys_; + ::apsi::SEALObject relin_keys_; // NOTE(juhou): we now support zstd compression by default seal::compr_mode_type compr_mode_ = seal::Serialization::compr_mode_default; bool has_label_; }; -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender.cc b/psi/apsi/sender.cc similarity index 89% rename from psi/psi/core/labeled_psi/sender.cc rename to psi/apsi/sender.cc index dd7f9349..09a6c4b6 100644 --- a/psi/psi/core/labeled_psi/sender.cc +++ b/psi/apsi/sender.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/labeled_psi/sender.h" +#include "psi/apsi/sender.h" #include #include @@ -29,21 +29,21 @@ #include "gsl/span" #include "yacl/utils/parallel.h" -#include "psi/psi/core/labeled_psi/package.h" -#include "psi/psi/core/labeled_psi/sender_db.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" +#include "psi/apsi/package.h" +#include "psi/apsi/sender_db.h" +#include "psi/apsi/sender_kvdb.h" -#include "psi/psi/core/labeled_psi/serializable.pb.h" +#include "psi/apsi/serializable.pb.h" -namespace psi::psi { +namespace psi::apsi { namespace { class QueryRequest { public: - QueryRequest(apsi::SEALObject *relin_keys, + QueryRequest(::apsi::SEALObject *relin_keys, std::unordered_map< - uint32_t, std::vector>> + uint32_t, std::vector<::apsi::SEALObject>> &encrypted_powers, const std::shared_ptr &sender_db) { auto seal_context = sender_db->GetSealContext(); @@ -93,12 +93,13 @@ class QueryRequest { using CiphertextPowers = std::vector; -uint32_t reset_powers_dag(apsi::PowersDag *pd, const apsi::PSIParams ¶ms, +uint32_t reset_powers_dag(::apsi::PowersDag *pd, + const ::apsi::PSIParams ¶ms, const std::set &source_powers) { // First compute the target powers std::set target_powers = - apsi::util::create_powers_set(params.query_params().ps_low_degree, - params.table_params().max_items_per_bin); + ::apsi::util::create_powers_set(params.query_params().ps_low_degree, + params.table_params().max_items_per_bin); SPDLOG_DEBUG("target_powers size:{}", target_powers.size()); // Configure the PowersDag @@ -109,8 +110,8 @@ uint32_t reset_powers_dag(apsi::PowersDag *pd, const apsi::PSIParams ¶ms, SPDLOG_INFO( "Failed to configure PowersDag (" "source_powers: {}, target_powers: {}", - apsi::util::to_string(source_powers), - apsi::util::to_string(target_powers)); + ::apsi::util::to_string(source_powers), + ::apsi::util::to_string(target_powers)); YACL_THROW("failed to configure PowersDag"); } SPDLOG_INFO("Configured PowersDag with depth {}", pd->depth()); @@ -122,9 +123,9 @@ uint32_t reset_powers_dag(apsi::PowersDag *pd, const apsi::PSIParams ¶ms, LabelPsiSender::LabelPsiSender(std::shared_ptr sender_db) : sender_db_(std::move(sender_db)) { - apsi::PSIParams params(sender_db_->GetParams()); + ::apsi::PSIParams params(sender_db_->GetParams()); - crypto_context_ = apsi::CryptoContext(sender_db_->GetParams()); + crypto_context_ = ::apsi::CryptoContext(sender_db_->GetParams()); SPDLOG_INFO("begin set PowersDag"); reset_powers_dag(&pd_, params, params.query_params().query_powers); @@ -141,7 +142,7 @@ void LabelPsiSender::RunPsiParams( YACL_ENFORCE(sizeof(nr) == nr_buffer.size()); std::memcpy(&nr, nr_buffer.data(), nr_buffer.size()); - apsi::PSIParams psi_params = GetPsiParams(nr, items_size); + ::apsi::PSIParams psi_params = GetPsiParams(nr, items_size); yacl::Buffer params_buffer = PsiParamsToBuffer(psi_params); @@ -151,7 +152,7 @@ void LabelPsiSender::RunPsiParams( } void LabelPsiSender::RunOPRF( - const std::shared_ptr &oprf_server, + const std::shared_ptr &oprf_server, const std::shared_ptr &link_ctx) { oprf_server->SetCompareLength(kEccKeySize); @@ -188,7 +189,7 @@ void LabelPsiSender::RunOPRF( std::vector> SenderRunQuery( const QueryRequest &query, const std::shared_ptr &sender_db, - const apsi::PowersDag &pd); + const ::apsi::PowersDag &pd); void LabelPsiSender::RunQuery( const std::shared_ptr &link_ctx) { @@ -201,7 +202,7 @@ void LabelPsiSender::RunQuery( auto seal_context = sender_db_->GetSealContext(); - apsi::SEALObject relin_keys; + ::apsi::SEALObject relin_keys; if (seal_context->using_keyswitching()) { auto relin_keys_data = query_proto.relin_keys(); gsl::span relin_keys_data_span( @@ -211,13 +212,14 @@ void LabelPsiSender::RunQuery( relin_keys.load(seal_context, relin_keys_data_span); } - std::unordered_map>> + std::unordered_map>> encrypted_powers; for (int idx = 0; idx < query_proto.encrypted_powers_size(); ++idx) { const proto::EncryptedPowersProto &encrypted_powers_proto = query_proto.encrypted_powers(idx); - std::vector> ciphertexts; + std::vector<::apsi::SEALObject> ciphertexts; ciphertexts.reserve(encrypted_powers_proto.ciphertexts_size()); for (int cipher_idx = 0; @@ -225,7 +227,7 @@ void LabelPsiSender::RunQuery( auto ct = encrypted_powers_proto.ciphertexts(cipher_idx); gsl::span ct_span( reinterpret_cast(ct.data()), ct.length()); - apsi::SEALObject temp; + ::apsi::SEALObject temp; temp.load(seal_context, ct_span); ciphertexts.emplace_back(std::move(temp)); } @@ -241,7 +243,7 @@ void LabelPsiSender::RunQuery( proto::QueryResponseProto response_proto; std::vector> futures; - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; for (auto &result : query_result) { proto::QueryResultProto *result_proto = response_proto.add_results(); @@ -277,26 +279,26 @@ void LabelPsiSender::RunQuery( } void ComputePowers(const std::shared_ptr &sender_db, - const apsi::CryptoContext &crypto_context, + const ::apsi::CryptoContext &crypto_context, std::vector *all_powers, - const apsi::PowersDag &pd, uint32_t bundle_idx, + const ::apsi::PowersDag &pd, uint32_t bundle_idx, seal::MemoryPoolHandle *pool); void ProcessBinBundleCache( const std::shared_ptr &sender_db, - const apsi::CryptoContext &crypto_context, - const std::shared_ptr &bundle, + const ::apsi::CryptoContext &crypto_context, + const std::shared_ptr<::apsi::sender::BinBundle> &bundle, std::vector *all_powers, uint32_t bundle_idx, seal::compr_mode_type compr_mode, seal::MemoryPoolHandle *pool, const std::shared_ptr &result); std::vector> SenderRunQuery( const QueryRequest &query, const std::shared_ptr &sender_db, - const apsi::PowersDag &pd) { + const ::apsi::PowersDag &pd) { // We use a custom SEAL memory that is freed after the query is done auto pool = seal::MemoryManager::GetPool(seal::mm_force_new); - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; // Acquire read lock on SenderDB // auto sender_db = sender_db; @@ -310,11 +312,11 @@ std::vector> SenderRunQuery( // that case query.relin_keys() simply holds an empty seal::RelinKeys // instance. There is no problem with the below call to // CryptoContext::set_evaluator. - apsi::CryptoContext crypto_context(sender_db->GetCryptoContext()); + ::apsi::CryptoContext crypto_context(sender_db->GetCryptoContext()); crypto_context.set_evaluator(query.relin_keys()); // Get the PSIParams - apsi::PSIParams params(sender_db->GetParams()); + ::apsi::PSIParams params(sender_db->GetParams()); uint32_t bundle_idx_count = params.bundle_idx_count(); @@ -371,7 +373,7 @@ std::vector> SenderRunQuery( std::vector> futures; for (size_t cache_idx = 0; cache_idx < cache_count; ++cache_idx) { - std::shared_ptr bundle = + std::shared_ptr<::apsi::sender::BinBundle> bundle = sender_db->GetBinBundleAt(static_cast(bundle_idx), cache_idx); @@ -398,9 +400,9 @@ std::vector> SenderRunQuery( } void ComputePowers(const std::shared_ptr &sender_db, - const apsi::CryptoContext &crypto_context, + const ::apsi::CryptoContext &crypto_context, std::vector *all_powers, - const apsi::PowersDag &pd, uint32_t bundle_idx, + const ::apsi::PowersDag &pd, uint32_t bundle_idx, seal::MemoryPoolHandle *pool) { SPDLOG_DEBUG("Sender::ComputePowers"); @@ -413,7 +415,7 @@ void ComputePowers(const std::shared_ptr &sender_db, CiphertextPowers &powers_at_this_bundle_idx = (*all_powers)[bundle_idx]; bool relinearize = crypto_context.seal_context()->using_keyswitching(); - pd.parallel_apply([&](const apsi::PowersDag::PowersNode &node) { + pd.parallel_apply([&](const ::apsi::PowersDag::PowersNode &node) { if (!node.is_source()) { auto parents = node.parents; seal::Ciphertext prod(*pool); @@ -443,7 +445,7 @@ void ComputePowers(const std::shared_ptr &sender_db, // only for convenience of the indexing; the ciphertext is actually not set or // valid for use. - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; // After computing all powers we will modulus switch down to parameters that // one more level for low powers than for high powers; same choice must be @@ -489,14 +491,14 @@ void ComputePowers(const std::shared_ptr &sender_db, void ProcessBinBundleCache( const std::shared_ptr &sender_db, - const apsi::CryptoContext &crypto_context, - const std::shared_ptr &bundle, + const ::apsi::CryptoContext &crypto_context, + const std::shared_ptr<::apsi::sender::BinBundle> &bundle, std::vector *all_powers, uint32_t bundle_idx, seal::compr_mode_type compr_mode, seal::MemoryPoolHandle *pool, const std::shared_ptr &result) { SPDLOG_DEBUG("Sender::ProcessBinBundleCache"); - std::reference_wrapper cache = + std::reference_wrapper cache = std::cref(bundle->get_cache()); // Package for the result data @@ -509,7 +511,7 @@ void ProcessBinBundleCache( seal::util::safe_cast(sender_db->GetLabelByteCount()); // Compute the matching result and move to rp - const apsi::sender::BatchedPlaintextPolyn &matching_polyn = + const ::apsi::sender::BatchedPlaintextPolyn &matching_polyn = cache.get().batched_matching_polyn; // Determine if we use Paterson-Stockmeyer or not @@ -541,4 +543,4 @@ void ProcessBinBundleCache( } } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender.h b/psi/apsi/sender.h similarity index 83% rename from psi/psi/core/labeled_psi/sender.h rename to psi/apsi/sender.h index 792eb657..c81da652 100644 --- a/psi/psi/core/labeled_psi/sender.h +++ b/psi/apsi/sender.h @@ -24,11 +24,11 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" -#include "psi/psi/core/labeled_psi/psi_params.h" -#include "psi/psi/core/labeled_psi/sender_db.h" +#include "psi/apsi/psi_params.h" +#include "psi/apsi/sender_db.h" +#include "psi/ecdh//ecdh_oprf.h" -namespace psi::psi { +namespace psi::apsi { class LabelPsiSender { public: @@ -49,7 +49,7 @@ class LabelPsiSender { * @param oprf_server * @param link_ctx */ - static void RunOPRF(const std::shared_ptr& oprf_server, + static void RunOPRF(const std::shared_ptr& oprf_server, const std::shared_ptr& link_ctx); /** @@ -62,10 +62,10 @@ class LabelPsiSender { private: std::shared_ptr sender_db_; - apsi::CryptoContext crypto_context_; + ::apsi::CryptoContext crypto_context_; seal::compr_mode_type compr_mode_ = seal::Serialization::compr_mode_default; - apsi::PowersDag pd_; + ::apsi::PowersDag pd_; }; -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_db.cc b/psi/apsi/sender_db.cc similarity index 87% rename from psi/psi/core/labeled_psi/sender_db.cc rename to psi/apsi/sender_db.cc index 9214b99c..b84fe90e 100644 --- a/psi/psi/core/labeled_psi/sender_db.cc +++ b/psi/apsi/sender_db.cc @@ -19,7 +19,8 @@ // we are using our own OPRF, the reason is we wanna make the oprf // switchable between secp256k1, sm2 or other types -// STD +#include "psi/apsi/sender_db.h" + #include #include #include @@ -30,30 +31,24 @@ #include #include -// APSI +#include "absl/strings/escaping.h" #include "apsi/psi_params.h" #include "apsi/thread_pool_mgr.h" #include "apsi/util/db_encoding.h" #include "apsi/util/label_encryptor.h" #include "apsi/util/utils.h" -#include "spdlog/spdlog.h" -#include "yacl/base/exception.h" - -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/sender_db.h" -#include "psi/psi/core/labeled_psi/serialize.h" - -// Kuku #include "kuku/locfunc.h" - -// SEAL -#include "absl/strings/escaping.h" #include "seal/util/common.h" #include "seal/util/streambuf.h" +#include "spdlog/spdlog.h" +#include "yacl/base/exception.h" #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +#include "psi/apsi/serialize.h" +#include "psi/ecdh//ecdh_oprf_selector.h" + +namespace psi::apsi { namespace { @@ -61,13 +56,11 @@ using DurationMillis = std::chrono::duration; } -namespace labeled_psi { - /** Creates and returns the vector of hash functions similarly to how Kuku 2.x sets them internally. */ -std::vector HashFunctions(const apsi::PSIParams ¶ms) { +std::vector HashFunctions(const ::apsi::PSIParams ¶ms) { std::vector result; for (uint32_t i = 0; i < params.table_params().hash_func_count; i++) { result.emplace_back(params.table_params().table_size, @@ -82,7 +75,7 @@ Computes all cuckoo hash table locations for a given item. */ std::unordered_set AllLocations( const std::vector &hash_funcs, - const apsi::HashedItem &item) { + const ::apsi::HashedItem &item) { std::unordered_set result; for (const auto &hf : hash_funcs) { result.emplace(hf(item.get_as().front())); @@ -95,7 +88,7 @@ std::unordered_set AllLocations( Compute the label size in multiples of item-size chunks. */ size_t ComputeLabelSize(size_t label_byte_count, - const apsi::PSIParams ¶ms) { + const ::apsi::PSIParams ¶ms) { return (label_byte_count * 8 + params.item_bit_count() - 1) / params.item_bit_count(); } @@ -118,9 +111,7 @@ std::pair UnpackCuckooIdx(size_t cuckoo_idx, return {bin_idx, bundle_idx}; } -} // namespace labeled_psi - -ISenderDB::ISenderDB(const apsi::PSIParams ¶ms, +ISenderDB::ISenderDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, std::size_t label_byte_count, std::size_t nonce_byte_count, bool compressed) @@ -139,9 +130,9 @@ ISenderDB::ISenderDB(const apsi::PSIParams ¶ms, YACL_THROW("label_byte_count is too large"); } - if (nonce_byte_count_ > apsi::max_nonce_byte_count) { + if (nonce_byte_count_ > ::apsi::max_nonce_byte_count) { SPDLOG_ERROR("Request nonce byte count {} exceeds the maximum ({}) ", - nonce_byte_count_, apsi::max_nonce_byte_count); + nonce_byte_count_, ::apsi::max_nonce_byte_count); YACL_THROW("nonce_byte_count is too large"); } @@ -149,13 +140,13 @@ ISenderDB::ISenderDB(const apsi::PSIParams ¶ms, // this is a labeled SenderDB but may not be safe to use for arbitrary label // changes. if ((label_byte_count_ != 0) && - nonce_byte_count_ < apsi::max_nonce_byte_count) { + nonce_byte_count_ < ::apsi::max_nonce_byte_count) { SPDLOG_WARN( "You have instantiated a labeled SenderDB instance with a nonce byte " "count {} , which is less than the safe default value {} . Updating " "labels for existing items in the SenderDB or removing and reinserting " "items with different labels may leak information about the labels.", - nonce_byte_count_, apsi::max_nonce_byte_count); + nonce_byte_count_, ::apsi::max_nonce_byte_count); } // Set the evaluator. This will be used for BatchedPlaintextPolyn::eval. @@ -164,8 +155,8 @@ ISenderDB::ISenderDB(const apsi::PSIParams ¶ms, oprf_key_.resize(oprf_key.size()); std::memcpy(oprf_key_.data(), oprf_key.data(), oprf_key.size()); - oprf_server_ = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); + oprf_server_ = ecdh::CreateEcdhOprfServer(oprf_key, ecdh::OprfType::Basic, + CurveType::CURVE_FOURQ); oprf_server_->SetCompareLength(kEccKeySize); } @@ -194,4 +185,4 @@ std::vector ISenderDB::GetOprfKey() const { return oprf_key_; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_db.h b/psi/apsi/sender_db.h similarity index 88% rename from psi/psi/core/labeled_psi/sender_db.h rename to psi/apsi/sender_db.h index cb6e8336..ed7089bf 100644 --- a/psi/psi/core/labeled_psi/sender_db.h +++ b/psi/apsi/sender_db.h @@ -17,7 +17,6 @@ #pragma once -// STD #include #include #include @@ -28,27 +27,22 @@ #include #include -// GSL -#include "gsl/span" - -// APSI #include "apsi/bin_bundle.h" #include "apsi/crypto_context.h" #include "apsi/item.h" #include "apsi/psi_params.h" +#include "gsl/span" +#include "seal/plaintext.h" +#include "seal/util/locks.h" +#include "spdlog/spdlog.h" #include "yacl/base/byte_container_view.h" #include "yacl/io/kv/leveldb_kvstore.h" #include "yacl/io/kv/memory_kvstore.h" -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/ecdh//basic_ecdh_oprf.h" +#include "psi/utils/batch_provider.h" -// SEAL -#include "seal/plaintext.h" -#include "seal/util/locks.h" -#include "spdlog/spdlog.h" - -namespace psi::psi { +namespace psi::apsi { /** A SenderDB maintains an in-memory representation of the sender's set of items @@ -78,7 +72,7 @@ class ISenderDB { /** Creates a new SenderDB. */ - ISenderDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + ISenderDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, std::size_t label_byte_count = 0, std::size_t nonce_byte_count = 16, bool compressed = false); @@ -116,18 +110,18 @@ class ISenderDB { /** Returns the bundle at the given bundle index. */ - virtual std::shared_ptr GetBinBundleAt( + virtual std::shared_ptr<::apsi::sender::BinBundle> GetBinBundleAt( std::uint32_t bundle_idx, size_t cache_idx) = 0; /** Returns a reference to the PSI parameters for this SenderDB. */ - virtual const apsi::PSIParams &GetParams() const { return params_; } + virtual const ::apsi::PSIParams &GetParams() const { return params_; } /** Returns a reference to the CryptoContext for this SenderDB. */ - virtual const apsi::CryptoContext &GetCryptoContext() const { + virtual const ::apsi::CryptoContext &GetCryptoContext() const { return crypto_context_; } @@ -175,12 +169,12 @@ class ISenderDB { The PSI parameters define the SEAL parameters, base field, item size, table size, etc. */ - apsi::PSIParams params_; + ::apsi::PSIParams params_; /** Necessary for evaluating polynomials of Plaintexts. */ - apsi::CryptoContext crypto_context_; + ::apsi::CryptoContext crypto_context_; /** A read-write lock to protect the database from modification while in use. @@ -221,21 +215,19 @@ class ISenderDB { Holds the OPRF key for this SenderDB. */ std::vector oprf_key_; - std::unique_ptr oprf_server_; + std::unique_ptr oprf_server_; }; // class SenderDB -namespace labeled_psi { - -std::vector HashFunctions(const apsi::PSIParams ¶ms); +std::vector HashFunctions(const ::apsi::PSIParams ¶ms); std::unordered_set AllLocations( - const std::vector &hash_funcs, const apsi::HashedItem &item); + const std::vector &hash_funcs, + const ::apsi::HashedItem &item); -size_t ComputeLabelSize(size_t label_byte_count, const apsi::PSIParams ¶ms); +size_t ComputeLabelSize(size_t label_byte_count, + const ::apsi::PSIParams ¶ms); std::pair UnpackCuckooIdx(size_t cuckoo_idx, size_t bins_per_bundle); -} // namespace labeled_psi - -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_kvdb.cc b/psi/apsi/sender_kvdb.cc similarity index 87% rename from psi/psi/core/labeled_psi/sender_kvdb.cc rename to psi/apsi/sender_kvdb.cc index 897cc94b..374fe5f6 100644 --- a/psi/psi/core/labeled_psi/sender_kvdb.cc +++ b/psi/apsi/sender_kvdb.cc @@ -35,14 +35,15 @@ #include "apsi/thread_pool_mgr.h" #include "apsi/util/db_encoding.h" #include "apsi/util/label_encryptor.h" +#include "apsi/util/stopwatch.h" #include "apsi/util/utils.h" #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/padding.h" -#include "psi/psi/core/labeled_psi/sender_kvdb.h" -#include "psi/psi/core/labeled_psi/serialize.h" +#include "psi/apsi/padding.h" +#include "psi/apsi/sender_kvdb.h" +#include "psi/apsi/serialize.h" +#include "psi/ecdh//ecdh_oprf_selector.h" // Kuku #include "kuku/locfunc.h" @@ -54,29 +55,29 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi::apsi { namespace { using DurationMillis = std::chrono::duration; -std::vector> PreprocessUnlabeledData( - const apsi::HashedItem &hashed_item, const apsi::PSIParams ¶ms) { +std::vector> PreprocessUnlabeledData( + const ::apsi::HashedItem &hashed_item, const ::apsi::PSIParams ¶ms) { // Some variables we'll need size_t bins_per_item = params.item_params().felts_per_item; size_t item_bit_count = params.item_bit_count(); // Set up Kuku hash functions - auto hash_funcs = labeled_psi::HashFunctions(params); + auto hash_funcs = HashFunctions(params); - std::vector> data_with_indices; + std::vector> data_with_indices; // Serialize the data into field elements - apsi::util::AlgItem alg_item = algebraize_item( + ::apsi::util::AlgItem alg_item = algebraize_item( hashed_item, item_bit_count, params.seal_params().plain_modulus()); // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : labeled_psi::AllLocations(hash_funcs, hashed_item)) { + for (auto location : AllLocations(hash_funcs, hashed_item)) { // The current hash value is an index into a table of Items. In reality // our BinBundles are tables of bins, which contain chunks of items. How // many chunks? bins_per_item many chunks @@ -89,28 +90,29 @@ std::vector> PreprocessUnlabeledData( return data_with_indices; } -std::vector> PreprocessLabeledData( - const std::pair &item_label_pair, - const apsi::PSIParams ¶ms, - const std::vector &hash_funcs) { +std::vector> +PreprocessLabeledData(const std::pair<::apsi::HashedItem, + ::apsi::EncryptedLabel> &item_label_pair, + const ::apsi::PSIParams ¶ms, + const std::vector &hash_funcs) { SPDLOG_DEBUG("Start preprocessing {} labeled items", distance(begin, end)); // Some variables we'll need size_t bins_per_item = params.item_params().felts_per_item; size_t item_bit_count = params.item_bit_count(); - std::vector> data_with_indices; + std::vector> data_with_indices; // Serialize the data into field elements - const apsi::HashedItem &item = item_label_pair.first; - const apsi::EncryptedLabel &label = item_label_pair.second; - apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( + const ::apsi::HashedItem &item = item_label_pair.first; + const ::apsi::EncryptedLabel &label = item_label_pair.second; + ::apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( item, label, item_bit_count, params.seal_params().plain_modulus()); std::set loc_set; // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : labeled_psi::AllLocations(hash_funcs, item)) { + for (auto location : AllLocations(hash_funcs, item)) { // The current hash value is an index into a table of Items. In reality // our BinBundles are tables of bins, which contain chunks of items. How // many chunks? bins_per_item many chunks @@ -139,7 +141,7 @@ void InsertOrAssignWorker( const std::vector> &data_with_indices, std::vector> *bundles_store, std::vector *bundles_store_idx, - const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + const ::apsi::CryptoContext &crypto_context, uint32_t bundle_index, uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, size_t ps_low_degree, bool overwrite, bool compressed) { STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); @@ -148,7 +150,7 @@ void InsertOrAssignWorker( bundle_index, overwrite ? "overwriting existing" : "inserting new"); // Create the bundle set at the given bundle index - std::vector bundle_set; + std::vector<::apsi::sender::BinBundle> bundle_set; // Iteratively insert each item-label pair at the given cuckoo index for (auto &data_with_idx : data_with_indices) { @@ -159,7 +161,7 @@ void InsertOrAssignWorker( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); // If the bundle_idx isn't in the prescribed range, don't try to insert this // data @@ -214,7 +216,7 @@ void InsertOrAssignWorker( // make a new BinBundle and insert the data there if (!written) { // Make a fresh BinBundle and insert - apsi::sender::BinBundle new_bin_bundle( + ::apsi::sender::BinBundle new_bin_bundle( crypto_context, label_size, max_bin_size, ps_low_degree, bins_per_bundle, compressed, false); int res = new_bin_bundle.multi_insert_for_real(data, bin_idx); @@ -261,7 +263,7 @@ void InsertOrAssignWorker( size_t indices_count, std::vector> *bundles_store, std::vector *bundles_store_idx, bool is_labeled, - const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + const ::apsi::CryptoContext &crypto_context, uint32_t bundle_index, uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, size_t ps_low_degree, bool overwrite, bool compressed) { @@ -272,7 +274,7 @@ void InsertOrAssignWorker( bundle_index, overwrite ? "overwriting existing" : "inserting new"); // Create the bundle set at the given bundle index - std::vector bundle_set; + std::vector<::apsi::sender::BinBundle> bundle_set; // Iteratively insert each item-label pair at the given cuckoo index for (size_t i = 0; i < indices_count; ++i) { @@ -282,8 +284,8 @@ void InsertOrAssignWorker( size_t cuckoo_idx; - std::pair datalabel_with_idx; - std::pair data_with_idx; + std::pair<::apsi::util::AlgItemLabel, size_t> datalabel_with_idx; + std::pair<::apsi::util::AlgItem, size_t> data_with_idx; if (is_labeled) { datalabel_with_idx = DeserializeDataLabelWithIndices(std::string_view( @@ -296,12 +298,12 @@ void InsertOrAssignWorker( cuckoo_idx = data_with_idx.second; } - // const apsi::util::AlgItem &data = data_with_idx.first; + // const ::apsi::util::AlgItem &data = data_with_idx.first; // Get the bundle index size_t bin_idx, bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); // If the bundle_idx isn't in the prescribed range, don't try to insert // this data @@ -380,7 +382,7 @@ void InsertOrAssignWorker( // make a new BinBundle and insert the data there if (!written) { // Make a fresh BinBundle and insert - apsi::sender::BinBundle new_bin_bundle( + ::apsi::sender::BinBundle new_bin_bundle( crypto_context, label_size, max_bin_size, ps_low_degree, bins_per_bundle, compressed, false); @@ -447,10 +449,10 @@ void DispatchInsertOrAssign( const std::vector> &data_with_indices, std::vector> *bundles_store, std::vector *bundles_store_idx, - const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + const ::apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, bool overwrite, bool compressed) { - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; // Collect the bundle indices and partition them into thread_count many // partitions. By some uniformity assumption, the number of things to insert @@ -462,7 +464,7 @@ void DispatchInsertOrAssign( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); bundle_indices_set.insert(bundle_idx); } @@ -501,10 +503,10 @@ void DispatchInsertOrAssign( size_t indices_count, const std::set &bundle_indices_set, std::vector> *bundles_store, std::vector *bundles_store_idx, bool is_labeled, - const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + const ::apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, bool overwrite, bool compressed) { - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; std::vector bundle_indices; bundle_indices.reserve(bundle_indices_set.size()); @@ -542,7 +544,7 @@ constexpr char kMemoryStoreFlag[] = "::memory"; } // namespace -SenderKvDB::SenderKvDB(const apsi::PSIParams ¶ms, +SenderKvDB::SenderKvDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, std::string_view kv_store_path, size_t label_byte_count, size_t nonce_byte_count, bool compressed) @@ -634,14 +636,7 @@ void SenderKvDB::clear() { ClearInternal(); } -void SenderKvDB::GenerateCaches() { - STOPWATCH(sender_stopwatch, "SenderDB::GenerateCaches"); - SPDLOG_INFO("Start generating bin bundle caches"); - - SPDLOG_INFO("Finished generating bin bundle caches"); -} - -std::shared_ptr SenderKvDB::GetBinBundleAt( +std::shared_ptr<::apsi::sender::BinBundle> SenderKvDB::GetBinBundleAt( uint32_t bundle_idx, size_t cache_idx) { yacl::Buffer value; @@ -649,8 +644,8 @@ std::shared_ptr SenderKvDB::GetBinBundleAt( YACL_ENFORCE(get_status); - size_t label_size = labeled_psi::ComputeLabelSize( - nonce_byte_count_ + label_byte_count_, params_); + size_t label_size = + ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); uint32_t bins_per_bundle = params_.bins_per_bundle(); uint32_t max_bin_size = params_.table_params().max_items_per_bin; @@ -658,8 +653,8 @@ std::shared_ptr SenderKvDB::GetBinBundleAt( bool compressed = false; - std::shared_ptr load_bin_bundle = - std::make_shared( + std::shared_ptr<::apsi::sender::BinBundle> load_bin_bundle = + std::make_shared<::apsi::sender::BinBundle>( crypto_context_, label_size, max_bin_size, ps_low_degree, bins_per_bundle, compressed, false); @@ -742,35 +737,35 @@ void SenderKvDB::InsertOrAssign( std::vector oprf_out = oprf_server_->FullEvaluate(batch_items); - std::vector>> + std::vector>> data_with_indices_vec; if (IsLabeled()) { data_with_indices_vec.resize(oprf_out.size()); - size_t key_offset_pos = sizeof(apsi::Item::value_type); + size_t key_offset_pos = sizeof(::apsi::Item::value_type); // Set up Kuku hash functions - auto hash_funcs = labeled_psi::HashFunctions(params_); + auto hash_funcs = HashFunctions(params_); yacl::parallel_for(0, oprf_out.size(), [&](int64_t begin, int64_t end) { for (int64_t idx = begin; idx < end; ++idx) { - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; std::memcpy(value.data(), &oprf_out[idx][0], value.size()); - apsi::HashedItem hashed_item(value); + ::apsi::HashedItem hashed_item(value); - apsi::LabelKey key; + ::apsi::LabelKey key; std::memcpy(key.data(), &oprf_out[idx][key_offset_pos], - apsi::label_key_byte_count); + ::apsi::label_key_byte_count); - apsi::Label label_with_padding = + ::apsi::Label label_with_padding = PaddingData(batch_labels[idx], label_byte_count_); - apsi::EncryptedLabel encrypted_label = apsi::util::encrypt_label( + ::apsi::EncryptedLabel encrypted_label = ::apsi::util::encrypt_label( label_with_padding, key, label_byte_count_, nonce_byte_count_); - std::pair item_label_pair = - std::make_pair(hashed_item, encrypted_label); + std::pair<::apsi::HashedItem, ::apsi::EncryptedLabel> + item_label_pair = std::make_pair(hashed_item, encrypted_label); data_with_indices_vec[idx] = PreprocessLabeledData(item_label_pair, params_, hash_funcs); @@ -787,7 +782,7 @@ void SenderKvDB::InsertOrAssign( size_t cuckoo_idx = data_with_indices_vec[i][j].second; size_t bin_idx, bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); bundle_indices_set.insert(bundle_idx); } @@ -797,13 +792,13 @@ void SenderKvDB::InsertOrAssign( } else { for (size_t i = 0; i < oprf_out.size(); ++i) { // - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; std::memcpy(value.data(), &oprf_out[i][0], value.size()); - apsi::HashedItem hashed_item(value); + ::apsi::HashedItem hashed_item(value); - std::vector> data_with_indices = - PreprocessUnlabeledData(hashed_item, params_); + std::vector> + data_with_indices = PreprocessUnlabeledData(hashed_item, params_); for (size_t j = 0; j < data_with_indices.size(); ++j) { std::string indices_buffer = @@ -815,7 +810,7 @@ void SenderKvDB::InsertOrAssign( size_t bin_idx, bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); bundle_indices_set.insert(bundle_idx); } @@ -830,8 +825,8 @@ void SenderKvDB::InsertOrAssign( size_t label_size = 0; if (IsLabeled()) { - label_size = labeled_psi::ComputeLabelSize( - nonce_byte_count_ + label_byte_count_, params_); + label_size = + ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); } DispatchInsertOrAssign( @@ -844,4 +839,4 @@ void SenderKvDB::InsertOrAssign( SPDLOG_INFO("Finished inserting {} items in SenderDB", item_count_); } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_kvdb.h b/psi/apsi/sender_kvdb.h similarity index 93% rename from psi/psi/core/labeled_psi/sender_kvdb.h rename to psi/apsi/sender_kvdb.h index 95f4a2dc..62f51fa1 100644 --- a/psi/psi/core/labeled_psi/sender_kvdb.h +++ b/psi/apsi/sender_kvdb.h @@ -40,16 +40,16 @@ #include "yacl/io/kv/leveldb_kvstore.h" #include "yacl/io/kv/memory_kvstore.h" -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" -#include "psi/psi/core/labeled_psi/sender_db.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/apsi/sender_db.h" +#include "psi/ecdh//basic_ecdh_oprf.h" +#include "psi/utils/batch_provider.h" // SEAL #include "seal/plaintext.h" #include "seal/util/locks.h" #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi::apsi { /** A SenderDB maintains an in-memory representation of the sender's set of items @@ -79,7 +79,7 @@ class SenderKvDB : public ISenderDB { /** Creates a new SenderDB. */ - SenderKvDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + SenderKvDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, std::string_view kv_store_path = "", std::size_t label_byte_count = 0, std::size_t nonce_byte_count = 16, bool compressed = true); @@ -136,7 +136,7 @@ class SenderKvDB : public ISenderDB { /** Returns the bundle at the given bundle index. */ - std::shared_ptr GetBinBundleAt( + std::shared_ptr<::apsi::sender::BinBundle> GetBinBundleAt( std::uint32_t bundle_idx, size_t cache_idx) override; /** @@ -152,8 +152,6 @@ class SenderKvDB : public ISenderDB { private: void ClearInternal(); - void GenerateCaches(); - std::string kv_store_path_; std::shared_ptr meta_info_store_; @@ -161,4 +159,4 @@ class SenderKvDB : public ISenderDB { std::vector bundles_store_idx_; }; // class SenderDB -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_memdb.cc b/psi/apsi/sender_memdb.cc similarity index 86% rename from psi/psi/core/labeled_psi/sender_memdb.cc rename to psi/apsi/sender_memdb.cc index 8e2e35d9..f34263a3 100644 --- a/psi/psi/core/labeled_psi/sender_memdb.cc +++ b/psi/apsi/sender_memdb.cc @@ -34,12 +34,13 @@ #include "apsi/thread_pool_mgr.h" #include "apsi/util/db_encoding.h" #include "apsi/util/label_encryptor.h" +#include "apsi/util/stopwatch.h" #include "apsi/util/utils.h" #include "spdlog/spdlog.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/core/labeled_psi/padding.h" -#include "psi/psi/core/labeled_psi/sender_memdb.h" +#include "psi/apsi/padding.h" +#include "psi/apsi/sender_memdb.h" +#include "psi/ecdh//ecdh_oprf_selector.h" // Kuku #include "kuku/locfunc.h" @@ -50,7 +51,7 @@ #include "seal/util/streambuf.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi::apsi { namespace { @@ -59,12 +60,13 @@ Converts each given Item-Label pair in between the given iterators into its algebraic form, i.e., a sequence of felt-felt pairs. Also computes each Item's cuckoo index. */ -std::vector> PreprocessLabeledData( - const std::vector>::const_iterator begin, - const std::vector< - std::pair>::const_iterator end, - const apsi::PSIParams ¶ms) { +std::vector> +PreprocessLabeledData( + const std::vector>::const_iterator begin, + const std::vector>::const_iterator end, + const ::apsi::PSIParams ¶ms) { STOPWATCH(sender_stopwatch, "preprocess_labeled_data"); SPDLOG_DEBUG("Start preprocessing {} labeled items", distance(begin, end)); @@ -73,30 +75,30 @@ std::vector> PreprocessLabeledData( size_t item_bit_count = params.item_bit_count(); // Set up Kuku hash functions - auto hash_funcs = labeled_psi::HashFunctions(params); + auto hash_funcs = HashFunctions(params); // Calculate the cuckoo indices for each item. Store every pair of // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector // by cuckoo_idx and use the result to parallelize the work of inserting the // items into BinBundles. - std::vector> data_with_indices; + std::vector> data_with_indices; for (auto it = begin; it != end; it++) { - const std::pair &item_label_pair = - *it; + const std::pair<::apsi::HashedItem, ::apsi::EncryptedLabel> + &item_label_pair = *it; // Serialize the data into field elements - const apsi::HashedItem &item = item_label_pair.first; - const apsi::EncryptedLabel &label = item_label_pair.second; - apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( + const ::apsi::HashedItem &item = item_label_pair.first; + const ::apsi::EncryptedLabel &label = item_label_pair.second; + ::apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( item, label, item_bit_count, params.seal_params().plain_modulus()); - std::vector> temp_data; + std::vector> temp_data; std::set loc_set; // Get the cuckoo table locations for this item and add to // data_with_indices - for (auto location : labeled_psi::AllLocations(hash_funcs, item)) { + for (auto location : AllLocations(hash_funcs, item)) { // The current hash value is an index into a table of Items. In // reality our BinBundles are tables of bins, which contain chunks // of items. How many chunks? bins_per_item many chunks @@ -122,10 +124,10 @@ std::vector> PreprocessLabeledData( Converts each given Item into its algebraic form, i.e., a sequence of felt-monostate pairs. Also computes each Item's cuckoo index. */ -std::vector> PreprocessUnlabeledData( - const std::vector::const_iterator begin, - const std::vector::const_iterator end, - const apsi::PSIParams ¶ms) { +std::vector> PreprocessUnlabeledData( + const std::vector<::apsi::HashedItem>::const_iterator begin, + const std::vector<::apsi::HashedItem>::const_iterator end, + const ::apsi::PSIParams ¶ms) { STOPWATCH(sender_stopwatch, "preprocess_unlabeled_data"); SPDLOG_DEBUG("Start preprocessing {} unlabeled items", distance(begin, end)); @@ -134,22 +136,22 @@ std::vector> PreprocessUnlabeledData( size_t item_bit_count = params.item_bit_count(); // Set up Kuku hash functions - auto hash_funcs = labeled_psi::HashFunctions(params); + auto hash_funcs = HashFunctions(params); // Calculate the cuckoo indices for each item. Store every pair of // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector // by cuckoo_idx and use the result to parallelize the work of inserting the // items into BinBundles. - std::vector> data_with_indices; + std::vector> data_with_indices; for (auto it = begin; it != end; it++) { - const apsi::HashedItem &item = *it; + const ::apsi::HashedItem &item = *it; // Serialize the data into field elements - apsi::util::AlgItem alg_item = algebraize_item( + ::apsi::util::AlgItem alg_item = algebraize_item( item, item_bit_count, params.seal_params().plain_modulus()); // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : labeled_psi::AllLocations(hash_funcs, item)) { + for (auto location : AllLocations(hash_funcs, item)) { // The current hash value is an index into a table of Items. In reality // our BinBundles are tables of bins, which contain chunks of items. How // many chunks? bins_per_item many chunks @@ -170,9 +172,9 @@ std::vector> PreprocessUnlabeledData( Converts given Item into its algebraic form, i.e., a sequence of felt-monostate pairs. Also computes the Item's cuckoo index. */ -std::vector> PreprocessUnlabeledData( - const apsi::HashedItem &item, const apsi::PSIParams ¶ms) { - std::vector item_singleton{item}; +std::vector> PreprocessUnlabeledData( + const ::apsi::HashedItem &item, const ::apsi::PSIParams ¶ms) { + std::vector<::apsi::HashedItem> item_singleton{item}; return PreprocessUnlabeledData(item_singleton.begin(), item_singleton.end(), params); } @@ -188,9 +190,9 @@ the labels if it finds an AlgItemLabel that matches the input perfectly. template void InsertOrAssignWorker( const std::vector> &data_with_indices, - std::vector>> + std::vector>> *bin_bundles, - const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + const ::apsi::CryptoContext &crypto_context, uint32_t bundle_index, uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, size_t ps_low_degree, bool overwrite, bool compressed) { STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); @@ -207,7 +209,7 @@ void InsertOrAssignWorker( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); // If the bundle_idx isn't in the prescribed range, don't try to insert this // data @@ -217,7 +219,7 @@ void InsertOrAssignWorker( } // Get the bundle set at the given bundle index - std::vector> &bundle_set = + std::vector> &bundle_set = (*bin_bundles)[bundle_idx]; // Try to insert or overwrite these field elements in an existing BinBundle @@ -265,8 +267,8 @@ void InsertOrAssignWorker( // make a new BinBundle and insert the data there if (!written) { // Make a fresh BinBundle and insert - std::shared_ptr new_bin_bundle = - std::make_shared( + std::shared_ptr<::apsi::sender::BinBundle> new_bin_bundle = + std::make_shared<::apsi::sender::BinBundle>( crypto_context, label_size, max_bin_size, ps_low_degree, bins_per_bundle, compressed, false); int res = new_bin_bundle->multi_insert_for_real(data, bin_idx); @@ -298,12 +300,12 @@ perfectly. template void DispatchInsertOrAssign( const std::vector> &data_with_indices, - std::vector>> + std::vector>> *bin_bundles, - const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + const ::apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, bool overwrite, bool compressed) { - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; // Collect the bundle indices and partition them into thread_count many // partitions. By some uniformity assumption, the number of things to insert @@ -315,7 +317,7 @@ void DispatchInsertOrAssign( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); bundle_indices_set.insert(bundle_idx); } @@ -354,9 +356,9 @@ Removes the given items and corresponding labels from bin_bundles at their respective cuckoo indices. */ void RemoveWorker( - const std::vector> + const std::vector> &data_with_indices, - std::vector>> + std::vector>> *bin_bundles, uint32_t bundle_index, uint32_t bins_per_bundle) { STOPWATCH(sender_stopwatch, "remove_worker"); @@ -369,7 +371,7 @@ void RemoveWorker( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); // If the bundle_idx isn't in the prescribed range, don't try to remove this // data @@ -379,7 +381,7 @@ void RemoveWorker( } // Get the bundle set at the given bundle index - std::vector> &bundle_set = + std::vector> &bundle_set = (*bin_bundles)[bundle_idx]; // Try to remove these field elements from an existing BinBundle at this @@ -418,12 +420,12 @@ Takes algebraized data to be removed, splits it up, and distributes it so that thread_count many threads can all remove in parallel. */ void DispatchRemove( - const std::vector> + const std::vector> &data_with_indices, - std::vector>> + std::vector>> *bin_bundles, uint32_t bins_per_bundle) { - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; // Collect the bundle indices and partition them into thread_count many // partitions. By some uniformity assumption, the number of things to remove @@ -435,7 +437,7 @@ void DispatchRemove( size_t bin_idx; size_t bundle_idx; std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); bundle_indices_set.insert(bundle_idx); } @@ -466,7 +468,7 @@ void DispatchRemove( } // namespace -SenderMemDB::SenderMemDB(const apsi::PSIParams ¶ms, +SenderMemDB::SenderMemDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, size_t label_byte_count, size_t nonce_byte_count, bool compressed) @@ -523,7 +525,7 @@ void SenderMemDB::GenerateCaches() { STOPWATCH(sender_stopwatch, "SenderDB::GenerateCaches"); SPDLOG_INFO("Start generating bin bundle caches"); - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; std::vector> futures; @@ -552,7 +554,7 @@ void SenderMemDB::strip() { memset(oprf_key_.data(), 0, oprf_key_.size()); hashed_items_.clear(); - apsi::ThreadPoolMgr tpm; + ::apsi::ThreadPoolMgr tpm; std::vector> futures; for (auto &bundle_idx : bin_bundles_) { @@ -585,23 +587,23 @@ void SenderMemDB::InsertOrAssign(const std::vector &keys, // First compute the hashes for the input data std::vector oprf_out = oprf_server_->FullEvaluate(keys); - std::vector> hashed_data( - oprf_out.size()); + std::vector> + hashed_data(oprf_out.size()); yacl::parallel_for(0, oprf_out.size(), [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; ++i) { - apsi::HashedItem hashed_item; + ::apsi::HashedItem hashed_item; std::memcpy(hashed_item.value().data(), oprf_out[i].data(), hashed_item.value().size()); - apsi::LabelKey key; + ::apsi::LabelKey key; std::memcpy(key.data(), &oprf_out[i][hashed_item.value().size()], key.size()); - apsi::Label label_with_padding = + ::apsi::Label label_with_padding = PaddingData(labels[i], label_byte_count_); - apsi::EncryptedLabel encrypted_label = encrypt_label( + ::apsi::EncryptedLabel encrypted_label = encrypt_label( label_with_padding, key, label_byte_count_, nonce_byte_count_); hashed_data[i] = std::make_pair(hashed_item, encrypted_label); @@ -636,8 +638,8 @@ void SenderMemDB::InsertOrAssign(const std::vector &keys, // Compute the label size; this ceil(effective_label_bit_count / // item_bit_count) - size_t label_size = labeled_psi::ComputeLabelSize( - nonce_byte_count_ + label_byte_count_, params_); + size_t label_size = + ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); auto new_item_count = distance(hashed_data.begin(), new_data_end); auto existing_item_count = distance(new_data_end, hashed_data.end()); @@ -648,8 +650,9 @@ void SenderMemDB::InsertOrAssign(const std::vector &keys, // Break the data into field element representation. Also compute the items' // cuckoo indices. - std::vector> data_with_indices = - PreprocessLabeledData(new_data_end, hashed_data.end(), params_); + std::vector> + data_with_indices = + PreprocessLabeledData(new_data_end, hashed_data.end(), params_); DispatchInsertOrAssign(data_with_indices, &bin_bundles_, crypto_context_, bins_per_bundle, label_size, max_bin_size, @@ -665,8 +668,9 @@ void SenderMemDB::InsertOrAssign(const std::vector &keys, // Process and add the new data. Break the data into field element // representation. Also compute the items' cuckoo indices. - std::vector> data_with_indices = - PreprocessLabeledData(hashed_data.begin(), hashed_data.end(), params_); + std::vector> + data_with_indices = PreprocessLabeledData(hashed_data.begin(), + hashed_data.end(), params_); DispatchInsertOrAssign(data_with_indices, &bin_bundles_, crypto_context_, bins_per_bundle, label_size, max_bin_size, @@ -694,9 +698,9 @@ void SenderMemDB::InsertOrAssign(const std::vector &data) { // First compute the hashes for the input data std::vector oprf_out = oprf_server_->FullEvaluate(data); - std::vector hashed_data; + std::vector<::apsi::HashedItem> hashed_data; for (const auto &out : oprf_out) { - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; std::memcpy(value.data(), out.data(), value.size()); hashed_data.emplace_back(value); @@ -727,7 +731,7 @@ void SenderMemDB::InsertOrAssign(const std::vector &data) { // Break the new data down into its field element representation. Also compute // the items' cuckoo indices. - std::vector> data_with_indices = + std::vector> data_with_indices = PreprocessUnlabeledData(hashed_data.begin(), hashed_data.end(), params_); // Dispatch the insertion @@ -744,7 +748,7 @@ void SenderMemDB::InsertOrAssign(const std::vector &data) { SPDLOG_INFO("Finished inserting {} items in SenderDB", data.size()); } -void SenderMemDB::remove(const std::vector &data) { +void SenderMemDB::remove(const std::vector<::apsi::Item> &data) { if (stripped_) { SPDLOG_ERROR("Cannot remove data from a stripped SenderDB"); YACL_THROW("failed to remove data"); @@ -762,9 +766,9 @@ void SenderMemDB::remove(const std::vector &data) { data[i].value().size()); } std::vector oprf_out = oprf_server_->FullEvaluate(data_str); - std::vector hashed_data; + std::vector<::apsi::HashedItem> hashed_data; for (const auto &out : oprf_out) { - apsi::Item::value_type value{}; + ::apsi::Item::value_type value{}; std::memcpy(value.data(), out.data(), value.size()); hashed_data.emplace_back(value); @@ -797,7 +801,7 @@ void SenderMemDB::remove(const std::vector &data) { // Break the data down into its field element representation. Also compute the // items' cuckoo indices. - std::vector> data_with_indices = + std::vector> data_with_indices = PreprocessUnlabeledData(hashed_data.begin(), hashed_data.end(), params_); // Dispatch the removal @@ -810,7 +814,7 @@ void SenderMemDB::remove(const std::vector &data) { SPDLOG_INFO("Finished removing {} items from SenderDB", data.size()); } -bool SenderMemDB::HasItem(const apsi::Item &item) const { +bool SenderMemDB::HasItem(const ::apsi::Item &item) const { if (stripped_) { SPDLOG_ERROR( "Cannot retrieve the presence of an item from a stripped SenderDB"); @@ -823,7 +827,7 @@ bool SenderMemDB::HasItem(const apsi::Item &item) const { item_str.reserve(item.value().size()); std::memcpy(item_str.data(), item.value().data(), item.value().size()); std::string oprf_out = oprf_server_->FullEvaluate(item_str); - apsi::HashedItem hashed_item; + ::apsi::HashedItem hashed_item; std::memcpy(hashed_item.value().data(), oprf_out.data(), hashed_item.value().size()); @@ -833,7 +837,7 @@ bool SenderMemDB::HasItem(const apsi::Item &item) const { return hashed_items_.find(hashed_item) != hashed_items_.end(); } -apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { +::apsi::Label SenderMemDB::GetLabel(const ::apsi::Item &item) const { if (stripped_) { SPDLOG_ERROR("Cannot retrieve a label from a stripped SenderDB"); YACL_THROW("failed to retrieve label"); @@ -845,8 +849,8 @@ apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { } // First compute the hash for the input item - apsi::HashedItem hashed_item; - apsi::LabelKey key; + ::apsi::HashedItem hashed_item; + ::apsi::LabelKey key; // tie(hashed_item, key) = OPRFSender::GetItemHash(item, oprf_key_); std::string item_str; @@ -871,7 +875,7 @@ apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { // Preprocess a single element. This algebraizes the item and gives back its // field element representation as well as its cuckoo hash. We only read one // of the locations because the labels are the same in each location. - apsi::util::AlgItem alg_item; + ::apsi::util::AlgItem alg_item; size_t cuckoo_idx; std::tie(alg_item, cuckoo_idx) = PreprocessUnlabeledData(hashed_item, params_)[0]; @@ -879,11 +883,10 @@ apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { // Now figure out where to look to get the label size_t bin_idx; size_t bundle_idx; - std::tie(bin_idx, bundle_idx) = - labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + std::tie(bin_idx, bundle_idx) = UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); // Retrieve the algebraic labels from one of the BinBundles at this index - const std::vector> &bundle_set = + const std::vector> &bundle_set = bin_bundles_[bundle_idx]; std::vector alg_label; bool got_labels = false; @@ -906,7 +909,7 @@ apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { } // All good. Now just reconstruct the big label from its split-up parts - apsi::EncryptedLabel encrypted_label = dealgebraize_label( + ::apsi::EncryptedLabel encrypted_label = dealgebraize_label( alg_label, alg_label.size() * static_cast(params_.item_bit_count_per_felt()), params_.seal_params().plain_modulus()); @@ -956,9 +959,9 @@ void SenderMemDB::SetData( GenerateCaches(); } -std::shared_ptr SenderMemDB::GetBinBundleAt( +std::shared_ptr<::apsi::sender::BinBundle> SenderMemDB::GetBinBundleAt( std::uint32_t bundle_idx, size_t cache_idx) { return bin_bundles_[bundle_idx][cache_idx]; } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/sender_memdb.h b/psi/apsi/sender_memdb.h similarity index 89% rename from psi/psi/core/labeled_psi/sender_memdb.h rename to psi/apsi/sender_memdb.h index 2d8dd057..7b00b6fd 100644 --- a/psi/psi/core/labeled_psi/sender_memdb.h +++ b/psi/apsi/sender_memdb.h @@ -38,15 +38,15 @@ #include "apsi/psi_params.h" #include "yacl/base/byte_container_view.h" -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" -#include "psi/psi/core/labeled_psi/sender_db.h" +#include "psi/apsi/sender_db.h" +#include "psi/ecdh//basic_ecdh_oprf.h" // SEAL #include "seal/plaintext.h" #include "seal/util/locks.h" #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi::apsi { /** A SenderDB maintains an in-memory representation of the sender's set of items @@ -76,7 +76,7 @@ class SenderMemDB : public ISenderDB { /** Creates a new SenderDB. */ - SenderMemDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + SenderMemDB(const ::apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, std::size_t label_byte_count = 0, std::size_t nonce_byte_count = 16, bool compressed = true); @@ -153,39 +153,39 @@ class SenderMemDB : public ISenderDB { /** Removes the given data from the database, using at most thread_count threads. */ - void remove(const std::vector &data); + void remove(const std::vector<::apsi::Item> &data); /** Removes the given (hashed) item from the database. */ - void remove(const apsi::Item &data) { - std::vector data_singleton{data}; + void remove(const ::apsi::Item &data) { + std::vector<::apsi::Item> data_singleton{data}; remove(data_singleton); } /** Returns whether the given item has been inserted in the SenderDB. */ - bool HasItem(const apsi::Item &item) const; + bool HasItem(const ::apsi::Item &item) const; /** Returns the label associated to the given item in the database. Throws std::invalid_argument if the item does not appear in the database. */ - apsi::Label GetLabel(const apsi::Item &item) const; + ::apsi::Label GetLabel(const ::apsi::Item &item) const; /** Returns a set of cache references corresponding to the bundles at the given bundle index. Even though this function returns a vector, the order has no significance. This function is meant for internal use. */ - std::shared_ptr GetBinBundleAt( + std::shared_ptr<::apsi::sender::BinBundle> GetBinBundleAt( std::uint32_t bundle_idx, size_t cache_idx) override; /** Returns a reference to a set of item hashes already existing in the SenderDB. */ - const std::unordered_set &GetHashedItems() const { + const std::unordered_set<::apsi::HashedItem> &GetHashedItems() const { return hashed_items_; } @@ -212,16 +212,16 @@ class SenderMemDB : public ISenderDB { /** The set of all items that have been inserted into the database */ - std::unordered_set hashed_items_; + std::unordered_set<::apsi::HashedItem> hashed_items_; /** All the BinBundles in the database, indexed by bundle index. The set (represented by a vector internally) at bundle index i contains all the BinBundles with bundle index i. */ - std::vector>> + std::vector>> bin_bundles_; }; // class SenderDB -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/labeled_psi/serializable.proto b/psi/apsi/serializable.proto similarity index 98% rename from psi/psi/core/labeled_psi/serializable.proto rename to psi/apsi/serializable.proto index b7519000..45b1fff9 100644 --- a/psi/psi/core/labeled_psi/serializable.proto +++ b/psi/apsi/serializable.proto @@ -16,7 +16,7 @@ syntax = "proto3"; -package psi.psi.proto; +package psi.proto; message SealParamsProto { uint32 poly_modulus_degree = 1; diff --git a/psi/psi/core/labeled_psi/serialize.h b/psi/apsi/serialize.h similarity index 75% rename from psi/psi/core/labeled_psi/serialize.h rename to psi/apsi/serialize.h index 7af467b8..318760ec 100644 --- a/psi/psi/core/labeled_psi/serialize.h +++ b/psi/apsi/serialize.h @@ -22,11 +22,11 @@ #include "apsi/util/db_encoding.h" #include "spdlog/spdlog.h" -#include "psi/psi/core/labeled_psi/serializable.pb.h" +#include "psi/apsi/serializable.pb.h" -namespace psi::psi { +namespace psi::apsi { -inline void SerializeAlgItem(const apsi::util::AlgItem& alg_items, +inline void SerializeAlgItem(const ::apsi::util::AlgItem& alg_items, proto::AlgItemProto* proto) { for (auto& alg_item : alg_items) { proto->add_item(alg_item); @@ -34,7 +34,7 @@ inline void SerializeAlgItem(const apsi::util::AlgItem& alg_items, } inline proto::AlgItemProto SerializeAlgItem( - const apsi::util::AlgItem& alg_items) { + const ::apsi::util::AlgItem& alg_items) { proto::AlgItemProto proto; SerializeAlgItem(alg_items, &proto); @@ -43,7 +43,7 @@ inline proto::AlgItemProto SerializeAlgItem( } inline std::string SerializeAlgItemToString( - const apsi::util::AlgItem& alg_items) { + const ::apsi::util::AlgItem& alg_items) { proto::AlgItemProto proto = SerializeAlgItem(alg_items); std::string item_string(proto.ByteSizeLong(), '\0'); @@ -52,9 +52,9 @@ inline std::string SerializeAlgItemToString( return item_string; } -inline apsi::util::AlgItem DeserializeAlgItem( +inline ::apsi::util::AlgItem DeserializeAlgItem( const proto::AlgItemProto& proto) { - apsi::util::AlgItem alg_items; + ::apsi::util::AlgItem alg_items; alg_items.resize(proto.item_size()); for (int i = 0; i < proto.item_size(); ++i) { @@ -64,8 +64,8 @@ inline apsi::util::AlgItem DeserializeAlgItem( return alg_items; } -inline apsi::util::AlgItem DeserializeAlgItem(const absl::string_view& buf) { - apsi::util::AlgItem alg_items; +inline ::apsi::util::AlgItem DeserializeAlgItem(const absl::string_view& buf) { + ::apsi::util::AlgItem alg_items; proto::AlgItemProto proto; proto.ParseFromArray(buf.data(), buf.length()); @@ -73,7 +73,7 @@ inline apsi::util::AlgItem DeserializeAlgItem(const absl::string_view& buf) { } inline void SerializeAlgItemLabel( - const apsi::util::AlgItemLabel& item_label_pair, + const ::apsi::util::AlgItemLabel& item_label_pair, proto::AlgItemLabelProto* proto) { for (size_t i = 0; i < item_label_pair.size(); ++i) { proto::AlgItemLabelPairProto* pair_proto = proto->add_item_label(); @@ -86,7 +86,7 @@ inline void SerializeAlgItemLabel( } inline proto::AlgItemLabelProto SerializeAlgItemLabel( - const apsi::util::AlgItemLabel& item_label_pair) { + const ::apsi::util::AlgItemLabel& item_label_pair) { proto::AlgItemLabelProto proto; SerializeAlgItemLabel(item_label_pair, &proto); @@ -94,7 +94,7 @@ inline proto::AlgItemLabelProto SerializeAlgItemLabel( } inline std::string SerializeAlgItemLabelToString( - const apsi::util::AlgItemLabel& item_label_pair) { + const ::apsi::util::AlgItemLabel& item_label_pair) { proto::AlgItemLabelProto proto = SerializeAlgItemLabel(item_label_pair); std::string item_string(proto.ByteSizeLong(), '\0'); @@ -103,17 +103,17 @@ inline std::string SerializeAlgItemLabelToString( return item_string; } -inline apsi::util::AlgItemLabel DeserializeAlgItemLabel( +inline ::apsi::util::AlgItemLabel DeserializeAlgItemLabel( const proto::AlgItemLabelProto& proto) { - apsi::util::AlgItemLabel item_label_pair; + ::apsi::util::AlgItemLabel item_label_pair; for (int i = 0; i < proto.item_label_size(); ++i) { auto pair_proto = proto.item_label(i); auto label_data = pair_proto.label_data(); - std::vector labels(label_data.size() / - sizeof(apsi::util::felt_t)); + std::vector<::apsi::util::felt_t> labels(label_data.size() / + sizeof(::apsi::util::felt_t)); std::memcpy(labels.data(), label_data.data(), label_data.size()); item_label_pair.emplace_back(pair_proto.item(), labels); @@ -122,7 +122,7 @@ inline apsi::util::AlgItemLabel DeserializeAlgItemLabel( return item_label_pair; } -inline apsi::util::AlgItemLabel DeserializeAlgItemLabel( +inline ::apsi::util::AlgItemLabel DeserializeAlgItemLabel( const absl::string_view& buf) { proto::AlgItemLabelProto proto; proto.ParseFromArray(buf.data(), buf.size()); @@ -131,7 +131,7 @@ inline apsi::util::AlgItemLabel DeserializeAlgItemLabel( } inline std::string SerializeDataWithIndices( - const std::pair& data_with_indices) { + const std::pair<::apsi::util::AlgItem, size_t>& data_with_indices) { proto::DataWithIndicesProto proto; proto::AlgItemProto* item_proto = new proto::AlgItemProto(); @@ -146,18 +146,18 @@ inline std::string SerializeDataWithIndices( return item_string; } -inline std::pair DeserializeDataWithIndices( +inline std::pair<::apsi::util::AlgItem, size_t> DeserializeDataWithIndices( const absl::string_view& buf) { proto::DataWithIndicesProto proto; proto.ParseFromArray(buf.data(), buf.size()); - apsi::util::AlgItem alg_item = DeserializeAlgItem(proto.data()); + ::apsi::util::AlgItem alg_item = DeserializeAlgItem(proto.data()); return std::make_pair(alg_item, proto.index()); } inline std::string SerializeDataLabelWithIndices( - const std::pair& data_with_indices) { + const std::pair<::apsi::util::AlgItemLabel, size_t>& data_with_indices) { proto::DataLabelWithIndicesProto proto; proto::AlgItemLabelProto* item_proto = new proto::AlgItemLabelProto(); @@ -172,14 +172,14 @@ inline std::string SerializeDataLabelWithIndices( return item_string; } -inline std::pair +inline std::pair<::apsi::util::AlgItemLabel, size_t> DeserializeDataLabelWithIndices(const absl::string_view& buf) { proto::DataLabelWithIndicesProto proto; proto.ParseFromArray(buf.data(), buf.size()); - apsi::util::AlgItemLabel alg_item = DeserializeAlgItemLabel(proto.data()); + ::apsi::util::AlgItemLabel alg_item = DeserializeAlgItemLabel(proto.data()); return std::make_pair(alg_item, proto.index()); } -} // namespace psi::psi +} // namespace psi::apsi diff --git a/psi/psi/core/bc22_psi/BUILD.bazel b/psi/bc22/BUILD.bazel similarity index 90% rename from psi/psi/core/bc22_psi/BUILD.bazel rename to psi/bc22/BUILD.bazel index 102e9a98..6824f1aa 100644 --- a/psi/psi/core/bc22_psi/BUILD.bazel +++ b/psi/bc22/BUILD.bazel @@ -25,8 +25,8 @@ psi_cc_library( deps = [ ":emp_vole", ":generalized_cuckoo_hash", - "//psi/psi/core:communication", - "//psi/psi/utils:serialize", + "//psi/utils:communication", + "//psi/utils:serialize", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@yacl//yacl/base:exception", @@ -51,9 +51,9 @@ psi_cc_library( hdrs = ["emp_vole.h"], copts = AES_COPT_FLAGS, deps = [ - "//psi/psi/core:communication", - "//psi/psi/utils:emp_io_adapter", - "//psi/psi/utils:serialize", + "//psi/utils:communication", + "//psi/utils:emp_io_adapter", + "//psi/utils:serialize", "@com_github_emptoolkit_emp_zk//:emp-zk", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -79,7 +79,7 @@ psi_cc_library( srcs = ["generalized_cuckoo_hash.cc"], hdrs = ["generalized_cuckoo_hash.h"], deps = [ - "//psi/psi/core:cuckoo_index", + "//psi/utils:cuckoo_index", "@com_google_absl//absl/strings", "@yacl//yacl/base:byte_container_view", "@yacl//yacl/base:exception", @@ -101,8 +101,8 @@ psi_cc_test( ) psi_cc_binary( - name = "bc22_psi_bench", - srcs = ["bc22_psi_bench.cc"], + name = "bc22_psi_benchmark", + srcs = ["bc22_psi_benchmark.cc"], deps = [ ":bc22_psi", "@com_github_google_benchmark//:benchmark_main", diff --git a/psi/psi/core/bc22_psi/bc22_psi.cc b/psi/bc22/bc22_psi.cc similarity index 98% rename from psi/psi/core/bc22_psi/bc22_psi.cc rename to psi/bc22/bc22_psi.cc index d178c166..11de3f5a 100644 --- a/psi/psi/core/bc22_psi/bc22_psi.cc +++ b/psi/bc22/bc22_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/bc22_psi.h" +#include "psi/bc22/bc22_psi.h" #include #include @@ -27,12 +27,12 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/bc22_psi/emp_vole.h" -#include "psi/psi/utils/serialize.h" +#include "psi/bc22/emp_vole.h" +#include "psi/utils/serialize.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi { +namespace psi::bc22 { namespace { @@ -557,4 +557,4 @@ void Bc22PcgPsi::PcgPsiRecvOprf(absl::Span items, SPDLOG_INFO("after compute intersection"); } -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/bc22_psi.h b/psi/bc22/bc22_psi.h similarity index 93% rename from psi/psi/core/bc22_psi/bc22_psi.h rename to psi/bc22/bc22_psi.h index 007e7959..1094102e 100644 --- a/psi/psi/core/bc22_psi/bc22_psi.h +++ b/psi/bc22/bc22_psi.h @@ -21,11 +21,11 @@ #include "absl/types/span.h" #include "yacl/base/int128.h" -#include "psi/psi/core/bc22_psi/generalized_cuckoo_hash.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/utils/serialize.h" +#include "psi/bc22/generalized_cuckoo_hash.h" +#include "psi/utils/communication.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi::bc22 { // PSI from Pseudorandom Correlation Generators // https://eprint.iacr.org/2022/334 @@ -88,4 +88,4 @@ class Bc22PcgPsi { std::vector results_; }; -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/bc22_psi_bench.cc b/psi/bc22/bc22_psi_benchmark.cc similarity index 97% rename from psi/psi/core/bc22_psi/bc22_psi_bench.cc rename to psi/bc22/bc22_psi_benchmark.cc index 167114db..d0c23246 100644 --- a/psi/psi/core/bc22_psi/bc22_psi_bench.cc +++ b/psi/bc22/bc22_psi_benchmark.cc @@ -23,9 +23,9 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/bc22_psi/bc22_psi.h" +#include "psi/bc22/bc22_psi.h" -namespace psi::psi { +namespace psi::bc22 { namespace { @@ -144,4 +144,4 @@ BENCHMARK(BM_PcgPsi) ->Arg(8 << 20) ->Arg(16 << 20); -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/bc22_psi_test.cc b/psi/bc22/bc22_psi_test.cc similarity index 97% rename from psi/psi/core/bc22_psi/bc22_psi_test.cc rename to psi/bc22/bc22_psi_test.cc index 3b22bcf1..de4af50a 100644 --- a/psi/psi/core/bc22_psi/bc22_psi_test.cc +++ b/psi/bc22/bc22_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/bc22_psi.h" +#include "psi/bc22/bc22_psi.h" #include #include @@ -26,7 +26,7 @@ #include "yacl/link/test_util.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi::bc22 { namespace { @@ -118,4 +118,4 @@ TEST_P(PcgPsiTest, Works) { INSTANTIATE_TEST_SUITE_P(Works_Instances, PcgPsiTest, testing::Values(10000, 100000, 1000000)); -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/emp_vole.cc b/psi/bc22/emp_vole.cc similarity index 97% rename from psi/psi/core/bc22_psi/emp_vole.cc rename to psi/bc22/emp_vole.cc index 8a7b449f..794ca967 100644 --- a/psi/psi/core/bc22_psi/emp_vole.cc +++ b/psi/bc22/emp_vole.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/emp_vole.h" +#include "psi/bc22/emp_vole.h" #include #include "spdlog/spdlog.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi { +namespace psi::bc22 { WolverineVole::WolverineVole(PsiRoleType psi_role, std::shared_ptr link_ctx) @@ -143,4 +143,4 @@ WolverineVoleFieldType EvaluatePolynomial( return EvaluatePolynomial(coeffs, block_x, high_coeff); } -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/emp_vole.h b/psi/bc22/emp_vole.h similarity index 93% rename from psi/psi/core/bc22_psi/emp_vole.h rename to psi/bc22/emp_vole.h index 2d14cbba..0d7ed048 100644 --- a/psi/psi/core/bc22_psi/emp_vole.h +++ b/psi/bc22/emp_vole.h @@ -25,11 +25,11 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/utils/emp_io_adapter.h" -#include "psi/psi/utils/serialize.h" +#include "psi/utils/communication.h" +#include "psi/utils/emp_io_adapter.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi::bc22 { inline constexpr size_t kVoleSilentOTThreads = 1; @@ -89,4 +89,4 @@ WolverineVoleFieldType EvaluatePolynomial( absl::Span coeffs, WolverineVoleFieldType x, WolverineVoleFieldType high_coeff); -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/emp_vole_test.cc b/psi/bc22/emp_vole_test.cc similarity index 96% rename from psi/psi/core/bc22_psi/emp_vole_test.cc rename to psi/bc22/emp_vole_test.cc index b7314d7e..bd938f94 100644 --- a/psi/psi/core/bc22_psi/emp_vole_test.cc +++ b/psi/bc22/emp_vole_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/emp_vole.h" +#include "psi/bc22/emp_vole.h" #include #include @@ -25,9 +25,9 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/serialize.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi::bc22 { class EmpVoleTest : public testing::TestWithParam {}; @@ -109,4 +109,4 @@ TEST(EmpVoleTest, PolynomialTest) { } } -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/generalized_cuckoo_hash.cc b/psi/bc22/generalized_cuckoo_hash.cc similarity index 98% rename from psi/psi/core/bc22_psi/generalized_cuckoo_hash.cc rename to psi/bc22/generalized_cuckoo_hash.cc index 61ea2fa1..5dc7cec1 100644 --- a/psi/psi/core/bc22_psi/generalized_cuckoo_hash.cc +++ b/psi/bc22/generalized_cuckoo_hash.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/generalized_cuckoo_hash.h" +#include "psi/bc22/generalized_cuckoo_hash.h" #include #include @@ -23,7 +23,7 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi::bc22 { namespace { @@ -259,4 +259,4 @@ void SimpleHashTable::Insert(absl::Span items) { } } -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/generalized_cuckoo_hash.h b/psi/bc22/generalized_cuckoo_hash.h similarity index 97% rename from psi/psi/core/bc22_psi/generalized_cuckoo_hash.h rename to psi/bc22/generalized_cuckoo_hash.h index d025432c..3594fbc6 100644 --- a/psi/psi/core/bc22_psi/generalized_cuckoo_hash.h +++ b/psi/bc22/generalized_cuckoo_hash.h @@ -25,9 +25,9 @@ #include "yacl/base/int128.h" #include "yacl/crypto/base/hash/hash_utils.h" -#include "psi/psi/core/cuckoo_index.h" +#include "psi/utils/cuckoo_index.h" -namespace psi::psi { +namespace psi::bc22 { // GeneralizedCuckooHash options // now support (2,2), (3,2) gch @@ -120,4 +120,4 @@ class SimpleHashTable : public IPsiHashTable { std::vector conflict_idx_; }; -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/core/bc22_psi/generalized_cuckoo_hash_test.cc b/psi/bc22/generalized_cuckoo_hash_test.cc similarity index 97% rename from psi/psi/core/bc22_psi/generalized_cuckoo_hash_test.cc rename to psi/bc22/generalized_cuckoo_hash_test.cc index a31c2b6a..e97b6c47 100644 --- a/psi/psi/core/bc22_psi/generalized_cuckoo_hash_test.cc +++ b/psi/bc22/generalized_cuckoo_hash_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/bc22_psi/generalized_cuckoo_hash.h" +#include "psi/bc22/generalized_cuckoo_hash.h" #include #include @@ -21,7 +21,7 @@ #include "gtest/gtest.h" #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi::bc22 { namespace { @@ -125,4 +125,4 @@ TEST(GchTest, CuckooHashTest) { SPDLOG_INFO("conflict_idx: {}", conflict_idx.size()); } -} // namespace psi::psi +} // namespace psi::bc22 diff --git a/psi/psi/cryptor/BUILD.bazel b/psi/cryptor/BUILD.bazel similarity index 99% rename from psi/psi/cryptor/BUILD.bazel rename to psi/cryptor/BUILD.bazel index 2e222b34..4405fb58 100644 --- a/psi/psi/cryptor/BUILD.bazel +++ b/psi/cryptor/BUILD.bazel @@ -98,7 +98,7 @@ psi_cc_library( ":sm2_cryptor", ":sodium_curve25519_cryptor", "//psi/proto:psi_v2_cc_proto", - "//psi/psi:prelude", + "//psi:prelude", "@yacl//yacl/utils:platform_utils", ] + select({ "@platforms//cpu:x86_64": [ diff --git a/psi/psi/cryptor/cryptor_selector.cc b/psi/cryptor/cryptor_selector.cc similarity index 90% rename from psi/psi/cryptor/cryptor_selector.cc rename to psi/cryptor/cryptor_selector.cc index 145845ca..34e3bcc6 100644 --- a/psi/psi/cryptor/cryptor_selector.cc +++ b/psi/cryptor/cryptor_selector.cc @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/cryptor_selector.h" +#include "psi/cryptor/cryptor_selector.h" #include #include "spdlog/spdlog.h" #include "yacl/utils/platform_utils.h" -#include "psi/psi/cryptor/fourq_cryptor.h" -#include "psi/psi/cryptor/sm2_cryptor.h" -#include "psi/psi/cryptor/sodium_curve25519_cryptor.h" -#include "psi/psi/prelude.h" +#include "psi/cryptor/fourq_cryptor.h" +#include "psi/cryptor/sm2_cryptor.h" +#include "psi/cryptor/sodium_curve25519_cryptor.h" +#include "psi/prelude.h" #ifdef __x86_64__ -#include "psi/psi/cryptor/ipp_ecc_cryptor.h" +#include "psi/cryptor/ipp_ecc_cryptor.h" #endif -namespace psi::psi { +namespace psi { namespace { @@ -118,4 +118,4 @@ std::unique_ptr CreateEccCryptor(CurveType type) { return cryptor; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/cryptor_selector.h b/psi/cryptor/cryptor_selector.h similarity index 88% rename from psi/psi/cryptor/cryptor_selector.h rename to psi/cryptor/cryptor_selector.h index 76d7c4f6..95966c24 100644 --- a/psi/psi/cryptor/cryptor_selector.h +++ b/psi/cryptor/cryptor_selector.h @@ -16,10 +16,10 @@ #include -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { std::unique_ptr CreateEccCryptor(CurveType type); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ecc_cryptor.cc b/psi/cryptor/ecc_cryptor.cc similarity index 97% rename from psi/psi/cryptor/ecc_cryptor.cc rename to psi/cryptor/ecc_cryptor.cc index c4b6c0e4..ec123894 100644 --- a/psi/psi/cryptor/ecc_cryptor.cc +++ b/psi/cryptor/ecc_cryptor.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" #include #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi { namespace { std::string CreateFlattenEccBuffer(const std::vector& items, @@ -107,4 +107,4 @@ std::vector HashInputs(const std::shared_ptr& cryptor, return ret; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ecc_cryptor.h b/psi/cryptor/ecc_cryptor.h similarity index 98% rename from psi/psi/cryptor/ecc_cryptor.h rename to psi/cryptor/ecc_cryptor.h index ad9d93f7..27b8adb8 100644 --- a/psi/psi/cryptor/ecc_cryptor.h +++ b/psi/cryptor/ecc_cryptor.h @@ -28,7 +28,7 @@ #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { inline constexpr int kEccKeySize = 32; @@ -80,4 +80,4 @@ std::string HashInput(const std::shared_ptr& cryptor, std::vector HashInputs(const std::shared_ptr& cryptor, const std::vector& items); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ecc_utils.h b/psi/cryptor/ecc_utils.h similarity index 99% rename from psi/psi/cryptor/ecc_utils.h rename to psi/cryptor/ecc_utils.h index 622ca1e4..fe8ab2c2 100644 --- a/psi/psi/cryptor/ecc_utils.h +++ b/psi/cryptor/ecc_utils.h @@ -25,7 +25,7 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi { inline constexpr size_t kEcPointCompressLength = 33; inline constexpr size_t kEc256KeyLength = 32; @@ -262,4 +262,4 @@ struct EcPointSt { ECPointPtr point_ptr; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ecc_utils_test.cc b/psi/cryptor/ecc_utils_test.cc similarity index 92% rename from psi/psi/cryptor/ecc_utils_test.cc rename to psi/cryptor/ecc_utils_test.cc index 611396e9..159feee5 100644 --- a/psi/psi/cryptor/ecc_utils_test.cc +++ b/psi/cryptor/ecc_utils_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/ecc_utils.h" +#include "psi/cryptor/ecc_utils.h" #include "gtest/gtest.h" -namespace psi::psi::test { +namespace psi::test { TEST(EcPointStTest, HashToCurveWorks) { EcGroupSt curve(NID_sm2); @@ -35,4 +35,4 @@ TEST(EcPointStTest, HashToCurveWorks) { } } -} // namespace psi::psi::test +} // namespace psi::test diff --git a/psi/psi/cryptor/fourq_cryptor.cc b/psi/cryptor/fourq_cryptor.cc similarity index 96% rename from psi/psi/cryptor/fourq_cryptor.cc rename to psi/cryptor/fourq_cryptor.cc index e9481e16..35d7fb4e 100644 --- a/psi/psi/cryptor/fourq_cryptor.cc +++ b/psi/cryptor/fourq_cryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/fourq_cryptor.h" +#include "psi/cryptor/fourq_cryptor.h" #include @@ -21,7 +21,7 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi { void FourQEccCryptor::EccMask(absl::Span batch_points, absl::Span dest_points) const { @@ -72,4 +72,4 @@ std::vector FourQEccCryptor::HashToCurve( return ret; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/fourq_cryptor.h b/psi/cryptor/fourq_cryptor.h similarity index 92% rename from psi/psi/cryptor/fourq_cryptor.h rename to psi/cryptor/fourq_cryptor.h index 3adaff5e..6d91243c 100644 --- a/psi/psi/cryptor/fourq_cryptor.h +++ b/psi/cryptor/fourq_cryptor.h @@ -18,9 +18,9 @@ #include "openssl/rand.h" #include "yacl/base/exception.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { class FourQEccCryptor : public IEccCryptor { public: @@ -36,4 +36,4 @@ class FourQEccCryptor : public IEccCryptor { std::vector HashToCurve(absl::Span input) const override; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/cryptor/fpga_ecc_cryptor.h b/psi/cryptor/fpga_ecc_cryptor.h similarity index 90% rename from psi/psi/cryptor/fpga_ecc_cryptor.h rename to psi/cryptor/fpga_ecc_cryptor.h index 92156073..afffece8 100644 --- a/psi/psi/cryptor/fpga_ecc_cryptor.h +++ b/psi/cryptor/fpga_ecc_cryptor.h @@ -14,9 +14,9 @@ #pragma once -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { class FPGAEccCryptor : public IEccCryptor { public: @@ -25,4 +25,4 @@ class FPGAEccCryptor : public IEccCryptor { absl::Span dest_points) const override; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/cryptor/hash_to_curve_elligator2.cc b/psi/cryptor/hash_to_curve_elligator2.cc similarity index 99% rename from psi/psi/cryptor/hash_to_curve_elligator2.cc rename to psi/cryptor/hash_to_curve_elligator2.cc index 2c11a079..b22844e5 100644 --- a/psi/psi/cryptor/hash_to_curve_elligator2.cc +++ b/psi/cryptor/hash_to_curve_elligator2.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/hash_to_curve_elligator2.h" +#include "psi/cryptor/hash_to_curve_elligator2.h" #include #include @@ -24,7 +24,7 @@ #include "yacl/crypto/base/hash/ssl_hash.h" #include "yacl/math/mpint/mp_int.h" -namespace psi::psi { +namespace psi { namespace { @@ -705,4 +705,4 @@ std::vector HashToCurveElligator2(yacl::ByteContainerView buffer, return PointClearCofactorProjective(px1); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/hash_to_curve_elligator2.h b/psi/cryptor/hash_to_curve_elligator2.h similarity index 96% rename from psi/psi/cryptor/hash_to_curve_elligator2.h rename to psi/cryptor/hash_to_curve_elligator2.h index 40cd733a..24e7241d 100644 --- a/psi/psi/cryptor/hash_to_curve_elligator2.h +++ b/psi/cryptor/hash_to_curve_elligator2.h @@ -19,7 +19,7 @@ #include "yacl/base/byte_container_view.h" -namespace psi::psi { +namespace psi { // RFC9380 Hashing to Elliptic Curves // https://datatracker.ietf.org/doc/rfc9380/ @@ -36,4 +36,4 @@ std::vector HashToCurveElligator2( const std::string &dst = "SECRETFLOW-V01-CS02-with-curve25519_XMD:SHA-512_ELL2_RO_"); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/hash_to_curve_elligator2_test.cc b/psi/cryptor/hash_to_curve_elligator2_test.cc similarity index 95% rename from psi/psi/cryptor/hash_to_curve_elligator2_test.cc rename to psi/cryptor/hash_to_curve_elligator2_test.cc index 8563c50a..f36cbd31 100644 --- a/psi/psi/cryptor/hash_to_curve_elligator2_test.cc +++ b/psi/cryptor/hash_to_curve_elligator2_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/hash_to_curve_elligator2.h" +#include "psi/cryptor/hash_to_curve_elligator2.h" #include #include @@ -21,7 +21,7 @@ #include "gtest/gtest.h" #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi { namespace { @@ -57,4 +57,4 @@ TEST(Elligator2Test, HashToCurve) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ipp_ecc_cryptor.cc b/psi/cryptor/ipp_ecc_cryptor.cc similarity index 94% rename from psi/psi/cryptor/ipp_ecc_cryptor.cc rename to psi/cryptor/ipp_ecc_cryptor.cc index 0a05ca10..103efb35 100644 --- a/psi/psi/cryptor/ipp_ecc_cryptor.cc +++ b/psi/cryptor/ipp_ecc_cryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/ipp_ecc_cryptor.h" +#include "psi/cryptor/ipp_ecc_cryptor.h" #include #include @@ -20,9 +20,9 @@ #include "crypto_mb/x25519.h" #include "yacl/utils/parallel.h" -#include "psi/psi/cryptor/hash_to_curve_elligator2.h" +#include "psi/cryptor/hash_to_curve_elligator2.h" -namespace psi::psi { +namespace psi { void IppEccCryptor::EccMask(absl::Span batch_points, absl::Span dest_points) const { @@ -79,4 +79,4 @@ std::vector IppElligator2Cryptor::HashToCurve( return HashToCurveElligator2(item_data); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/ipp_ecc_cryptor.h b/psi/cryptor/ipp_ecc_cryptor.h similarity index 93% rename from psi/psi/cryptor/ipp_ecc_cryptor.h rename to psi/cryptor/ipp_ecc_cryptor.h index 03479899..4d064265 100644 --- a/psi/psi/cryptor/ipp_ecc_cryptor.h +++ b/psi/cryptor/ipp_ecc_cryptor.h @@ -20,9 +20,9 @@ #include "openssl/rand.h" #include "yacl/base/exception.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { class IppEccCryptor : public IEccCryptor { public: @@ -41,4 +41,4 @@ class IppElligator2Cryptor : public IppEccCryptor { absl::Span item_data) const override; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/sm2_cryptor.cc b/psi/cryptor/sm2_cryptor.cc similarity index 95% rename from psi/psi/cryptor/sm2_cryptor.cc rename to psi/cryptor/sm2_cryptor.cc index ab70c6fc..4a39be46 100644 --- a/psi/psi/cryptor/sm2_cryptor.cc +++ b/psi/cryptor/sm2_cryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/sm2_cryptor.h" +#include "psi/cryptor/sm2_cryptor.h" #include "absl/types/span.h" #include "openssl/bn.h" @@ -20,9 +20,9 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/cryptor/ecc_utils.h" +#include "psi/cryptor/ecc_utils.h" -namespace psi::psi { +namespace psi { void Sm2Cryptor::EccMask(absl::Span batch_points, absl::Span dest_points) const { @@ -85,4 +85,4 @@ std::vector Sm2Cryptor::HashToCurve( return out; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/sm2_cryptor.h b/psi/cryptor/sm2_cryptor.h similarity index 95% rename from psi/psi/cryptor/sm2_cryptor.h rename to psi/cryptor/sm2_cryptor.h index 49bb0158..0bc02656 100644 --- a/psi/psi/cryptor/sm2_cryptor.h +++ b/psi/cryptor/sm2_cryptor.h @@ -20,9 +20,9 @@ #include "openssl/evp.h" #include "yacl/base/exception.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { class Sm2Cryptor : public IEccCryptor { public: explicit Sm2Cryptor(CurveType type = CurveType::CURVE_SM2) @@ -63,4 +63,4 @@ class Sm2Cryptor : public IEccCryptor { CurveType curve_type_ = CurveType::CURVE_SM2; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/sm2_cryptor_test.cc b/psi/cryptor/sm2_cryptor_test.cc similarity index 97% rename from psi/psi/cryptor/sm2_cryptor_test.cc rename to psi/cryptor/sm2_cryptor_test.cc index 15965ae2..aed99f3b 100644 --- a/psi/psi/cryptor/sm2_cryptor_test.cc +++ b/psi/cryptor/sm2_cryptor_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/sm2_cryptor.h" +#include "psi/cryptor/sm2_cryptor.h" #include #include @@ -23,7 +23,7 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/crypto/tools/prg.h" -namespace psi::psi { +namespace psi { struct TestParams { size_t items_size; @@ -101,4 +101,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{50, CurveType::CURVE_SECP256K1}, TestParams{100, CurveType::CURVE_SECP256K1})); -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/cryptor/sodium_curve25519_cryptor.cc b/psi/cryptor/sodium_curve25519_cryptor.cc similarity index 94% rename from psi/psi/cryptor/sodium_curve25519_cryptor.cc rename to psi/cryptor/sodium_curve25519_cryptor.cc index f886a2b7..cd6687c7 100644 --- a/psi/psi/cryptor/sodium_curve25519_cryptor.cc +++ b/psi/cryptor/sodium_curve25519_cryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/cryptor/sodium_curve25519_cryptor.h" +#include "psi/cryptor/sodium_curve25519_cryptor.h" extern "C" { #include "sodium.h" @@ -23,9 +23,9 @@ extern "C" { #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/cryptor/hash_to_curve_elligator2.h" +#include "psi/cryptor/hash_to_curve_elligator2.h" -namespace psi::psi { +namespace psi { void SodiumCurve25519Cryptor::EccMask(absl::Span batch_points, absl::Span dest_points) const { @@ -79,4 +79,4 @@ std::vector SodiumElligator2Cryptor::HashToCurve( return HashToCurveElligator2(item_data); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/cryptor/sodium_curve25519_cryptor.h b/psi/cryptor/sodium_curve25519_cryptor.h similarity index 95% rename from psi/psi/cryptor/sodium_curve25519_cryptor.h rename to psi/cryptor/sodium_curve25519_cryptor.h index 2e59e47f..74895643 100644 --- a/psi/psi/cryptor/sodium_curve25519_cryptor.h +++ b/psi/cryptor/sodium_curve25519_cryptor.h @@ -22,9 +22,9 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi { // // avx2 asm code reference: @@ -68,4 +68,4 @@ class SodiumElligator2Cryptor : public SodiumCurve25519Cryptor { absl::Span item_data) const override; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/ecdh/BUILD.bazel b/psi/ecdh/BUILD.bazel new file mode 100644 index 00000000..c7418fed --- /dev/null +++ b/psi/ecdh/BUILD.bazel @@ -0,0 +1,237 @@ +# Copyright 2022 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_library( + name = "ecdh_oprf", + srcs = ["ecdh_oprf.cc"], + hdrs = ["ecdh_oprf.h"], + # Openssl::libcrypto requires `dlopen`... + linkopts = ["-ldl"], + deps = [ + "//psi/cryptor:ecc_cryptor", + "@com_github_openssl_openssl//:openssl", + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:byte_container_view", + "@yacl//yacl/base:exception", + "@yacl//yacl/utils:parallel", + ], +) + +psi_cc_library( + name = "ecdh_oprf_selector", + srcs = ["ecdh_oprf_selector.cc"], + hdrs = ["ecdh_oprf_selector.h"], + deps = [ + ":basic_ecdh_oprf", + "@yacl//yacl/utils:platform_utils", + ], +) + +psi_cc_library( + name = "basic_ecdh_oprf", + srcs = ["basic_ecdh_oprf.cc"], + hdrs = ["basic_ecdh_oprf.h"], + defines = [ + "__LINUX__", + ] + select({ + "@bazel_tools//src/conditions:linux_x86_64": [ + "_AMD64_", + "_ASM_", + ], + "@bazel_tools//src/conditions:darwin_arm64": [ + "_ARM64_", + ], + "//conditions:default": [ + "_AMD64_", + ], + }), + deps = [ + ":ecdh_oprf", + "//psi/cryptor:ecc_utils", + "//psi/cryptor:sm2_cryptor", + "@com_github_microsoft_apsi//:apsi", + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:exception", + "@yacl//yacl/crypto/base/hash:blake3", + "@yacl//yacl/crypto/base/hash:hash_utils", + "@yacl//yacl/utils:parallel", + ], +) + +psi_cc_test( + name = "basic_ecdh_oprf_test", + srcs = ["basic_ecdh_oprf_test.cc"], + deps = [ + ":ecdh_oprf_selector", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/crypto/utils:rand", + ], +) + +psi_cc_library( + name = "ecdh_psi", + srcs = ["ecdh_psi.cc"], + hdrs = ["ecdh_psi.h"], + deps = [ + ":ecdh_logger", + "//psi/cryptor:cryptor_selector", + "//psi/utils:batch_provider", + "//psi/utils:communication", + "//psi/utils:ec_point_store", + "//psi/utils:recovery", + "@com_google_absl//absl/strings", + "@yacl//yacl/link", + "@yacl//yacl/utils:parallel", + ], +) + +psi_cc_test( + name = "ecdh_psi_test", + srcs = ["ecdh_psi_test.cc"], + deps = [ + ":ecdh_psi", + "//psi/utils:test_utils", + ], +) + +psi_cc_binary( + name = "ecdh_psi_benchmark", + srcs = ["ecdh_psi_benchmark.cc"], + deps = [ + ":ecdh_psi", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "ecdh_3pc_psi", + srcs = ["ecdh_3pc_psi.cc"], + hdrs = ["ecdh_3pc_psi.h"], + deps = [ + ":ecdh_psi", + ], +) + +psi_cc_test( + name = "ecdh_3pc_psi_test", + srcs = ["ecdh_3pc_psi_test.cc"], + deps = [ + ":ecdh_3pc_psi", + "//psi/utils:test_utils", + ], +) + +psi_cc_binary( + name = "ecdh_3pc_psi_benchmark", + srcs = ["ecdh_3pc_psi_benchmark.cc"], + deps = [ + ":ecdh_3pc_psi", + "//psi/utils:test_utils", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "ecdh_oprf_psi", + srcs = ["ecdh_oprf_psi.cc"], + hdrs = ["ecdh_oprf_psi.h"], + deps = [ + ":ecdh_oprf_selector", + "//psi/utils:batch_provider", + "//psi/utils:communication", + "//psi/utils:ec_point_store", + "//psi/utils:ub_psi_cache", + "@com_google_absl//absl/strings", + "@yacl//yacl/base:exception", + "@yacl//yacl/link", + "@yacl//yacl/utils:parallel", + ], +) + +psi_cc_test( + name = "ecdh_oprf_psi_test", + srcs = ["ecdh_oprf_psi_test.cc"], + deps = [ + ":ecdh_oprf_psi", + "//psi/utils:test_utils", + "@boost//:uuid", + "@com_google_absl//absl/time", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/crypto/utils:rand", + "@yacl//yacl/utils:scope_guard", + ], +) + +psi_cc_library( + name = "ecdh_logger", + hdrs = ["ecdh_logger.h"], + deps = [ + "//psi/cryptor:ecc_cryptor", + "@yacl//yacl/base:exception", + ], +) + +psi_cc_library( + name = "common", + hdrs = ["common.h"], +) + +psi_cc_library( + name = "receiver", + srcs = ["receiver.cc"], + hdrs = ["receiver.h"], + deps = [ + ":common", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], +) + +psi_cc_library( + name = "sender", + srcs = ["sender.cc"], + hdrs = ["sender.h"], + deps = [ + ":common", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], +) + +psi_cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":ecdh_oprf_psi", + "//psi:interface", + "//psi/utils:sync", + ], +) + +psi_cc_library( + name = "server", + srcs = ["server.cc"], + hdrs = ["server.h"], + deps = [ + ":ecdh_oprf_psi", + "//psi:interface", + "//psi/utils:ec", + "//psi/utils:sync", + ], +) diff --git a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf.cc b/psi/ecdh/basic_ecdh_oprf.cc similarity index 99% rename from psi/psi/core/ecdh_oprf/basic_ecdh_oprf.cc rename to psi/ecdh/basic_ecdh_oprf.cc index 25d22638..397512a2 100644 --- a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf.cc +++ b/psi/ecdh/basic_ecdh_oprf.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" +#include "psi/ecdh/basic_ecdh_oprf.h" #include #include @@ -23,9 +23,9 @@ #include "yacl/crypto/base/hash/blake3.h" #include "yacl/crypto/base/hash/hash_utils.h" -#include "psi/psi/cryptor/ecc_utils.h" +#include "psi/cryptor/ecc_utils.h" -namespace psi::psi { +namespace psi::ecdh { namespace { // use 96bit as the final compare value @@ -419,4 +419,4 @@ size_t FourQBasicEcdhOprfClient::GetEcPointLength() const { return kEccKeySize; } -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h b/psi/ecdh/basic_ecdh_oprf.h similarity index 96% rename from psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h rename to psi/ecdh/basic_ecdh_oprf.h index d43749a3..a5d2ce55 100644 --- a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h +++ b/psi/ecdh/basic_ecdh_oprf.h @@ -26,9 +26,9 @@ #include "yacl/base/exception.h" #include "yacl/crypto/base/hash/hash_interface.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" -#include "psi/psi/cryptor/ecc_cryptor.h" -#include "psi/psi/cryptor/sm2_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/sm2_cryptor.h" +#include "psi/ecdh/ecdh_oprf.h" // 2HashDH Oprf // F_k(x) = H2(x, H1(x)^k) @@ -41,7 +41,7 @@ // server H2(x, H1(x)^sk) // client H2(y, (H1(y)^r)^sk^(1/r))=H2(y, H1(y)^sk) -namespace psi::psi { +namespace psi::ecdh { class BasicEcdhOprfServer : public IEcdhOprfServer { public: @@ -177,4 +177,4 @@ class FourQBasicEcdhOprfClient : public IEcdhOprfClient { yacl::crypto::HashAlgorithm hash_type_ = yacl::crypto::HashAlgorithm::BLAKE3; }; -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf_test.cc b/psi/ecdh/basic_ecdh_oprf_test.cc similarity index 94% rename from psi/psi/core/ecdh_oprf/basic_ecdh_oprf_test.cc rename to psi/ecdh/basic_ecdh_oprf_test.cc index 9d0dc62f..093ad980 100644 --- a/psi/psi/core/ecdh_oprf/basic_ecdh_oprf_test.cc +++ b/psi/ecdh/basic_ecdh_oprf_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" +#include "psi/ecdh/basic_ecdh_oprf.h" #include #include @@ -28,9 +28,9 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" +#include "psi/ecdh/ecdh_oprf_selector.h" -namespace psi::psi { +namespace psi::ecdh { struct TestParams { size_t items_size; CurveType type = CurveType::CURVE_SECP256K1; @@ -89,4 +89,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{10, CurveType::CURVE_FOURQ}, TestParams{50, CurveType::CURVE_FOURQ})); -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/client.cc b/psi/ecdh/client.cc similarity index 95% rename from psi/psi/ecdh/client.cc rename to psi/ecdh/client.cc index 5f64263a..28769916 100644 --- a/psi/psi/ecdh/client.cc +++ b/psi/ecdh/client.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/ecdh/client.h" +#include "psi/ecdh/client.h" #include -#include "psi/psi/bucket_psi.h" -#include "psi/psi/utils/arrow_csv_batch_provider.h" -#include "psi/psi/utils/sync.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/utils/arrow_csv_batch_provider.h" +#include "psi/utils/sync.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { EcdhUbPsiClient::EcdhUbPsiClient(const v2::UbPsiConfig& config, std::shared_ptr lctx) @@ -123,4 +123,4 @@ void EcdhUbPsiClient::Online() { selected_keys, results, false, false, false); } -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/client.h b/psi/ecdh/client.h similarity index 89% rename from psi/psi/ecdh/client.h rename to psi/ecdh/client.h index 904243e6..771db612 100644 --- a/psi/psi/ecdh/client.h +++ b/psi/ecdh/client.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once -#include "psi/psi/core/ecdh_oprf_psi.h" -#include "psi/psi/interface.h" +#include "psi/ecdh/ecdh_oprf_psi.h" +#include "psi/interface.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { class EcdhUbPsiClient final : public AbstractUbPsiClient { public: @@ -41,4 +41,4 @@ class EcdhUbPsiClient final : public AbstractUbPsiClient { EcdhOprfPsiOptions psi_options_; }; -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/common.h b/psi/ecdh/common.h similarity index 92% rename from psi/psi/ecdh/common.h rename to psi/ecdh/common.h index d374566c..b9fddc43 100644 --- a/psi/psi/ecdh/common.h +++ b/psi/ecdh/common.h @@ -13,9 +13,9 @@ // limitations under the License. #pragma once -namespace psi::psi::ecdh { +namespace psi::ecdh { // Default bin num of HashBucketEcPointStores for ec points. constexpr int kDefaultBinNum = 64; -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_3pc_psi.cc b/psi/ecdh/ecdh_3pc_psi.cc similarity index 98% rename from psi/psi/core/ecdh_3pc_psi.cc rename to psi/ecdh/ecdh_3pc_psi.cc index 017f6faf..a395e33d 100644 --- a/psi/psi/core/ecdh_3pc_psi.cc +++ b/psi/ecdh/ecdh_3pc_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_3pc_psi.h" +#include "psi/ecdh/ecdh_3pc_psi.h" #include #include @@ -27,9 +27,9 @@ #include "yacl/link/link.h" #include "yacl/utils/serialize.h" -#include "psi/psi/cryptor/cryptor_selector.h" +#include "psi/cryptor/cryptor_selector.h" -namespace psi::psi { +namespace psi::ecdh { EcdhP2PExtendCtx::EcdhP2PExtendCtx(const EcdhPsiOptions& options) : EcdhPsiContext(options) {} @@ -364,4 +364,4 @@ size_t ShuffleEcdh3PcPsi::GetPartnersPsiPeerRank() { } } -} // namespace psi::psi \ No newline at end of file +} // namespace psi::ecdh \ No newline at end of file diff --git a/psi/psi/core/ecdh_3pc_psi.h b/psi/ecdh/ecdh_3pc_psi.h similarity index 97% rename from psi/psi/core/ecdh_3pc_psi.h rename to psi/ecdh/ecdh_3pc_psi.h index a36befe9..29899a7d 100644 --- a/psi/psi/core/ecdh_3pc_psi.h +++ b/psi/ecdh/ecdh_3pc_psi.h @@ -18,10 +18,10 @@ #include #include -#include "psi/psi/core/communication.h" -#include "psi/psi/core/ecdh_psi.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/utils/communication.h" -namespace psi::psi { +namespace psi::ecdh { class EcdhP2PExtendCtx : public EcdhPsiContext { public: @@ -159,4 +159,4 @@ class ShuffleEcdh3PcPsi { std::vector private_key_; }; -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_3pc_psi_bench.cc b/psi/ecdh/ecdh_3pc_psi_benchmark.cc similarity index 78% rename from psi/psi/core/ecdh_3pc_psi_bench.cc rename to psi/ecdh/ecdh_3pc_psi_benchmark.cc index 4a3242e5..e66de392 100644 --- a/psi/psi/core/ecdh_3pc_psi_bench.cc +++ b/psi/ecdh/ecdh_3pc_psi_benchmark.cc @@ -19,22 +19,22 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/ecdh_3pc_psi.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/ecdh/ecdh_3pc_psi.h" +#include "psi/utils/test_utils.h" static void BM_Ecdh3PcPsi(benchmark::State& state) { for (auto _ : state) { state.PauseTiming(); size_t n = state.range(0); - auto alice_items = psi::psi::test::CreateRangeItems(1, n); - auto bob_items = psi::psi::test::CreateRangeItems(2, n); - auto carol_items = psi::psi::test::CreateRangeItems(3, n); + auto alice_items = psi::test::CreateRangeItems(1, n); + auto bob_items = psi::test::CreateRangeItems(2, n); + auto carol_items = psi::test::CreateRangeItems(3, n); auto contexts = yacl::link::test::SetupWorld(3); // simple runner auto psi_func = - [&](const std::shared_ptr& handler, + [&](const std::shared_ptr& handler, const std::vector& items, std::vector* results) { std::vector masked_master_items; @@ -58,24 +58,24 @@ static void BM_Ecdh3PcPsi(benchmark::State& state) { std::vector bob_res; std::vector carol_res; auto alice_func = std::async([&]() { - psi::psi::ShuffleEcdh3PcPsi::Options opts; + psi::ecdh::ShuffleEcdh3PcPsi::Options opts; opts.link_ctx = contexts[0]; opts.master_rank = 0; - auto op = std::make_shared(opts); + auto op = std::make_shared(opts); return psi_func(op, alice_items, &alice_res); }); auto bob_func = std::async([&]() { - psi::psi::ShuffleEcdh3PcPsi::Options opts; + psi::ecdh::ShuffleEcdh3PcPsi::Options opts; opts.link_ctx = contexts[1]; opts.master_rank = 0; - auto op = std::make_shared(opts); + auto op = std::make_shared(opts); return psi_func(op, bob_items, &bob_res); }); auto carol_func = std::async([&]() { - psi::psi::ShuffleEcdh3PcPsi::Options opts; + psi::ecdh::ShuffleEcdh3PcPsi::Options opts; opts.link_ctx = contexts[2]; opts.master_rank = 0; - auto op = std::make_shared(opts); + auto op = std::make_shared(opts); return psi_func(op, carol_items, &carol_res); }); diff --git a/psi/psi/core/ecdh_3pc_psi_test.cc b/psi/ecdh/ecdh_3pc_psi_test.cc similarity index 78% rename from psi/psi/core/ecdh_3pc_psi_test.cc rename to psi/ecdh/ecdh_3pc_psi_test.cc index c22805ce..d5001c6c 100644 --- a/psi/psi/core/ecdh_3pc_psi_test.cc +++ b/psi/ecdh/ecdh_3pc_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_3pc_psi.h" +#include "psi/ecdh/ecdh_3pc_psi.h" #include #include @@ -24,7 +24,7 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" struct TestParams { std::vector items_a; @@ -32,7 +32,7 @@ struct TestParams { std::vector items_c; }; -namespace psi::psi::test { +namespace psi::ecdh { class Ecdh3PcPsiTest : public testing::TestWithParam {}; @@ -99,7 +99,8 @@ TEST_P(Ecdh3PcPsiTest, PartnersPsi) { size_t master_rank = alice_rank; - auto intersection_std_bc = GetIntersection(params.items_b, params.items_c); + auto intersection_std_bc = + test::GetIntersection(params.items_b, params.items_c); std::shared_ptr ecdh_3pc_psi_master; std::shared_ptr ecdh_3pc_psi_master_next; @@ -153,10 +154,12 @@ TEST_P(Ecdh3PcPsiTest, Works) { size_t master_rank = alice_rank; - auto intersection_std_ab = GetIntersection(params.items_a, params.items_b); - auto intersection_std_bc = GetIntersection(params.items_b, params.items_c); + auto intersection_std_ab = + test::GetIntersection(params.items_a, params.items_b); + auto intersection_std_bc = + test::GetIntersection(params.items_b, params.items_c); auto intersection_std_abc = - GetIntersection(intersection_std_ab, params.items_c); + test::GetIntersection(intersection_std_ab, params.items_c); std::shared_ptr ecdh_3pc_psi_master; std::shared_ptr ecdh_3pc_psi_master_next; @@ -216,32 +219,34 @@ TEST_P(Ecdh3PcPsiTest, Works) { INSTANTIATE_TEST_SUITE_P( Works_Instances, Ecdh3PcPsiTest, - testing::Values( - TestParams{{"a", "b"}, {"b", "c"}, {"b", "d"}}, // - TestParams{{"a", "b"}, {"b", "c"}, {"b", "d"}}, // - - TestParams{{"a", "b"}, {"b", "c"}, {"c", "d"}}, // - // - TestParams{{"a", "b"}, {"c", "d"}, {"d", "e"}}, // - TestParams{{"a", "b"}, {"c", "d"}, {"e", "f"}}, // - - // - TestParams{{}, {"a"}, {}}, // - TestParams{{"a"}, {}, {}}, // - TestParams{{}, {}, {"a"}}, // - // - // less than one batch - TestParams{CreateRangeItems(0, 4095), CreateRangeItems(1, 4095), - CreateRangeItems(2, 4095)}, // - - // exactly one batch - TestParams{CreateRangeItems(0, 4096), CreateRangeItems(1, 4096), - CreateRangeItems(2, 4096)}, // - // more than one batch - TestParams{CreateRangeItems(0, 8193), CreateRangeItems(5, 8193), - CreateRangeItems(10, 8193)}, // - // - TestParams{{}, {}, {}} // - )); - -} // namespace psi::psi::test \ No newline at end of file + testing::Values(TestParams{{"a", "b"}, {"b", "c"}, {"b", "d"}}, // + TestParams{{"a", "b"}, {"b", "c"}, {"b", "d"}}, // + + TestParams{{"a", "b"}, {"b", "c"}, {"c", "d"}}, // + // + TestParams{{"a", "b"}, {"c", "d"}, {"d", "e"}}, // + TestParams{{"a", "b"}, {"c", "d"}, {"e", "f"}}, // + + // + TestParams{{}, {"a"}, {}}, // + TestParams{{"a"}, {}, {}}, // + TestParams{{}, {}, {"a"}}, // + // + // less than one batch + TestParams{test::CreateRangeItems(0, 4095), + test::CreateRangeItems(1, 4095), + test::CreateRangeItems(2, 4095)}, // + + // exactly one batch + TestParams{test::CreateRangeItems(0, 4096), + test::CreateRangeItems(1, 4096), + test::CreateRangeItems(2, 4096)}, // + // more than one batch + TestParams{test::CreateRangeItems(0, 8193), + test::CreateRangeItems(5, 8193), + test::CreateRangeItems(10, 8193)}, // + // + TestParams{{}, {}, {}} // + )); + +} // namespace psi::ecdh diff --git a/psi/ecdh/ecdh_logger.h b/psi/ecdh/ecdh_logger.h new file mode 100644 index 00000000..320109dc --- /dev/null +++ b/psi/ecdh/ecdh_logger.h @@ -0,0 +1,41 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "psi/cryptor/ecc_cryptor.h" + +namespace psi::ecdh { + +enum EcdhStage { MaskSelf, MaskPeer, RecvDualMaskedSelf }; + +// EcdhLogger is for internal debug purposes. +class EcdhLogger { + public: + virtual ~EcdhLogger() = default; + EcdhLogger() = default; + + // For RecvDualMaskedSelf, output should be left empty. + virtual void Log(EcdhStage stage, + const std::array& secret_key, + size_t start_idx, const std::vector& input, + const std::vector& output = {}) = 0; +}; + +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/ecdh_oprf.cc b/psi/ecdh/ecdh_oprf.cc similarity index 96% rename from psi/psi/core/ecdh_oprf/ecdh_oprf.cc rename to psi/ecdh/ecdh_oprf.cc index 8a2f9097..0919350f 100644 --- a/psi/psi/core/ecdh_oprf/ecdh_oprf.cc +++ b/psi/ecdh/ecdh_oprf.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" +#include "psi/ecdh/ecdh_oprf.h" #include #include @@ -20,7 +20,7 @@ #include "yacl/utils/parallel.h" -namespace psi::psi { +namespace psi::ecdh { std::vector IEcdhOprfServer::Evaluate( absl::Span blinded_elements) const { @@ -91,4 +91,4 @@ std::vector IEcdhOprfClient::Finalize( return output; } -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/ecdh_oprf.h b/psi/ecdh/ecdh_oprf.h similarity index 98% rename from psi/psi/core/ecdh_oprf/ecdh_oprf.h rename to psi/ecdh/ecdh_oprf.h index 57318927..1e35affa 100644 --- a/psi/psi/core/ecdh_oprf/ecdh_oprf.h +++ b/psi/ecdh/ecdh_oprf.h @@ -27,9 +27,9 @@ #include "yacl/base/byte_container_view.h" #include "yacl/base/exception.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" -namespace psi::psi { +namespace psi::ecdh { enum class OprfType { Basic, @@ -192,4 +192,4 @@ class IEcdhOprfClient : public IEcdhOprf { absl::Span evaluated_element) const; }; -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf_psi.cc b/psi/ecdh/ecdh_oprf_psi.cc similarity index 98% rename from psi/psi/core/ecdh_oprf_psi.cc rename to psi/ecdh/ecdh_oprf_psi.cc index 0dab356f..381fbaa0 100644 --- a/psi/psi/core/ecdh_oprf_psi.cc +++ b/psi/ecdh/ecdh_oprf_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf_psi.h" +#include "psi/ecdh/ecdh_oprf_psi.h" #include @@ -28,11 +28,12 @@ #include "yacl/utils/parallel.h" #include "yacl/utils/serialize.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/cryptor/ecc_utils.h" -#include "psi/psi/utils/ub_psi_cache.h" +#include "psi/cryptor/ecc_utils.h" +#include "psi/ecdh/ecdh_oprf_selector.h" +#include "psi/utils/communication.h" +#include "psi/utils/ub_psi_cache.h" -namespace psi::psi { +namespace psi::ecdh { size_t EcdhOprfPsiServer::FullEvaluateAndSend( const std::shared_ptr& batch_provider, @@ -635,4 +636,4 @@ void EcdhOprfPsiClient::SendIntersectionMaskedItems( return; } -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf_psi.h b/psi/ecdh/ecdh_oprf_psi.h similarity index 96% rename from psi/psi/core/ecdh_oprf_psi.h rename to psi/ecdh/ecdh_oprf_psi.h index dd2048ed..0b35014c 100644 --- a/psi/psi/core/ecdh_oprf_psi.h +++ b/psi/ecdh/ecdh_oprf_psi.h @@ -24,11 +24,11 @@ #include "yacl/base/byte_container_view.h" #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/ub_psi_cache.h" +#include "psi/ecdh/ecdh_oprf.h" +#include "psi/ecdh/ecdh_oprf_selector.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/ec_point_store.h" +#include "psi/utils/ub_psi_cache.h" // basic ecdh-oprf based psi // reference: @@ -60,7 +60,7 @@ // ======================================================= // Intersection // -namespace psi::psi { +namespace psi::ecdh { // send queque capacity inline constexpr size_t kQueueCapacity = 32; @@ -220,4 +220,4 @@ class EcdhOprfPsiClient { size_t ec_point_length_; }; -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf_psi_test.cc b/psi/ecdh/ecdh_oprf_psi_test.cc similarity index 97% rename from psi/psi/core/ecdh_oprf_psi_test.cc rename to psi/ecdh/ecdh_oprf_psi_test.cc index a9ed7045..ec26aa1c 100644 --- a/psi/psi/core/ecdh_oprf_psi_test.cc +++ b/psi/ecdh/ecdh_oprf_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf_psi.h" +#include "psi/ecdh/ecdh_oprf_psi.h" #include #include @@ -34,13 +34,13 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/ecdh/ecdh_oprf_selector.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/ec_point_store.h" +#include "psi/utils/io.h" +#include "psi/utils/test_utils.h" -namespace psi::psi { +namespace psi::ecdh { namespace { void WriteCsvFile(const std::string &file_name, @@ -329,4 +329,4 @@ INSTANTIATE_TEST_SUITE_P( ) // ); -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/ecdh_oprf_selector.cc b/psi/ecdh/ecdh_oprf_selector.cc similarity index 96% rename from psi/psi/core/ecdh_oprf/ecdh_oprf_selector.cc rename to psi/ecdh/ecdh_oprf_selector.cc index 88cc2f2e..dd310b68 100644 --- a/psi/psi/core/ecdh_oprf/ecdh_oprf_selector.cc +++ b/psi/ecdh/ecdh_oprf_selector.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h" +#include "psi/ecdh/ecdh_oprf_selector.h" #include "yacl/utils/platform_utils.h" -#include "psi/psi/core/ecdh_oprf/basic_ecdh_oprf.h" -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" +#include "psi/ecdh/basic_ecdh_oprf.h" +#include "psi/ecdh/ecdh_oprf.h" -namespace psi::psi { +namespace psi::ecdh { std::unique_ptr CreateEcdhOprfServer( yacl::ByteContainerView private_key, OprfType oprf_type, @@ -181,4 +181,4 @@ std::unique_ptr CreateEcdhOprfClient( return client; } -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h b/psi/ecdh/ecdh_oprf_selector.h similarity index 92% rename from psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h rename to psi/ecdh/ecdh_oprf_selector.h index 59050d6f..040d2af1 100644 --- a/psi/psi/core/ecdh_oprf/ecdh_oprf_selector.h +++ b/psi/ecdh/ecdh_oprf_selector.h @@ -16,9 +16,9 @@ #include -#include "psi/psi/core/ecdh_oprf/ecdh_oprf.h" +#include "psi/ecdh/ecdh_oprf.h" -namespace psi::psi { +namespace psi::ecdh { std::unique_ptr CreateEcdhOprfServer( yacl::ByteContainerView private_key, OprfType oprf_type, @@ -34,4 +34,4 @@ std::unique_ptr CreateEcdhOprfClient( yacl::ByteContainerView private_key, OprfType oprf_type, CurveType curve_type); -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_psi.cc b/psi/ecdh/ecdh_psi.cc similarity index 92% rename from psi/psi/core/ecdh_psi.cc rename to psi/ecdh/ecdh_psi.cc index f9264ccc..8cd98700 100644 --- a/psi/psi/core/ecdh_psi.cc +++ b/psi/ecdh/ecdh_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_psi.h" +#include "psi/ecdh/ecdh_psi.h" #include #include @@ -24,11 +24,10 @@ #include "yacl/utils/parallel.h" #include "yacl/utils/serialize.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/utils/batch_provider.h" -namespace psi::psi { +namespace psi::ecdh { constexpr int kLogBatchInterval = 10; @@ -95,9 +94,10 @@ void EcdhPsiContext::MaskSelf( } std::vector masked_items; + std::vector hashed_masked_items; if (!batch_items.empty()) { - masked_items = Mask(options_.ecc_cryptor, - HashInputs(options_.ecc_cryptor, batch_items)); + hashed_masked_items = HashInputs(options_.ecc_cryptor, batch_items); + masked_items = Mask(options_.ecc_cryptor, hashed_masked_items); } // Send x^a. const auto tag = fmt::format("ECDHPSI:X^A:{}", batch_count); @@ -110,6 +110,11 @@ void EcdhPsiContext::MaskSelf( } break; } + + if (options_.ecdh_logger) { + options_.ecdh_logger->Log(EcdhStage::MaskSelf, options_.private_key, + item_count, hashed_masked_items, masked_items); + } item_count += batch_items.size(); ++batch_count; @@ -168,6 +173,10 @@ void EcdhPsiContext::MaskPeer( } break; } + if (options_.ecdh_logger) { + options_.ecdh_logger->Log(EcdhStage::MaskPeer, options_.private_key, + item_count, peer_items, dual_masked_peers); + } item_count += peer_items.size(); batch_count++; @@ -184,6 +193,7 @@ void EcdhPsiContext::RecvDualMaskedSelf( return; } + size_t item_count = 0; // Receive x^a^b. size_t batch_count = 0; while (true) { @@ -191,6 +201,10 @@ void EcdhPsiContext::RecvDualMaskedSelf( std::vector masked_items; const auto tag = fmt::format("ECDHPSI:X^A^B:{}", batch_count); RecvDualMaskedBatch(&masked_items, batch_count, tag); + if (options_.ecdh_logger) { + options_.ecdh_logger->Log(EcdhStage::RecvDualMaskedSelf, + options_.private_key, item_count, masked_items); + } for (auto& item : masked_items) { self_ec_point_store->Save(std::move(item)); } @@ -206,6 +220,8 @@ void EcdhPsiContext::RecvDualMaskedSelf( self_ec_point_store->ItemCount()); } } + + item_count += masked_items.size(); batch_count++; // Call the hook. @@ -407,6 +423,11 @@ std::vector RunEcdhPsi( options.target_rank = target_rank; options.batch_size = batch_size; + std::array key_array{}; + std::memcpy(key_array.data(), &options.ecc_cryptor->GetPrivateKey()[0], + kEccKeySize); + options.private_key = key_array; + auto self_ec_point_store = std::make_shared(); auto peer_ec_point_store = std::make_shared(); auto batch_provider = @@ -433,4 +454,4 @@ std::vector RunEcdhPsi( return ret; } -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_psi.h b/psi/ecdh/ecdh_psi.h similarity index 92% rename from psi/psi/core/ecdh_psi.h rename to psi/ecdh/ecdh_psi.h index 9e010aa0..6d819bc1 100644 --- a/psi/psi/core/ecdh_psi.h +++ b/psi/ecdh/ecdh_psi.h @@ -23,15 +23,16 @@ #include "yacl/link/link.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/cryptor/ecc_cryptor.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/recovery.h" +#include "psi/cryptor/ecc_cryptor.h" +#include "psi/ecdh/ecdh_logger.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/communication.h" +#include "psi/utils/ec_point_store.h" +#include "psi/utils/recovery.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi { +namespace psi::ecdh { using FinishBatchHook = std::function; @@ -74,6 +75,9 @@ struct EcdhPsiOptions { // Optional RecoveryManager to save checkpoints. std::shared_ptr recovery_manager = nullptr; + + std::array private_key; + std::shared_ptr ecdh_logger = nullptr; }; // batch handler for 2-party ecdh psi @@ -157,4 +161,4 @@ std::vector RunEcdhPsi( CurveType curve = CurveType::CURVE_25519, size_t batch_size = kEcdhPsiBatchSize); -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/core/ecdh_psi_bench.cc b/psi/ecdh/ecdh_psi_benchmark.cc similarity index 82% rename from psi/psi/core/ecdh_psi_bench.cc rename to psi/ecdh/ecdh_psi_benchmark.cc index 09abde7f..de981f58 100644 --- a/psi/psi/core/ecdh_psi_bench.cc +++ b/psi/ecdh/ecdh_psi_benchmark.cc @@ -20,10 +20,10 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/ec_point_store.h" namespace { @@ -35,16 +35,16 @@ std::vector CreateRangeItems(size_t begin, size_t size) { return ret; } -std::optional GetOverrideCurveType() { +std::optional GetOverrideCurveType() { if (const auto* env = std::getenv("OVERRIDE_CURVE")) { if (std::strcmp(env, "25519") == 0) { - return psi::psi::CurveType::CURVE_25519; + return psi::CurveType::CURVE_25519; } if (std::strcmp(env, "FOURQ") == 0) { - return psi::psi::CurveType::CURVE_FOURQ; + return psi::CurveType::CURVE_FOURQ; } if (std::strcmp(env, "ELLIGATOR2") == 0) { - return psi::psi::CurveType::CURVE_25519_ELLIGATOR2; + return psi::CurveType::CURVE_25519_ELLIGATOR2; } } return {}; @@ -64,9 +64,9 @@ static void BM_EcdhPsi(benchmark::State& state) { const std::vector& items, size_t target_rank) -> std::vector { const auto curve = GetOverrideCurveType(); - return psi::psi::RunEcdhPsi( + return psi::ecdh::RunEcdhPsi( ctx, items, target_rank, - curve.has_value() ? *curve : psi::psi::CurveType::CURVE_25519); + curve.has_value() ? *curve : psi::CurveType::CURVE_25519); }; state.ResumeTiming(); diff --git a/psi/psi/core/ecdh_psi_test.cc b/psi/ecdh/ecdh_psi_test.cc similarity index 97% rename from psi/psi/core/ecdh_psi_test.cc rename to psi/ecdh/ecdh_psi_test.cc index c4a87863..271e4c9e 100644 --- a/psi/psi/core/ecdh_psi_test.cc +++ b/psi/ecdh/ecdh_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/ecdh_psi.h" +#include "psi/ecdh/ecdh_psi.h" #include #include @@ -22,13 +22,13 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" struct TestParams { std::vector items_a; std::vector items_b; size_t target_rank; - psi::psi::CurveType curve_type = psi::psi::CurveType::CURVE_25519; + psi::CurveType curve_type = psi::CurveType::CURVE_25519; }; namespace std { @@ -40,7 +40,7 @@ std::ostream& operator<<(std::ostream& out, const TestParams& params) { } // namespace std -namespace psi::psi { +namespace psi::ecdh { TEST(EcdhPsiTestFailed, TargetRankMismatched) { for (std::pair ranks : std::vector>{ @@ -171,4 +171,4 @@ INSTANTIATE_TEST_SUITE_P( CurveType::CURVE_FOURQ} // )); -} // namespace psi::psi +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/receiver.cc b/psi/ecdh/receiver.cc similarity index 94% rename from psi/psi/ecdh/receiver.cc rename to psi/ecdh/receiver.cc index 987a0287..b076d126 100644 --- a/psi/psi/ecdh/receiver.cc +++ b/psi/ecdh/receiver.cc @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/ecdh/receiver.h" +#include "psi/ecdh/receiver.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "yacl/base/exception.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/ecdh/common.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/sync.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/ecdh/common.h" +#include "psi/trace_categories.h" +#include "psi/utils/sync.h" #include "psi/proto/psi.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { EcdhPsiReceiver::EcdhPsiReceiver(const v2::PsiConfig &config, std::shared_ptr lctx) @@ -142,4 +142,4 @@ void EcdhPsiReceiver::PostProcess() { SPDLOG_INFO("[EcdhPsiReceiver::PostProcess] end"); } -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/receiver.h b/psi/ecdh/receiver.h similarity index 87% rename from psi/psi/ecdh/receiver.h rename to psi/ecdh/receiver.h index 3e1ca4b0..82a9bbcc 100644 --- a/psi/psi/ecdh/receiver.h +++ b/psi/ecdh/receiver.h @@ -13,13 +13,13 @@ // limitations under the License. #pragma once -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/interface.h" -#include "psi/psi/utils/arrow_csv_batch_provider.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/interface.h" +#include "psi/utils/arrow_csv_batch_provider.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { class EcdhPsiReceiver final : public AbstractPsiReceiver { public: @@ -45,4 +45,4 @@ class EcdhPsiReceiver final : public AbstractPsiReceiver { std::shared_ptr peer_ec_point_store_; }; -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/sender.cc b/psi/ecdh/sender.cc similarity index 94% rename from psi/psi/ecdh/sender.cc rename to psi/ecdh/sender.cc index 94b50503..cbcb4f87 100644 --- a/psi/psi/ecdh/sender.cc +++ b/psi/ecdh/sender.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/ecdh/sender.h" +#include "psi/ecdh/sender.h" #include @@ -21,14 +21,14 @@ #include "yacl/base/exception.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/ecdh/common.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/sync.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/ecdh/common.h" +#include "psi/trace_categories.h" +#include "psi/utils/sync.h" #include "psi/proto/psi.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { EcdhPsiSender::EcdhPsiSender(const v2::PsiConfig &config, std::shared_ptr lctx) @@ -143,4 +143,4 @@ void EcdhPsiSender::PostProcess() { SPDLOG_INFO("[EcdhPsiSender::PostProcess] end"); } -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/sender.h b/psi/ecdh/sender.h similarity index 87% rename from psi/psi/ecdh/sender.h rename to psi/ecdh/sender.h index cd6259ee..92a54551 100644 --- a/psi/psi/ecdh/sender.h +++ b/psi/ecdh/sender.h @@ -13,13 +13,13 @@ // limitations under the License. #pragma once -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/interface.h" -#include "psi/psi/utils/arrow_csv_batch_provider.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/interface.h" +#include "psi/utils/arrow_csv_batch_provider.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { class EcdhPsiSender final : public AbstractPsiSender { public: @@ -45,4 +45,4 @@ class EcdhPsiSender final : public AbstractPsiSender { std::shared_ptr peer_ec_point_store_; }; -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/server.cc b/psi/ecdh/server.cc similarity index 96% rename from psi/psi/ecdh/server.cc rename to psi/ecdh/server.cc index 40fd6b0b..3a9f8950 100644 --- a/psi/psi/ecdh/server.cc +++ b/psi/ecdh/server.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/ecdh/server.h" +#include "psi/ecdh/server.h" -#include "psi/psi/utils/arrow_csv_batch_provider.h" -#include "psi/psi/utils/ec.h" -#include "psi/psi/utils/sync.h" +#include "psi/utils/arrow_csv_batch_provider.h" +#include "psi/utils/ec.h" +#include "psi/utils/sync.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { EcdhUbPsiServer::EcdhUbPsiServer(const v2::UbPsiConfig &config, std::shared_ptr lctx) @@ -149,4 +149,4 @@ void EcdhUbPsiServer::Online() { report_.set_intersection_count(-1); } -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/ecdh/server.h b/psi/ecdh/server.h similarity index 89% rename from psi/psi/ecdh/server.h rename to psi/ecdh/server.h index 24a12cba..7b63bf35 100644 --- a/psi/psi/ecdh/server.h +++ b/psi/ecdh/server.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once -#include "psi/psi/core/ecdh_oprf_psi.h" -#include "psi/psi/interface.h" +#include "psi/ecdh/ecdh_oprf_psi.h" +#include "psi/interface.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::ecdh { +namespace psi::ecdh { class EcdhUbPsiServer final : public AbstractUbPsiServer { public: @@ -41,4 +41,4 @@ class EcdhUbPsiServer final : public AbstractUbPsiServer { EcdhOprfPsiOptions psi_options_; }; -} // namespace psi::psi::ecdh +} // namespace psi::ecdh diff --git a/psi/psi/factory.cc b/psi/factory.cc similarity index 87% rename from psi/psi/factory.cc rename to psi/factory.cc index df0c6c0a..82d76961 100644 --- a/psi/psi/factory.cc +++ b/psi/factory.cc @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/factory.h" +#include "psi/factory.h" #include #include "yacl/base/exception.h" -#include "psi/psi/ecdh/client.h" -#include "psi/psi/ecdh/receiver.h" -#include "psi/psi/ecdh/sender.h" -#include "psi/psi/ecdh/server.h" -#include "psi/psi/kkrt/receiver.h" -#include "psi/psi/kkrt/sender.h" -#include "psi/psi/rr22/receiver.h" -#include "psi/psi/rr22/sender.h" +#include "psi/ecdh/client.h" +#include "psi/ecdh/receiver.h" +#include "psi/ecdh/sender.h" +#include "psi/ecdh/server.h" +#include "psi/kkrt/receiver.h" +#include "psi/kkrt/sender.h" +#include "psi/rr22/receiver.h" +#include "psi/rr22/sender.h" -namespace psi::psi { +namespace psi { std::unique_ptr createPsiParty( const v2::PsiConfig& config, std::shared_ptr lctx) { @@ -79,4 +79,4 @@ std::unique_ptr createUbPsiParty( } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/factory.h b/psi/factory.h similarity index 92% rename from psi/psi/factory.h rename to psi/factory.h index ac6943db..6d2df3ab 100644 --- a/psi/psi/factory.h +++ b/psi/factory.h @@ -15,11 +15,11 @@ #include -#include "psi/psi/interface.h" +#include "psi/interface.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { std::unique_ptr createPsiParty( const v2::PsiConfig& config, std::shared_ptr lctx); @@ -27,4 +27,4 @@ std::unique_ptr createPsiParty( std::unique_ptr createUbPsiParty( const v2::UbPsiConfig& config, std::shared_ptr lctx); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/interface.cc b/psi/interface.cc similarity index 97% rename from psi/psi/interface.cc rename to psi/interface.cc index 2aa3dceb..b534c660 100644 --- a/psi/psi/interface.cc +++ b/psi/interface.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/interface.h" +#include "psi/interface.h" #include @@ -26,17 +26,17 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/prelude.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/advanced_join.h" -#include "psi/psi/utils/csv_checker.h" -#include "psi/psi/utils/key.h" -#include "psi/psi/utils/sync.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/prelude.h" +#include "psi/trace_categories.h" +#include "psi/utils/advanced_join.h" +#include "psi/utils/csv_checker.h" +#include "psi/utils/key.h" +#include "psi/utils/sync.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { namespace { @@ -411,4 +411,4 @@ AbstractUbPsiClient::AbstractUbPsiClient( const v2::UbPsiConfig &config, std::shared_ptr lctx) : AbstractUbPsiParty(config, v2::Role::ROLE_CLIENT, std::move(lctx)) {} -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/interface.h b/psi/interface.h similarity index 96% rename from psi/psi/interface.h rename to psi/interface.h index af2b2488..93fefe1a 100644 --- a/psi/psi/interface.h +++ b/psi/interface.h @@ -23,13 +23,13 @@ #include "yacl/link/algorithm/barrier.h" #include "yacl/link/link.h" -#include "psi/psi/utils/advanced_join.h" -#include "psi/psi/utils/index_store.h" -#include "psi/psi/utils/recovery.h" +#include "psi/utils/advanced_join.h" +#include "psi/utils/index_store.h" +#include "psi/utils/recovery.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { class AbstractPsiParty { public: @@ -171,4 +171,4 @@ class AbstractUbPsiClient : public AbstractUbPsiParty { std::shared_ptr lctx); }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/kkrt/BUILD.bazel b/psi/kkrt/BUILD.bazel new file mode 100644 index 00000000..194ccf5f --- /dev/null +++ b/psi/kkrt/BUILD.bazel @@ -0,0 +1,88 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_library( + name = "kkrt_psi", + srcs = ["kkrt_psi.cc"], + hdrs = ["kkrt_psi.h"], + deps = [ + "//psi/utils:communication", + "//psi/utils:cuckoo_index", + "//psi/utils:serialize", + "@com_google_absl//absl/strings", + "@yacl//yacl/crypto/base/hash:hash_utils", + "@yacl//yacl/crypto/primitives/ot:base_ot", + "@yacl//yacl/crypto/primitives/ot:iknp_ote", + "@yacl//yacl/crypto/primitives/ot:kkrt_ote", + "@yacl//yacl/crypto/utils:rand", + "@yacl//yacl/link", + ], +) + +psi_cc_test( + name = "kkrt_psi_test", + srcs = ["kkrt_psi_test.cc"], + deps = [ + ":kkrt_psi", + "@yacl//yacl/crypto/base/hash:hash_utils", + ], +) + +psi_cc_binary( + name = "kkrt_psi_benchmark", + srcs = ["kkrt_psi_benchmark.cc"], + deps = [ + ":kkrt_psi", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "common", + srcs = ["common.cc"], + hdrs = ["common.h"], + deps = [ + "//psi/proto:psi_v2_cc_proto", + "//psi/utils:bucket", + "//psi/utils:recovery", + ], +) + +psi_cc_library( + name = "receiver", + srcs = ["receiver.cc"], + hdrs = ["receiver.h"], + deps = [ + ":common", + ":kkrt_psi", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], +) + +psi_cc_library( + name = "sender", + srcs = ["sender.cc"], + hdrs = ["sender.h"], + deps = [ + ":common", + ":kkrt_psi", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], +) diff --git a/psi/psi/kkrt/common.cc b/psi/kkrt/common.cc similarity index 88% rename from psi/psi/kkrt/common.cc rename to psi/kkrt/common.cc index 4455721f..7dda8b8b 100644 --- a/psi/psi/kkrt/common.cc +++ b/psi/kkrt/common.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/kkrt/common.h" +#include "psi/kkrt/common.h" -#include "psi/psi/utils/bucket.h" +#include "psi/utils/bucket.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { void CommonInit(const std::string& key_hash_digest, v2::PsiConfig* config, RecoveryManager* recovery_manager) { @@ -30,4 +30,4 @@ void CommonInit(const std::string& key_hash_digest, v2::PsiConfig* config, } } -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/kkrt/common.h b/psi/kkrt/common.h similarity index 89% rename from psi/psi/kkrt/common.h rename to psi/kkrt/common.h index 70ec2abe..78d66147 100644 --- a/psi/psi/kkrt/common.h +++ b/psi/kkrt/common.h @@ -15,11 +15,11 @@ #include -#include "psi/psi/utils/recovery.h" +#include "psi/utils/recovery.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { // For KkrtOt constexpr size_t kDefaultNumOt = 512; @@ -27,4 +27,4 @@ constexpr size_t kDefaultNumOt = 512; void CommonInit(const std::string& key_hash_digest, v2::PsiConfig* config, RecoveryManager* recovery_manager); -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/core/kkrt_psi.cc b/psi/kkrt/kkrt_psi.cc similarity index 98% rename from psi/psi/core/kkrt_psi.cc rename to psi/kkrt/kkrt_psi.cc index 7544adc0..9fd59130 100644 --- a/psi/psi/core/kkrt_psi.cc +++ b/psi/kkrt/kkrt_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/kkrt_psi.h" +#include "psi/kkrt/kkrt_psi.h" #include #include @@ -29,11 +29,11 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/core/cuckoo_index.h" -#include "psi/psi/utils/serialize.h" +#include "psi/utils/communication.h" +#include "psi/utils/cuckoo_index.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi::kkrt { namespace { // constexpr size_t kPsiDataBatchSize = 1 << 10; @@ -415,4 +415,4 @@ std::vector KkrtPsiRecv( return ret_intersection; } -} // namespace psi::psi +} // namespace psi::kkrt diff --git a/psi/psi/core/kkrt_psi.h b/psi/kkrt/kkrt_psi.h similarity index 98% rename from psi/psi/core/kkrt_psi.h rename to psi/kkrt/kkrt_psi.h index ce551308..160bf516 100644 --- a/psi/psi/core/kkrt_psi.h +++ b/psi/kkrt/kkrt_psi.h @@ -33,7 +33,7 @@ // PSZ18 Scalable private set intersection based on ot extension // https://eprint.iacr.org/2016/930.pdf // -namespace psi::psi { +namespace psi::kkrt { struct KkrtPsiOptions { // batch size the receiver send corrections @@ -90,4 +90,4 @@ inline std::vector KkrtPsiRecv( return KkrtPsiRecv(link_ctx, kkrt_psi_options, ot_send, items_hash); } -} // namespace psi::psi +} // namespace psi::kkrt diff --git a/psi/psi/core/kkrt_psi_bench.cc b/psi/kkrt/kkrt_psi_benchmark.cc similarity index 88% rename from psi/psi/core/kkrt_psi_bench.cc rename to psi/kkrt/kkrt_psi_benchmark.cc index 14833a5a..52c8d298 100644 --- a/psi/psi/core/kkrt_psi_bench.cc +++ b/psi/kkrt/kkrt_psi_benchmark.cc @@ -20,7 +20,7 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/kkrt_psi.h" +#include "psi/kkrt/kkrt_psi.h" namespace { std::vector CreateRangeItems(size_t begin, size_t size) { @@ -34,15 +34,15 @@ std::vector CreateRangeItems(size_t begin, size_t size) { void KkrtPsiSend(const std::shared_ptr& link_ctx, const std::vector& items_hash) { - auto ot_recv = psi::psi::GetKkrtOtSenderOptions(link_ctx, 512); - return psi::psi::KkrtPsiSend(link_ctx, ot_recv, items_hash); + auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); + return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); } std::vector KkrtPsiRecv( const std::shared_ptr& link_ctx, const std::vector& items_hash) { - auto ot_send = psi::psi::GetKkrtOtReceiverOptions(link_ctx, 512); - return psi::psi::KkrtPsiRecv(link_ctx, ot_send, items_hash); + auto ot_send = psi::kkrt::GetKkrtOtReceiverOptions(link_ctx, 512); + return psi::kkrt::KkrtPsiRecv(link_ctx, ot_send, items_hash); } } // namespace diff --git a/psi/psi/core/kkrt_psi_test.cc b/psi/kkrt/kkrt_psi_test.cc similarity index 98% rename from psi/psi/core/kkrt_psi_test.cc rename to psi/kkrt/kkrt_psi_test.cc index 1d368a42..75228803 100644 --- a/psi/psi/core/kkrt_psi_test.cc +++ b/psi/kkrt/kkrt_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/kkrt_psi.h" +#include "psi/kkrt/kkrt_psi.h" #include #include @@ -29,7 +29,7 @@ struct TestParams { std::vector items_b; }; -namespace psi::psi { +namespace psi::kkrt { void KkrtPsiSend(const std::shared_ptr& link_ctx, const std::vector& items_hash) { @@ -131,4 +131,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{{}, {}} // )); -} // namespace psi::psi +} // namespace psi::kkrt diff --git a/psi/psi/kkrt/receiver.cc b/psi/kkrt/receiver.cc similarity index 93% rename from psi/psi/kkrt/receiver.cc rename to psi/kkrt/receiver.cc index 84de36eb..76f3facf 100644 --- a/psi/psi/kkrt/receiver.cc +++ b/psi/kkrt/receiver.cc @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/kkrt/receiver.h" +#include "psi/kkrt/receiver.h" #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/core/kkrt_psi.h" -#include "psi/psi/kkrt/common.h" -#include "psi/psi/prelude.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/bucket.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/kkrt/common.h" +#include "psi/kkrt/kkrt_psi.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/prelude.h" +#include "psi/trace_categories.h" +#include "psi/utils/bucket.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { KkrtPsiReceiver::KkrtPsiReceiver(const v2::PsiConfig& config, std::shared_ptr lctx) @@ -176,4 +176,4 @@ void KkrtPsiReceiver::PostProcess() { SPDLOG_INFO("[KkrtPsiReceiver::PostProcess] end"); } -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/kkrt/receiver.h b/psi/kkrt/receiver.h similarity index 90% rename from psi/psi/kkrt/receiver.h rename to psi/kkrt/receiver.h index d31001b2..e7e3e336 100644 --- a/psi/psi/kkrt/receiver.h +++ b/psi/kkrt/receiver.h @@ -15,12 +15,12 @@ #include "yacl/crypto/primitives/ot/ot_store.h" -#include "psi/psi/interface.h" -#include "psi/psi/utils/hash_bucket_cache.h" +#include "psi/interface.h" +#include "psi/utils/hash_bucket_cache.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { class KkrtPsiReceiver final : public AbstractPsiReceiver { public: @@ -45,4 +45,4 @@ class KkrtPsiReceiver final : public AbstractPsiReceiver { std::unique_ptr ot_send_; }; -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/kkrt/sender.cc b/psi/kkrt/sender.cc similarity index 92% rename from psi/psi/kkrt/sender.cc rename to psi/kkrt/sender.cc index 667b97e4..f3e366c4 100644 --- a/psi/psi/kkrt/sender.cc +++ b/psi/kkrt/sender.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/kkrt/sender.h" +#include "psi/kkrt/sender.h" #include #include @@ -21,16 +21,16 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/core/kkrt_psi.h" -#include "psi/psi/kkrt/common.h" -#include "psi/psi/prelude.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/bucket.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/kkrt/common.h" +#include "psi/kkrt/kkrt_psi.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/prelude.h" +#include "psi/trace_categories.h" +#include "psi/utils/bucket.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { KkrtPsiSender::KkrtPsiSender(const v2::PsiConfig& config, std::shared_ptr lctx) @@ -170,4 +170,4 @@ void KkrtPsiSender::PostProcess() { SPDLOG_INFO("[KkrtPsiSender::PostProcess] end"); } -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/kkrt/sender.h b/psi/kkrt/sender.h similarity index 89% rename from psi/psi/kkrt/sender.h rename to psi/kkrt/sender.h index 525dc82c..cbacb0df 100644 --- a/psi/psi/kkrt/sender.h +++ b/psi/kkrt/sender.h @@ -15,12 +15,12 @@ #include "yacl/crypto/primitives/ot/ot_store.h" -#include "psi/psi/interface.h" -#include "psi/psi/utils/hash_bucket_cache.h" +#include "psi/interface.h" +#include "psi/utils/hash_bucket_cache.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::kkrt { +namespace psi::kkrt { class KkrtPsiSender final : public AbstractPsiSender { public: @@ -45,4 +45,4 @@ class KkrtPsiSender final : public AbstractPsiSender { std::unique_ptr ot_recv_; }; -} // namespace psi::psi::kkrt +} // namespace psi::kkrt diff --git a/psi/psi/launch.cc b/psi/launch.cc similarity index 96% rename from psi/psi/launch.cc rename to psi/launch.cc index ddd7149f..abd7cbdd 100644 --- a/psi/psi/launch.cc +++ b/psi/launch.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/launch.h" +// perfetto usage is adapted from +// https://github.com/google/perfetto/blob/master/examples/sdk/example.cc + +#include "psi/launch.h" #include @@ -24,11 +27,10 @@ #include "perfetto.h" #include "spdlog/spdlog.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/factory.h" -#include "psi/psi/trace_categories.h" +#include "psi/factory.h" +#include "psi/trace_categories.h" -namespace psi::psi { +namespace psi { namespace { void InitializePerfetto() { @@ -163,4 +165,4 @@ PsiResultReport RunLegacyPsi(const BucketPsiConfig& bucket_psi_config, return bucket_psi.Run(progress_callbacks, callbacks_interval_ms); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/launch.h b/psi/launch.h similarity index 94% rename from psi/psi/launch.h rename to psi/launch.h index 292cc5ee..2bc60a76 100644 --- a/psi/psi/launch.h +++ b/psi/launch.h @@ -19,12 +19,12 @@ #include "yacl/link/context.h" -#include "psi/psi/bucket_psi.h" +#include "psi/legacy/bucket_psi.h" #include "psi/proto/psi.pb.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { PsiResultReport RunLegacyPsi(const BucketPsiConfig& bucket_psi_config, const std::shared_ptr& lctx, @@ -37,4 +37,4 @@ PsiResultReport RunPsi(const v2::PsiConfig& psi_config, PsiResultReport RunUbPsi(const v2::UbPsiConfig& ub_psi_config, const std::shared_ptr& lctx); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/BUILD.bazel b/psi/legacy/BUILD.bazel similarity index 57% rename from psi/psi/operator/BUILD.bazel rename to psi/legacy/BUILD.bazel index a8b4f93b..68080aa0 100644 --- a/psi/psi/operator/BUILD.bazel +++ b/psi/legacy/BUILD.bazel @@ -22,7 +22,7 @@ psi_cc_library( hdrs = ["base_operator.h"], deps = [ "//psi/proto:psi_cc_proto", - "//psi/psi/utils:sync", + "//psi/utils:sync", "@yacl//yacl/link", ], ) @@ -34,7 +34,7 @@ psi_cc_library( deps = [ ":base_operator", ":factory", - "//psi/psi/core:ecdh_3pc_psi", + "//psi/ecdh:ecdh_3pc_psi", ], alwayslink = True, ) @@ -46,7 +46,7 @@ psi_cc_library( deps = [ ":base_operator", ":factory", - "//psi/psi/core:kkrt_psi", + "//psi/kkrt:kkrt_psi", "@yacl//yacl/utils:parallel", ], alwayslink = True, @@ -60,7 +60,7 @@ psi_cc_library( ":base_operator", ":factory", ":kkrt_2party_psi", - "//psi/psi/core:ecdh_psi", + "//psi/ecdh:ecdh_psi", "@yacl//yacl/utils:parallel", ], alwayslink = True, @@ -71,7 +71,7 @@ psi_cc_test( srcs = ["nparty_psi_test.cc"], deps = [ ":nparty_psi", - "//psi/psi/utils:test_utils", + "//psi/utils:test_utils", ], ) @@ -93,7 +93,7 @@ psi_cc_library( deps = [ ":base_operator", ":factory", - "//psi/psi/core/bc22_psi", + "//psi/bc22:bc22_psi", ], alwayslink = True, ) @@ -105,7 +105,7 @@ psi_cc_library( deps = [ ":base_operator", ":factory", - "//psi/psi/core/dp_psi", + "//psi/legacy/dp_psi", ], alwayslink = True, ) @@ -117,7 +117,7 @@ psi_cc_library( deps = [ ":base_operator", ":factory", - "//psi/psi/core/vole_psi:rr22_psi", + "//psi/rr22:rr22_psi", "@yacl//yacl/utils:parallel", ], alwayslink = True, @@ -134,3 +134,88 @@ psi_cc_library( ":rr22_2party_psi", ], ) + +psi_cc_library( + name = "memory_psi", + srcs = ["memory_psi.cc"], + hdrs = [ + "memory_psi.h", + ], + deps = [ + ":factory", + ":operator", + "//psi:prelude", + "//psi/ecdh:ecdh_psi", + "//psi/proto:psi_cc_proto", + "//psi/utils:sync", + ], +) + +psi_cc_test( + name = "memory_psi_test", + srcs = ["memory_psi_test.cc"], + deps = [ + ":memory_psi", + "//psi/utils:test_utils", + ], +) + +psi_cc_library( + name = "bucket_ub_psi", + srcs = ["bucket_ub_psi.cc"], + hdrs = [ + "bucket_ub_psi.h", + ], + deps = [ + "//psi:prelude", + "//psi/ecdh:ecdh_oprf_psi", + "//psi/proto:psi_cc_proto", + "//psi/utils:batch_provider", + "//psi/utils:csv_checker", + "//psi/utils:csv_header_analyzer", + "//psi/utils:ec", + "//psi/utils:ec_point_store", + "//psi/utils:progress", + "//psi/utils:sync", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@yacl//yacl/utils:scope_guard", + ], +) + +psi_cc_library( + name = "bucket_psi", + srcs = ["bucket_psi.cc"], + hdrs = [ + "bucket_psi.h", + ], + deps = [ + ":bucket_ub_psi", + ":memory_psi", + "//psi:prelude", + "//psi/proto:psi_cc_proto", + "//psi/utils:batch_provider", + "//psi/utils:csv_checker", + "//psi/utils:csv_header_analyzer", + "//psi/utils:ec_point_store", + "@boost//:uuid", + ], +) + +psi_cc_test( + name = "bucket_psi_test", + srcs = ["bucket_psi_test.cc"], + deps = [ + ":bucket_psi", + "@yacl//yacl/utils:scope_guard", + ], +) + +psi_cc_test( + name = "bucket_ub_psi_test", + srcs = ["bucket_ub_psi_test.cc"], + deps = [ + ":bucket_psi", + "//psi/utils:test_utils", + ], +) diff --git a/psi/psi/operator/base_operator.cc b/psi/legacy/base_operator.cc similarity index 87% rename from psi/psi/operator/base_operator.cc rename to psi/legacy/base_operator.cc index 21d898b0..5186e959 100644 --- a/psi/psi/operator/base_operator.cc +++ b/psi/legacy/base_operator.cc @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" #include #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi { PsiBaseOperator::PsiBaseOperator(std::shared_ptr link_ctx) : link_ctx_(std::move(link_ctx)) {} @@ -39,4 +39,4 @@ std::vector PsiBaseOperator::Run( return res; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/base_operator.h b/psi/legacy/base_operator.h similarity index 96% rename from psi/psi/operator/base_operator.h rename to psi/legacy/base_operator.h index 98dd6d57..4dbd599a 100644 --- a/psi/psi/operator/base_operator.h +++ b/psi/legacy/base_operator.h @@ -21,7 +21,7 @@ #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { class PsiBaseOperator { public: @@ -40,4 +40,4 @@ class PsiBaseOperator { std::shared_ptr link_ctx_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/bc22_2party_psi.cc b/psi/legacy/bc22_2party_psi.cc similarity index 88% rename from psi/psi/operator/bc22_2party_psi.cc rename to psi/legacy/bc22_2party_psi.cc index a45f4fdd..4e942f0d 100644 --- a/psi/psi/operator/bc22_2party_psi.cc +++ b/psi/legacy/bc22_2party_psi.cc @@ -11,14 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/bc22_2party_psi.h" +#include "psi/legacy/bc22_2party_psi.h" #include -#include "psi/psi/core/bc22_psi/bc22_psi.h" -#include "psi/psi/operator/factory.h" +#include "psi/bc22/bc22_psi.h" +#include "psi/legacy/factory.h" -namespace psi::psi { +namespace psi { Bc22PcgPsiOperator::Options Bc22PcgPsiOperator::ParseConfig( const MemoryPsiConfig& config, @@ -34,7 +34,7 @@ std::vector Bc22PcgPsiOperator::OnRun( auto role = link_ctx_->Rank() == options_.receiver_rank ? PsiRoleType::Receiver : PsiRoleType::Sender; - Bc22PcgPsi pcg_psi(link_ctx_, role); + bc22::Bc22PcgPsi pcg_psi(link_ctx_, role); pcg_psi.RunPsi(inputs); if (role == PsiRoleType::Receiver) { return pcg_psi.GetIntersection(); @@ -56,4 +56,4 @@ REGISTER_OPERATOR(BC22_PSI_2PC, CreateOperator); } // namespace -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/operator/bc22_2party_psi.h b/psi/legacy/bc22_2party_psi.h similarity index 92% rename from psi/psi/operator/bc22_2party_psi.h rename to psi/legacy/bc22_2party_psi.h index 15034686..d474e856 100644 --- a/psi/psi/operator/bc22_2party_psi.h +++ b/psi/legacy/bc22_2party_psi.h @@ -14,9 +14,9 @@ #pragma once -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" -namespace psi::psi { +namespace psi { class Bc22PcgPsiOperator : public PsiBaseOperator { public: @@ -40,4 +40,4 @@ class Bc22PcgPsiOperator : public PsiBaseOperator { Options options_; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/bucket_psi.cc b/psi/legacy/bucket_psi.cc similarity index 96% rename from psi/psi/bucket_psi.cc rename to psi/legacy/bucket_psi.cc index 6e41a252..0d09d52b 100644 --- a/psi/psi/bucket_psi.cc +++ b/psi/legacy/bucket_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/bucket_psi.h" +#include "psi/legacy/bucket_psi.h" #include @@ -34,19 +34,18 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/serialize.h" -#include "psi/psi/bucket_ub_psi.h" -#include "psi/psi/core/ecdh_oprf_psi.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/csv_header_analyzer.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/legacy/bucket_ub_psi.h" +#include "psi/prelude.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/csv_header_analyzer.h" +#include "psi/utils/ec_point_store.h" +#include "psi/utils/io.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi { namespace { @@ -427,7 +426,7 @@ std::vector BucketPsi::RunPsi(std::shared_ptr& progress, self_items_count); if (config_.psi_type() == PsiType::ECDH_PSI_2PC) { - EcdhPsiOptions psi_options; + ecdh::EcdhPsiOptions psi_options; if (config_.curve_type() == CurveType::CURVE_INVALID_TYPE) { YACL_THROW("Unsupported curve type"); } @@ -471,8 +470,8 @@ std::vector BucketPsi::RunPsi(std::shared_ptr& progress, } // Launch ECDH-PSI core. - RunEcdhPsi(psi_options, batch_provider, self_ec_point_store, - peer_ec_point_store); + ecdh::RunEcdhPsi(psi_options, batch_provider, self_ec_point_store, + peer_ec_point_store); std::vector results; results = @@ -598,4 +597,4 @@ void GetResultIndices(const std::vector& item_data_list, } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/bucket_psi.h b/psi/legacy/bucket_psi.h similarity index 95% rename from psi/psi/bucket_psi.h rename to psi/legacy/bucket_psi.h index cb273788..c8c6745d 100644 --- a/psi/psi/bucket_psi.h +++ b/psi/legacy/bucket_psi.h @@ -26,16 +26,16 @@ #include "yacl/link/link.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/memory_psi.h" -#include "psi/psi/utils/csv_checker.h" -#include "psi/psi/utils/hash_bucket_cache.h" -#include "psi/psi/utils/index_store.h" -#include "psi/psi/utils/key.h" -#include "psi/psi/utils/progress.h" +#include "psi/legacy/memory_psi.h" +#include "psi/utils/csv_checker.h" +#include "psi/utils/hash_bucket_cache.h" +#include "psi/utils/index_store.h" +#include "psi/utils/key.h" +#include "psi/utils/progress.h" #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { using ProgressCallbacks = std::function; @@ -149,4 +149,4 @@ class BucketPsi { std::unique_ptr mem_psi_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/bucket_psi_test.cc b/psi/legacy/bucket_psi_test.cc similarity index 99% rename from psi/psi/bucket_psi_test.cc rename to psi/legacy/bucket_psi_test.cc index d3785640..7d276920 100644 --- a/psi/psi/bucket_psi_test.cc +++ b/psi/legacy/bucket_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/bucket_psi.h" +#include "psi/legacy/bucket_psi.h" #include #include @@ -23,9 +23,9 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { namespace { struct TestParams { @@ -656,4 +656,4 @@ TEST(UnbalancedPsiTest, EcdhOprfUnbalanced) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/bucket_ub_psi.cc b/psi/legacy/bucket_ub_psi.cc similarity index 87% rename from psi/psi/bucket_ub_psi.cc rename to psi/legacy/bucket_ub_psi.cc index 8c9f50a0..8a3147d3 100644 --- a/psi/psi/bucket_ub_psi.cc +++ b/psi/legacy/bucket_ub_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/bucket_ub_psi.h" +#include "psi/legacy/bucket_ub_psi.h" #include #include @@ -20,13 +20,13 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/ec.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/prelude.h" +#include "psi/utils/ec.h" +#include "psi/utils/io.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi { namespace { @@ -63,7 +63,7 @@ std::vector GetItemsByIndices( std::pair, size_t> UbPsi( BucketPsiConfig config, std::shared_ptr lctx) { - EcdhOprfPsiOptions psi_options; + ecdh::EcdhOprfPsiOptions psi_options; psi_options.link0 = lctx; if (config.psi_type() == PsiType::ECDH_OPRF_UB_PSI_2PC_GEN_CACHE) { @@ -143,12 +143,13 @@ std::pair, size_t> UbPsi( // generate cache std::pair, size_t> UbPsiServerGenCache( BucketPsiConfig config, std::shared_ptr /*lctx*/, - const EcdhOprfPsiOptions& psi_options) { + const ecdh::EcdhOprfPsiOptions& psi_options) { std::vector server_private_key = ReadEcSecretKeyFile(config.ecdh_secret_key_path()); - std::shared_ptr dh_oprf_psi_server = - std::make_shared(psi_options, server_private_key); + std::shared_ptr dh_oprf_psi_server = + std::make_shared(psi_options, + server_private_key); std::vector selected_fields; selected_fields.insert(selected_fields.end(), @@ -175,9 +176,9 @@ std::pair, size_t> UbPsiServerGenCache( // transfer cache std::pair, size_t> UbPsiClientTransferCache( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options) { - std::shared_ptr ub_psi_client_transfer_cache = - std::make_shared(psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options) { + std::shared_ptr ub_psi_client_transfer_cache = + std::make_shared(psi_options); auto peer_ec_point_store = std::make_shared( config.preprocess_path(), false, "peer", false); @@ -197,9 +198,9 @@ std::pair, size_t> UbPsiClientTransferCache( std::pair, size_t> UbPsiServerTransferCache( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options) { - std::shared_ptr ub_psi_server_transfer_cache = - std::make_shared(psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options) { + std::shared_ptr ub_psi_server_transfer_cache = + std::make_shared(psi_options); std::shared_ptr batch_provider = std::make_shared( @@ -221,11 +222,11 @@ std::pair, size_t> UbPsiServerTransferCache( // online with shuffling std::pair, size_t> UbPsiClientShuffleOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir) { + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir) { std::vector private_key = yacl::crypto::RandBytes(kEccKeySize); - std::shared_ptr ub_psi_client_shuffle_online = - std::make_shared(psi_options, private_key); + std::shared_ptr ub_psi_client_shuffle_online = + std::make_shared(psi_options, private_key); std::vector selected_fields; selected_fields.insert(selected_fields.end(), @@ -275,12 +276,14 @@ std::pair, size_t> UbPsiClientShuffleOnline( std::pair, size_t> UbPsiServerShuffleOnline( BucketPsiConfig config, std::shared_ptr /*lctx*/, - const EcdhOprfPsiOptions& psi_options, const std::string& /*tmp_dir*/) { + const ecdh::EcdhOprfPsiOptions& psi_options, + const std::string& /*tmp_dir*/) { std::vector server_private_key = ReadEcSecretKeyFile(config.ecdh_secret_key_path()); - std::shared_ptr ub_psi_server_shuffle_online = - std::make_shared(psi_options, server_private_key); + std::shared_ptr ub_psi_server_shuffle_online = + std::make_shared(psi_options, + server_private_key); ub_psi_server_shuffle_online->RecvBlindAndShuffleSendEvaluate(); @@ -300,18 +303,19 @@ std::pair, size_t> UbPsiServerShuffleOnline( // offline std::pair, size_t> UbPsiClientOffline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options) { + const ecdh::EcdhOprfPsiOptions& psi_options) { return UbPsiClientTransferCache(config, lctx, psi_options); } std::pair, size_t> UbPsiServerOffline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options) { + const ecdh::EcdhOprfPsiOptions& psi_options) { std::vector server_private_key = ReadEcSecretKeyFile(config.ecdh_secret_key_path()); - std::shared_ptr dh_oprf_psi_server_offline = - std::make_shared(psi_options, server_private_key); + std::shared_ptr dh_oprf_psi_server_offline = + std::make_shared(psi_options, + server_private_key); std::vector selected_fields; selected_fields.insert(selected_fields.end(), @@ -338,9 +342,9 @@ std::pair, size_t> UbPsiServerOffline( // online std::pair, size_t> UbPsiClientOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir) { - std::shared_ptr dh_oprf_psi_client_online = - std::make_shared(psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir) { + std::shared_ptr dh_oprf_psi_client_online = + std::make_shared(psi_options); std::vector selected_fields; selected_fields.insert(selected_fields.end(), @@ -411,12 +415,14 @@ std::pair, size_t> UbPsiClientOnline( std::pair, size_t> UbPsiServerOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& /*tmp_dir*/) { + const ecdh::EcdhOprfPsiOptions& psi_options, + const std::string& /*tmp_dir*/) { std::vector server_private_key = ReadEcSecretKeyFile(config.ecdh_secret_key_path()); - std::shared_ptr dh_oprf_psi_server_online = - std::make_shared(psi_options, server_private_key); + std::shared_ptr dh_oprf_psi_server_online = + std::make_shared(psi_options, + server_private_key); dh_oprf_psi_server_online->RecvBlindAndSendEvaluate(); @@ -457,4 +463,4 @@ std::pair, size_t> UbPsiServerOnline( return std::make_pair(results, 0); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/bucket_ub_psi.h b/psi/legacy/bucket_ub_psi.h similarity index 76% rename from psi/psi/bucket_ub_psi.h rename to psi/legacy/bucket_ub_psi.h index bbc5cf0e..fa2aa141 100644 --- a/psi/psi/bucket_ub_psi.h +++ b/psi/legacy/bucket_ub_psi.h @@ -22,49 +22,49 @@ #include "yacl/base/exception.h" #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_oprf_psi.h" +#include "psi/ecdh/ecdh_oprf_psi.h" #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { std::pair, size_t> UbPsi( BucketPsiConfig config, std::shared_ptr lctx); std::pair, size_t> UbPsiServerGenCache( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options); std::pair, size_t> UbPsiClientTransferCache( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options); std::pair, size_t> UbPsiServerTransferCache( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options); std::pair, size_t> UbPsiClientShuffleOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); std::pair, size_t> UbPsiServerShuffleOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); std::pair, size_t> UbPsiClientOffline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options); std::pair, size_t> UbPsiServerOffline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options); + const ecdh::EcdhOprfPsiOptions& psi_options); std::pair, size_t> UbPsiClientOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); std::pair, size_t> UbPsiServerOnline( BucketPsiConfig config, std::shared_ptr lctx, - const EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); + const ecdh::EcdhOprfPsiOptions& psi_options, const std::string& tmp_dir); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/bucket_ub_psi_test.cc b/psi/legacy/bucket_ub_psi_test.cc similarity index 98% rename from psi/psi/bucket_ub_psi_test.cc rename to psi/legacy/bucket_ub_psi_test.cc index 5e011417..6669e3d6 100644 --- a/psi/psi/bucket_ub_psi_test.cc +++ b/psi/legacy/bucket_ub_psi_test.cc @@ -22,11 +22,11 @@ #include "yacl/link/test_util.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/utils/io.h" +#include "psi/utils/test_utils.h" -namespace psi::psi { +namespace psi { namespace { @@ -262,4 +262,4 @@ INSTANTIATE_TEST_SUITE_P( ) // ); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/dp_2party_psi.cc b/psi/legacy/dp_2party_psi.cc similarity index 78% rename from psi/psi/operator/dp_2party_psi.cc rename to psi/legacy/dp_2party_psi.cc index 2c594718..215cabcf 100644 --- a/psi/psi/operator/dp_2party_psi.cc +++ b/psi/legacy/dp_2party_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/dp_2party_psi.h" +#include "psi/legacy/dp_2party_psi.h" #include #include @@ -23,14 +23,14 @@ #include "yacl/link/link.h" #include "yacl/utils/serialize.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/operator/factory.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/legacy/factory.h" -namespace psi::psi { +namespace psi { DpPsiOperator::DpPsiOperator(const std::shared_ptr& lctx, - const DpPsiOptions& options, size_t receiver_rank, - CurveType curve_type) + const dp_psi::DpPsiOptions& options, + size_t receiver_rank, CurveType curve_type) : PsiBaseOperator(lctx), dp_options_(options), receiver_rank_(receiver_rank), @@ -45,15 +45,16 @@ std::vector DpPsiOperator::OnRun( size_t bob_sub_sample_size = 0; if (receiver_rank_ == link_ctx_->Rank()) { - std::vector dp_psi_result = RunDpEcdhPsiBob( + std::vector dp_psi_result = dp_psi::RunDpEcdhPsiBob( dp_options_, link_ctx_, inputs, &bob_sub_sample_size, curve_type_); for (auto index : dp_psi_result) { res.emplace_back(inputs[index]); } } else { - RunDpEcdhPsiAlice(dp_options_, link_ctx_, inputs, &alice_sub_sample_size, - &alice_up_sample_size, curve_type_); + dp_psi::RunDpEcdhPsiAlice(dp_options_, link_ctx_, inputs, + &alice_sub_sample_size, &alice_up_sample_size, + curve_type_); } return res; @@ -71,7 +72,7 @@ std::unique_ptr CreateOperator( epsilon = config.dppsi_params().epsilon(); } - DpPsiOptions dp_options(bob_sub_sampling, epsilon); + dp_psi::DpPsiOptions dp_options(bob_sub_sampling, epsilon); if (config.curve_type() != CurveType::CURVE_INVALID_TYPE) { return std::make_unique( @@ -86,4 +87,4 @@ REGISTER_OPERATOR(DP_PSI_2PC, CreateOperator); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/dp_2party_psi.h b/psi/legacy/dp_2party_psi.h similarity index 81% rename from psi/psi/operator/dp_2party_psi.h rename to psi/legacy/dp_2party_psi.h index f475784b..e79a0081 100644 --- a/psi/psi/operator/dp_2party_psi.h +++ b/psi/legacy/dp_2party_psi.h @@ -17,24 +17,24 @@ #include #include -#include "psi/psi/core/dp_psi/dp_psi.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" +#include "psi/legacy/dp_psi/dp_psi.h" -namespace psi::psi { +namespace psi { class DpPsiOperator : public PsiBaseOperator { public: DpPsiOperator(const std::shared_ptr& lctx, - const DpPsiOptions& options, size_t receiver_rank, + const dp_psi::DpPsiOptions& options, size_t receiver_rank, CurveType curve_type = CurveType::CURVE_25519); std::vector OnRun( const std::vector& inputs) override final; private: - DpPsiOptions dp_options_; + dp_psi::DpPsiOptions dp_options_; size_t receiver_rank_; CurveType curve_type_ = CurveType::CURVE_25519; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/dp_psi/BUILD.bazel b/psi/legacy/dp_psi/BUILD.bazel similarity index 77% rename from psi/psi/core/dp_psi/BUILD.bazel rename to psi/legacy/dp_psi/BUILD.bazel index 13f596a8..f78e704c 100644 --- a/psi/psi/core/dp_psi/BUILD.bazel +++ b/psi/legacy/dp_psi/BUILD.bazel @@ -16,23 +16,32 @@ load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") package(default_visibility = ["//visibility:public"]) +psi_cc_library( + name = "dp_psi_utils", + srcs = [ + "dp_psi_utils.cc", + ], + hdrs = [ + "dp_psi_utils.h", + ], +) + psi_cc_library( name = "dp_psi", srcs = [ "dp_psi.cc", - "dp_psi_utils.cc", ], hdrs = [ "dp_psi.h", - "dp_psi_utils.h", ], deps = [ - "//psi/psi/core:ecdh_3pc_psi", - "//psi/psi/core:ecdh_psi", - "//psi/psi/cryptor:cryptor_selector", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:ec_point_store", - "//psi/psi/utils:serialize", + ":dp_psi_utils", + "//psi/cryptor:cryptor_selector", + "//psi/ecdh:ecdh_3pc_psi", + "//psi/ecdh:ecdh_psi", + "//psi/utils:batch_provider", + "//psi/utils:ec_point_store", + "//psi/utils:serialize", "@com_google_absl//absl/strings", "@yacl//yacl/base:exception", "@yacl//yacl/crypto/utils:rand", @@ -51,8 +60,8 @@ psi_cc_test( ) psi_cc_binary( - name = "dp_psi_bench", - srcs = ["dp_psi_bench.cc"], + name = "dp_psi_benchmark", + srcs = ["dp_psi_benchmark.cc"], deps = [ ":dp_psi", "@com_github_google_benchmark//:benchmark_main", @@ -61,8 +70,8 @@ psi_cc_binary( ) psi_cc_binary( - name = "dp_psi_payload_bench", - srcs = ["dp_psi_payload_bench.cc"], + name = "dp_psi_payload_benchmark", + srcs = ["dp_psi_payload_benchmark.cc"], deps = [ ":dp_psi", "@com_github_google_benchmark//:benchmark_main", diff --git a/psi/psi/core/dp_psi/dp_psi.cc b/psi/legacy/dp_psi/dp_psi.cc similarity index 94% rename from psi/psi/core/dp_psi/dp_psi.cc rename to psi/legacy/dp_psi/dp_psi.cc index e7a1d6ea..782b06d9 100644 --- a/psi/psi/core/dp_psi/dp_psi.cc +++ b/psi/legacy/dp_psi/dp_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/dp_psi/dp_psi.h" +#include "psi/legacy/dp_psi/dp_psi.h" #include #include @@ -24,17 +24,17 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/core/dp_psi/dp_psi_utils.h" -#include "psi/psi/core/ecdh_3pc_psi.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/serialize.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/ecdh/ecdh_3pc_psi.h" +#include "psi/legacy/dp_psi/dp_psi_utils.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/communication.h" +#include "psi/utils/ec_point_store.h" +#include "psi/utils/serialize.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi { +namespace psi::dp_psi { namespace { @@ -118,13 +118,13 @@ size_t RunDpEcdhPsiAlice(const DpPsiOptions& dp_psi_options, "alice items_size: {}, down_sampling_rate: {}, up_sampling_rate: {}", items.size(), dp_psi_options.p2, dp_psi_options.q); - EcdhPsiOptions options; + ecdh::EcdhPsiOptions options; options.link_ctx = link_ctx; options.ecc_cryptor = CreateEccCryptor(curve); options.target_rank = link_ctx->Rank(); - EcdhP2PExtendCtx psi_ctx(options); + ecdh::EcdhP2PExtendCtx psi_ctx(options); std::future f_mask_self_a = std::async([&] { return psi_ctx.MaskSelf(batch_provider); }); @@ -239,14 +239,14 @@ std::vector RunDpEcdhPsiBob( auto peer_ec_point_store = std::make_shared(); - EcdhPsiOptions options; + ecdh::EcdhPsiOptions options; // set options options.ecc_cryptor = CreateEccCryptor(curve); options.link_ctx = link_ctx; options.target_rank = link_ctx->Rank(); - EcdhP2PExtendCtx psi_ctx(options); + ecdh::EcdhP2PExtendCtx psi_ctx(options); std::future f_mask_peer_b = std::async([&] { return psi_ctx.MaskPeer(peer_ec_point_store); }); @@ -310,4 +310,4 @@ std::vector RunDpEcdhPsiBob( return dp_intersection_idx; } -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi.h b/psi/legacy/dp_psi/dp_psi.h similarity index 96% rename from psi/psi/core/dp_psi/dp_psi.h rename to psi/legacy/dp_psi/dp_psi.h index c8170b2a..7880791b 100644 --- a/psi/psi/core/dp_psi/dp_psi.h +++ b/psi/legacy/dp_psi/dp_psi.h @@ -23,9 +23,9 @@ #include "spdlog/spdlog.h" #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_psi.h" +#include "psi/ecdh/ecdh_psi.h" -namespace psi::psi { +namespace psi::dp_psi { // bernoulli distribution probability for sub/up samples struct DpPsiOptions { @@ -83,4 +83,4 @@ std::vector RunDpEcdhPsiBob( const std::vector& items, size_t* sub_sample_size, CurveType curve = CurveType::CURVE_25519); -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi_bench.cc b/psi/legacy/dp_psi/dp_psi_benchmark.cc similarity index 98% rename from psi/psi/core/dp_psi/dp_psi_bench.cc rename to psi/legacy/dp_psi/dp_psi_benchmark.cc index ae71659c..023f2c8f 100644 --- a/psi/psi/core/dp_psi/dp_psi_bench.cc +++ b/psi/legacy/dp_psi/dp_psi_benchmark.cc @@ -24,9 +24,9 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/dp_psi/dp_psi.h" +#include "psi/legacy/dp_psi/dp_psi.h" -namespace psi::psi { +namespace psi::dp_psi { namespace { @@ -197,4 +197,4 @@ BENCHMARK(BM_DpPsi) ->Arg(64 << 20) ->Arg(128 << 20); -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi_payload_bench.cc b/psi/legacy/dp_psi/dp_psi_payload_benchmark.cc similarity index 98% rename from psi/psi/core/dp_psi/dp_psi_payload_bench.cc rename to psi/legacy/dp_psi/dp_psi_payload_benchmark.cc index 7548e5f3..a3e3adfc 100644 --- a/psi/psi/core/dp_psi/dp_psi_payload_bench.cc +++ b/psi/legacy/dp_psi/dp_psi_payload_benchmark.cc @@ -26,10 +26,10 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/dp_psi/dp_psi.h" -#include "psi/psi/core/dp_psi/dp_psi_utils.h" +#include "psi/legacy/dp_psi/dp_psi.h" +#include "psi/legacy/dp_psi/dp_psi_utils.h" -namespace psi::psi { +namespace psi::dp_psi { namespace { @@ -394,4 +394,4 @@ BENCHMARK(BM_DpPsi) ->Args({1 << 20, 4}) ->Args({1 << 20, 5}); -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi_test.cc b/psi/legacy/dp_psi/dp_psi_test.cc similarity index 97% rename from psi/psi/core/dp_psi/dp_psi_test.cc rename to psi/legacy/dp_psi/dp_psi_test.cc index ad5b656d..bcfab2f0 100644 --- a/psi/psi/core/dp_psi/dp_psi_test.cc +++ b/psi/legacy/dp_psi/dp_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/dp_psi/dp_psi.h" +#include "psi/legacy/dp_psi/dp_psi.h" #include #include @@ -22,7 +22,7 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -namespace psi::psi { +namespace psi::dp_psi { namespace { @@ -140,4 +140,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, DpPsiTest, ) // ); -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi_utils.cc b/psi/legacy/dp_psi/dp_psi_utils.cc similarity index 98% rename from psi/psi/core/dp_psi/dp_psi_utils.cc rename to psi/legacy/dp_psi/dp_psi_utils.cc index 3a651198..e9934291 100644 --- a/psi/psi/core/dp_psi/dp_psi_utils.cc +++ b/psi/legacy/dp_psi/dp_psi_utils.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/dp_psi/dp_psi_utils.h" +#include "psi/legacy/dp_psi/dp_psi_utils.h" #include #include @@ -21,7 +21,7 @@ #include #include -namespace psi::psi { +namespace psi::dp_psi { double ComputeEpsilon2(size_t n, double epsilon) { double epsilon1; @@ -169,4 +169,4 @@ double CalibrateAnalyticGaussianMechanism(double epsilon, double delta, return sigma; } -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/core/dp_psi/dp_psi_utils.h b/psi/legacy/dp_psi/dp_psi_utils.h similarity index 95% rename from psi/psi/core/dp_psi/dp_psi_utils.h rename to psi/legacy/dp_psi/dp_psi_utils.h index 2b99785b..8d1db322 100644 --- a/psi/psi/core/dp_psi/dp_psi_utils.h +++ b/psi/legacy/dp_psi/dp_psi_utils.h @@ -18,7 +18,7 @@ #include #include -namespace psi::psi { +namespace psi::dp_psi { inline constexpr double kEpsilonPsi = 4; inline constexpr double kErrorRate = 1.e-12; @@ -34,4 +34,4 @@ inline double ComputePSubKeep(double epsilon2) { double CalibrateAnalyticGaussianMechanism(double epsilon, double delta, double GS, double tol = kErrorRate); -} // namespace psi::psi +} // namespace psi::dp_psi diff --git a/psi/psi/operator/ecdh_3party_psi.cc b/psi/legacy/ecdh_3party_psi.cc similarity index 91% rename from psi/psi/operator/ecdh_3party_psi.cc rename to psi/legacy/ecdh_3party_psi.cc index 0251c5b3..62474ebc 100644 --- a/psi/psi/operator/ecdh_3party_psi.cc +++ b/psi/legacy/ecdh_3party_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/ecdh_3party_psi.h" +#include "psi/legacy/ecdh_3party_psi.h" #include #include @@ -27,14 +27,14 @@ #include "yacl/link/link.h" #include "yacl/utils/serialize.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/operator/factory.h" +#include "psi/cryptor/cryptor_selector.h" +#include "psi/legacy/factory.h" namespace { constexpr uint32_t kLinkRecvTimeout = 30 * 60 * 1000; } // namespace -namespace psi::psi { +namespace psi { Ecdh3PartyPsiOperator::Options Ecdh3PartyPsiOperator::ParseConfig( const MemoryPsiConfig& config, @@ -52,14 +52,14 @@ Ecdh3PartyPsiOperator::Ecdh3PartyPsiOperator(const Options& options) : PsiBaseOperator(options.link_ctx), options_(options), handler_(nullptr) { options_.link_ctx->SetRecvTimeout(kLinkRecvTimeout); - ShuffleEcdh3PcPsi::Options opts; + ecdh::ShuffleEcdh3PcPsi::Options opts; opts.link_ctx = options_.link_ctx; opts.master_rank = options_.master_rank; opts.batch_size = options_.batch_size; opts.dual_mask_size = options_.dual_mask_size; opts.curve_type = options_.curve_type; - handler_ = std::make_shared(opts); + handler_ = std::make_shared(opts); } std::vector Ecdh3PartyPsiOperator::OnRun( @@ -94,4 +94,4 @@ REGISTER_OPERATOR(ECDH_PSI_3PC, CreateOperator); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/ecdh_3party_psi.h b/psi/legacy/ecdh_3party_psi.h similarity index 95% rename from psi/psi/operator/ecdh_3party_psi.h rename to psi/legacy/ecdh_3party_psi.h index f137b7ba..1500f1cd 100644 --- a/psi/psi/operator/ecdh_3party_psi.h +++ b/psi/legacy/ecdh_3party_psi.h @@ -19,10 +19,10 @@ #include #include -#include "psi/psi/core/ecdh_3pc_psi.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/ecdh/ecdh_3pc_psi.h" +#include "psi/legacy/base_operator.h" -namespace psi::psi { +namespace psi { // // 3party ecdh psi algorithm. @@ -91,7 +91,7 @@ class Ecdh3PartyPsiOperator : public PsiBaseOperator { private: Options options_; - std::shared_ptr handler_; + std::shared_ptr handler_; }; // @@ -112,4 +112,4 @@ std::vector RunShuffleEcdh3PartyPsi( CurveType curve_type = CurveType::CURVE_25519, size_t batch_size = kEcdhPsiBatchSize); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/factory.h b/psi/legacy/factory.h similarity index 96% rename from psi/psi/operator/factory.h rename to psi/legacy/factory.h index a74c7b30..ace8c7b7 100644 --- a/psi/psi/operator/factory.h +++ b/psi/legacy/factory.h @@ -21,11 +21,11 @@ #include "yacl/base/exception.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { using OperatorCreator = std::function( const MemoryPsiConfig& config, @@ -77,4 +77,4 @@ class OperatorRegistrar { #define REGISTER_OPERATOR(type, creator) \ static OperatorRegistrar registrar__##type##__object(#type, creator); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/kkrt_2party_psi.cc b/psi/legacy/kkrt_2party_psi.cc similarity index 80% rename from psi/psi/operator/kkrt_2party_psi.cc rename to psi/legacy/kkrt_2party_psi.cc index 6bca6deb..a0d04c52 100644 --- a/psi/psi/operator/kkrt_2party_psi.cc +++ b/psi/legacy/kkrt_2party_psi.cc @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/kkrt_2party_psi.h" +#include "psi/legacy/kkrt_2party_psi.h" #include #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/operator/factory.h" +#include "psi/kkrt/kkrt_psi.h" +#include "psi/legacy/factory.h" -namespace psi::psi { +namespace psi { KkrtPsiOperator::Options KkrtPsiOperator::ParseConfig( const MemoryPsiConfig& config, @@ -42,16 +43,18 @@ std::vector KkrtPsiOperator::OnRun( }); if (options_.receiver_rank == link_ctx_->Rank()) { - auto ot_send = GetKkrtOtReceiverOptions(options_.link_ctx, options_.num_ot); + auto ot_send = + kkrt::GetKkrtOtReceiverOptions(options_.link_ctx, options_.num_ot); std::vector kkrt_psi_result = - KkrtPsiRecv(options_.link_ctx, ot_send, items_hash); + kkrt::KkrtPsiRecv(options_.link_ctx, ot_send, items_hash); for (auto index : kkrt_psi_result) { res.emplace_back(inputs[index]); } } else { - auto ot_recv = GetKkrtOtSenderOptions(options_.link_ctx, options_.num_ot); - KkrtPsiSend(options_.link_ctx, ot_recv, items_hash); + auto ot_recv = + kkrt::GetKkrtOtSenderOptions(options_.link_ctx, options_.num_ot); + kkrt::KkrtPsiSend(options_.link_ctx, ot_recv, items_hash); } return res; @@ -70,4 +73,4 @@ REGISTER_OPERATOR(KKRT_PSI_2PC, CreateOperator); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/kkrt_2party_psi.h b/psi/legacy/kkrt_2party_psi.h similarity index 90% rename from psi/psi/operator/kkrt_2party_psi.h rename to psi/legacy/kkrt_2party_psi.h index 17b024fd..eee73588 100644 --- a/psi/psi/operator/kkrt_2party_psi.h +++ b/psi/legacy/kkrt_2party_psi.h @@ -19,10 +19,9 @@ #include #include -#include "psi/psi/core/kkrt_psi.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" -namespace psi::psi { +namespace psi { class KkrtPsiOperator : public PsiBaseOperator { public: @@ -44,4 +43,4 @@ class KkrtPsiOperator : public PsiBaseOperator { Options options_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/memory_psi.cc b/psi/legacy/memory_psi.cc similarity index 90% rename from psi/psi/memory_psi.cc rename to psi/legacy/memory_psi.cc index f81a5252..9722a682 100644 --- a/psi/psi/memory_psi.cc +++ b/psi/legacy/memory_psi.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/memory_psi.h" +#include "psi/legacy/memory_psi.h" #include "spdlog/spdlog.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/operator/factory.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/sync.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/legacy/factory.h" +#include "psi/prelude.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi { MemoryPsi::MemoryPsi(MemoryPsiConfig config, std::shared_ptr lctx) @@ -97,9 +97,9 @@ std::vector MemoryPsi::EcdhPsi( } if (config_.curve_type() != CurveType::CURVE_INVALID_TYPE) { - return RunEcdhPsi(lctx_, inputs, target_rank, config_.curve_type()); + return ecdh::RunEcdhPsi(lctx_, inputs, target_rank, config_.curve_type()); } - return RunEcdhPsi(lctx_, inputs, target_rank); + return ecdh::RunEcdhPsi(lctx_, inputs, target_rank); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/memory_psi.h b/psi/legacy/memory_psi.h similarity index 96% rename from psi/psi/memory_psi.h rename to psi/legacy/memory_psi.h index 3365badf..38450d31 100644 --- a/psi/psi/memory_psi.h +++ b/psi/legacy/memory_psi.h @@ -23,7 +23,7 @@ #include "psi/proto/psi.pb.h" -namespace psi::psi { +namespace psi { class MemoryPsi { public: @@ -44,4 +44,4 @@ class MemoryPsi { std::shared_ptr lctx_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/memory_psi_test.cc b/psi/legacy/memory_psi_test.cc similarity index 98% rename from psi/psi/memory_psi_test.cc rename to psi/legacy/memory_psi_test.cc index e5166e0d..f10b8f1f 100644 --- a/psi/psi/memory_psi_test.cc +++ b/psi/legacy/memory_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/memory_psi.h" +#include "psi/legacy/memory_psi.h" #include #include @@ -24,9 +24,9 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" -namespace psi::psi { +namespace psi { namespace { struct MemoryTaskTestParams { @@ -206,4 +206,4 @@ INSTANTIATE_TEST_SUITE_P(FailedWorks_Instances, MemoryTaskPsiTestFailedTest, FailedTestParams{3, 4, PsiType::INVALID_PSI_TYPE})); -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/core/polynomial/BUILD.bazel b/psi/legacy/mini_psi/BUILD.bazel similarity index 50% rename from psi/psi/core/polynomial/BUILD.bazel rename to psi/legacy/mini_psi/BUILD.bazel index 17c72c5b..8fa5903a 100644 --- a/psi/psi/core/polynomial/BUILD.bazel +++ b/psi/legacy/mini_psi/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2022 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test") +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") package(default_visibility = ["//visibility:public"]) @@ -39,3 +39,42 @@ psi_cc_test( "@yacl//yacl/crypto/tools:prg", ], ) + +psi_cc_library( + name = "mini_psi", + srcs = ["mini_psi.cc"], + hdrs = ["mini_psi.h"], + defines = ["CURVE25519_DONNA"], + deps = [ + ":polynomial", + "//psi/utils:batch_provider", + "//psi/utils:communication", + "//psi/utils:cuckoo_index", + "//psi/utils:serialize", + "//psi/utils:test_utils", + "@com_github_floodyberry_curve25519_donna//:curve25519_donna", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@yacl//yacl/crypto/base/hash:hash_utils", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/link", + "@yacl//yacl/utils:parallel", + ], +) + +psi_cc_test( + name = "mini_psi_test", + srcs = ["mini_psi_test.cc"], + deps = [ + ":mini_psi", + ], +) + +psi_cc_binary( + name = "mini_psi_demo", + srcs = ["mini_psi_demo.cc"], + deps = [ + ":mini_psi", + "//psi/ecdh:ecdh_psi", + ], +) diff --git a/psi/psi/core/mini_psi.cc b/psi/legacy/mini_psi/mini_psi.cc similarity index 97% rename from psi/psi/core/mini_psi.cc rename to psi/legacy/mini_psi/mini_psi.cc index 55a4aeae..6e21347c 100644 --- a/psi/psi/core/mini_psi.cc +++ b/psi/legacy/mini_psi/mini_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/mini_psi.h" +#include "psi/legacy/mini_psi/mini_psi.h" #include #include @@ -37,13 +37,13 @@ extern "C" { #include "yacl/crypto/tools/prg.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/core/cuckoo_index.h" -#include "psi/psi/core/polynomial/polynomial.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/serialize.h" +#include "psi/legacy/mini_psi/polynomial.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/communication.h" +#include "psi/utils/cuckoo_index.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi::mini_psi { namespace { @@ -119,7 +119,7 @@ struct MiniPsiSendCtx { yacl::parallel_for(0, items.size(), [&](int64_t begin, int64_t end) { for (int64_t idx = begin; idx < end; ++idx) { - polynomial_eval_values[idx] = ::psi::psi::EvalPolynomial( + polynomial_eval_values[idx] = ::psi::mini_psi::EvalPolynomial( polynomial_coeff, absl::string_view(items_hash[idx]), prime256_str); std::array ideal_permutation; @@ -246,7 +246,7 @@ struct MiniPsiRecvCtx { // ToDo: now use newton Polynomial Interpolation, need optimize to fft // polynomial_coeff = - ::psi::psi::InterpolatePolynomial(poly_x, poly_y, prime256_str); + ::psi::mini_psi::InterpolatePolynomial(poly_x, poly_y, prime256_str); } void SendPolynomialCoeff( @@ -621,4 +621,4 @@ std::vector MiniPsiRecvBatch( return ret; } -} // namespace psi::psi \ No newline at end of file +} // namespace psi::mini_psi diff --git a/psi/psi/core/mini_psi.h b/psi/legacy/mini_psi/mini_psi.h similarity index 96% rename from psi/psi/core/mini_psi.h rename to psi/legacy/mini_psi/mini_psi.h index 195aa391..e8b03433 100644 --- a/psi/psi/core/mini_psi.h +++ b/psi/legacy/mini_psi/mini_psi.h @@ -22,7 +22,7 @@ #include "absl/types/span.h" #include "yacl/link/link.h" -namespace psi::psi { +namespace psi::mini_psi { // // Compact and Malicious Private Set Intersection for Small Sets @@ -44,4 +44,5 @@ void MiniPsiSendBatch(const std::shared_ptr& link_ctx, std::vector MiniPsiRecvBatch( const std::shared_ptr& link_ctx, const std::vector& items); -} // namespace psi::psi \ No newline at end of file + +} // namespace psi::mini_psi diff --git a/psi/psi/core/mini_psi_demo.cc b/psi/legacy/mini_psi/mini_psi_demo.cc similarity index 91% rename from psi/psi/core/mini_psi_demo.cc rename to psi/legacy/mini_psi/mini_psi_demo.cc index 0bdb05f8..4a4af883 100644 --- a/psi/psi/core/mini_psi_demo.cc +++ b/psi/legacy/mini_psi/mini_psi_demo.cc @@ -24,8 +24,8 @@ #include "gflags/gflags.h" #include "spdlog/spdlog.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/core/mini_psi.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/legacy/mini_psi/mini_psi.h" DEFINE_int32(role, 1, "sender:0, receiver: 1"); DEFINE_int32(rank, 0, "self rank 0/1"); @@ -130,18 +130,14 @@ std::shared_ptr CreateLinks(const std::string& local_addr, } // namespace -// -// script generate_psi.py -// used to generate test data 18 digits id 50% intersection rate -// // psi demo // -- sender -// ./bazel-bin/psi/psi/core/mini_psi_demo --in ./100m/psi_1.csv --local +// ./bazel-bin/psi/psi/ecdh/mini_psi_demo --in ./100m/psi_1.csv --local // "127.0.0.1:1234" --remote "127.0.0.1:2222" --rank 0 --role 0 --protocol // semi-honest // // -- receiver -// ./bazel-bin/psi/psi/core/mini_psi_demo --in ./100m/psi_1.csv --local +// ./bazel-bin/psi/psi/ecdh/mini_psi_demo --in ./100m/psi_1.csv --local // "127.0.0.1:1234" --remote "127.0.0.1:2222" --rank 1 --role 1 --protocol // semi-honest // @@ -167,7 +163,7 @@ int main(int argc, char** argv) { std::vector intersection; if (FLAGS_protocol == "semi-honest") { - intersection = psi::psi::RunEcdhPsi(link_ctx, items, 1); + intersection = psi::ecdh::RunEcdhPsi(link_ctx, items, 1); if (FLAGS_rank == 1) { SPDLOG_INFO("intersection size:{}", intersection.size()); @@ -175,9 +171,9 @@ int main(int argc, char** argv) { } } else if (FLAGS_protocol == "malicious") { if (FLAGS_role == 0) { - psi::psi::MiniPsiSendBatch(link_ctx, items); + psi::mini_psi::MiniPsiSendBatch(link_ctx, items); } else if (FLAGS_role == 1) { - intersection = psi::psi::MiniPsiRecvBatch(link_ctx, items); + intersection = psi::mini_psi::MiniPsiRecvBatch(link_ctx, items); SPDLOG_INFO("intersection size:{}", intersection.size()); WriteCsvData(file_name, intersection); } diff --git a/psi/psi/core/mini_psi_test.cc b/psi/legacy/mini_psi/mini_psi_test.cc similarity index 95% rename from psi/psi/core/mini_psi_test.cc rename to psi/legacy/mini_psi/mini_psi_test.cc index acf2cbac..e05489c1 100644 --- a/psi/psi/core/mini_psi_test.cc +++ b/psi/legacy/mini_psi/mini_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/mini_psi.h" +#include "psi/legacy/mini_psi/mini_psi.h" #include #include @@ -25,7 +25,7 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" struct TestParams { std::vector items_a; @@ -33,7 +33,7 @@ struct TestParams { bool batch = false; }; -namespace psi::psi { +namespace psi::mini_psi { class MiniPsiTest : public testing::TestWithParam {}; TEST_P(MiniPsiTest, Works) { @@ -85,4 +85,4 @@ INSTANTIATE_TEST_SUITE_P( // TestParams{{}, {}} // )); -} // namespace psi::psi +} // namespace psi::mini_psi diff --git a/psi/psi/core/polynomial/polynomial.cc b/psi/legacy/mini_psi/polynomial.cc similarity index 98% rename from psi/psi/core/polynomial/polynomial.cc rename to psi/legacy/mini_psi/polynomial.cc index a9d407bb..22414468 100644 --- a/psi/psi/core/polynomial/polynomial.cc +++ b/psi/legacy/mini_psi/polynomial.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/polynomial/polynomial.h" +#include "psi/legacy/mini_psi/polynomial.h" #include #include @@ -20,7 +20,7 @@ #include "openssl/bn.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi::mini_psi { namespace { class BNDeleter { @@ -237,4 +237,4 @@ std::vector InterpolatePolynomial( return res; } -} // namespace psi::psi +} // namespace psi::mini_psi diff --git a/psi/psi/core/polynomial/polynomial.h b/psi/legacy/mini_psi/polynomial.h similarity index 96% rename from psi/psi/core/polynomial/polynomial.h rename to psi/legacy/mini_psi/polynomial.h index 9428944d..90596225 100644 --- a/psi/psi/core/polynomial/polynomial.h +++ b/psi/legacy/mini_psi/polynomial.h @@ -20,7 +20,7 @@ #include "absl/strings/string_view.h" -namespace psi::psi { +namespace psi::mini_psi { // for big num std::string EvalPolynomial(const std::vector &coeff, @@ -41,4 +41,4 @@ std::vector InterpolatePolynomial( const std::vector &poly_x, const std::vector &poly_y, std::string_view p_str); -} // namespace psi::psi +} // namespace psi::mini_psi diff --git a/psi/psi/core/polynomial/polynomial_test.cc b/psi/legacy/mini_psi/polynomial_test.cc similarity index 96% rename from psi/psi/core/polynomial/polynomial_test.cc rename to psi/legacy/mini_psi/polynomial_test.cc index cfecea4f..6d4f05ac 100644 --- a/psi/psi/core/polynomial/polynomial_test.cc +++ b/psi/legacy/mini_psi/polynomial_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/polynomial/polynomial.h" +#include "psi/legacy/mini_psi/polynomial.h" #include #include @@ -34,7 +34,7 @@ constexpr uint32_t kBnByteSize = 32; } // namespace -namespace psi::psi { +namespace psi::mini_psi { // test 256b big num polynomial interpolate and eval class PolynomialBnTest : public testing::TestWithParam {}; @@ -83,4 +83,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, PolynomialBnTest, TestParams{1025} // )); -} // namespace psi::psi +} // namespace psi::mini_psi diff --git a/psi/psi/operator/nparty_psi.cc b/psi/legacy/nparty_psi.cc similarity index 94% rename from psi/psi/operator/nparty_psi.cc rename to psi/legacy/nparty_psi.cc index 30fdcb75..60c6c204 100644 --- a/psi/psi/operator/nparty_psi.cc +++ b/psi/legacy/nparty_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/nparty_psi.h" +#include "psi/legacy/nparty_psi.h" #include #include @@ -25,16 +25,16 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/communication.h" -#include "psi/psi/operator/factory.h" -#include "psi/psi/operator/kkrt_2party_psi.h" -#include "psi/psi/utils/serialize.h" +#include "psi/legacy/factory.h" +#include "psi/legacy/kkrt_2party_psi.h" +#include "psi/utils/communication.h" +#include "psi/utils/serialize.h" namespace { constexpr size_t kSyncRecvWaitTimeoutMs = 60L * 60 * 1000; } // namespace -namespace psi::psi { +namespace psi { NpartyPsiOperator::Options NpartyPsiOperator::ParseConfig( const MemoryPsiConfig& config, @@ -166,11 +166,11 @@ std::vector NpartyPsiOperator::Run2PartyPsi( auto link_ctx = CreateP2PLinkCtx("2partypsi", options_.link_ctx, peer_rank); if (options_.psi_proto == PsiProtocol::Ecdh) { - return RunEcdhPsi(link_ctx, items, - target_rank == options_.link_ctx->Rank() - ? link_ctx->Rank() - : link_ctx->NextRank(), - options_.curve_type, options_.batch_size); + return ecdh::RunEcdhPsi(link_ctx, items, + target_rank == options_.link_ctx->Rank() + ? link_ctx->Rank() + : link_ctx->NextRank(), + options_.curve_type, options_.batch_size); } else if (options_.psi_proto == PsiProtocol::Kkrt) { KkrtPsiOperator::Options opts; opts.link_ctx = link_ctx; @@ -259,4 +259,4 @@ REGISTER_OPERATOR(KKRT_PSI_NPC, CreateOperator); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/nparty_psi.h b/psi/legacy/nparty_psi.h similarity index 95% rename from psi/psi/operator/nparty_psi.h rename to psi/legacy/nparty_psi.h index 859159fc..fda43e4a 100644 --- a/psi/psi/operator/nparty_psi.h +++ b/psi/legacy/nparty_psi.h @@ -19,10 +19,10 @@ #include "yacl/link/link.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/ecdh/ecdh_psi.h" +#include "psi/legacy/base_operator.h" -namespace psi::psi { +namespace psi { // use 2-party psi to get n-party PSI // put master rank at 0 position // ascending sort other rank by items size, @@ -81,4 +81,4 @@ class NpartyPsiOperator : public PsiBaseOperator { Options options_; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/operator/nparty_psi_test.cc b/psi/legacy/nparty_psi_test.cc similarity index 97% rename from psi/psi/operator/nparty_psi_test.cc rename to psi/legacy/nparty_psi_test.cc index 20d1b51e..238f4aec 100644 --- a/psi/psi/operator/nparty_psi_test.cc +++ b/psi/legacy/nparty_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/nparty_psi.h" +#include "psi/legacy/nparty_psi.h" #include #include @@ -26,9 +26,9 @@ #include "yacl/base/exception.h" #include "yacl/link/test_util.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" -namespace psi::psi { +namespace psi { namespace { struct NPartyTestParams { @@ -146,4 +146,4 @@ INSTANTIATE_TEST_SUITE_P( // NPartyTestParams{{0, 0}, 0, NpartyPsiOperator::PsiProtocol::Kkrt})); -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/operator/rr22_2party_psi.cc b/psi/legacy/rr22_2party_psi.cc similarity index 89% rename from psi/psi/operator/rr22_2party_psi.cc rename to psi/legacy/rr22_2party_psi.cc index 821519d4..edb13d19 100644 --- a/psi/psi/operator/rr22_2party_psi.cc +++ b/psi/legacy/rr22_2party_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/operator/rr22_2party_psi.h" +#include "psi/legacy/rr22_2party_psi.h" #include @@ -20,12 +20,12 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/operator/factory.h" -#include "psi/psi/utils/sync.h" +#include "psi/legacy/factory.h" +#include "psi/utils/sync.h" using DurationMillis = std::chrono::duration; -namespace psi::psi { +namespace psi { Rr22PsiOperator::Options Rr22PsiOperator::ParseConfig( const MemoryPsiConfig& config, @@ -75,8 +75,8 @@ std::vector Rr22PsiOperator::OnRun( const auto psi_core_start = std::chrono::system_clock::now(); if (options_.receiver_rank == link_ctx_->Rank()) { - std::vector rr22_psi_result = - Rr22PsiReceiver(options_.rr22_options, options_.link_ctx, items_hash); + std::vector rr22_psi_result = rr22::Rr22PsiReceiverInternal( + options_.rr22_options, options_.link_ctx, items_hash); const auto psi_core_end = std::chrono::system_clock::now(); const DurationMillis psi_core_duration = psi_core_end - psi_core_start; @@ -89,7 +89,8 @@ std::vector Rr22PsiOperator::OnRun( result.push_back(inputs[index]); } } else { - Rr22PsiSender(options_.rr22_options, options_.link_ctx, items_hash); + rr22::Rr22PsiSenderInternal(options_.rr22_options, options_.link_ctx, + items_hash); const auto psi_core_end = std::chrono::system_clock::now(); const DurationMillis psi_core_duration = psi_core_end - psi_core_start; @@ -115,7 +116,7 @@ std::unique_ptr CreateLowCommOperator( const std::shared_ptr& lctx) { auto options = Rr22PsiOperator::ParseConfig(config, lctx); - options.rr22_options.mode = Rr22PsiMode::LowCommMode; + options.rr22_options.mode = rr22::Rr22PsiMode::LowCommMode; return std::make_unique(options); } @@ -125,7 +126,7 @@ std::unique_ptr CreateMaliciousOperator( const std::shared_ptr& lctx) { auto options = Rr22PsiOperator::ParseConfig(config, lctx); - options.rr22_options.mode = Rr22PsiMode::FastMode; + options.rr22_options.mode = rr22::Rr22PsiMode::FastMode; options.rr22_options.malicious = true; options.rr22_options.code_type = yacl::crypto::CodeType::ExAcc7; @@ -141,4 +142,4 @@ REGISTER_OPERATOR(RR22_MALICIOUS_PSI_2PC, CreateMaliciousOperator); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/operator/rr22_2party_psi.h b/psi/legacy/rr22_2party_psi.h similarity index 85% rename from psi/psi/operator/rr22_2party_psi.h rename to psi/legacy/rr22_2party_psi.h index 8210caf0..fe0c79fb 100644 --- a/psi/psi/operator/rr22_2party_psi.h +++ b/psi/legacy/rr22_2party_psi.h @@ -18,17 +18,17 @@ #include #include -#include "psi/psi/core/vole_psi/rr22_psi.h" -#include "psi/psi/operator/base_operator.h" +#include "psi/legacy/base_operator.h" +#include "psi/rr22/rr22_psi.h" -namespace psi::psi { +namespace psi { class Rr22PsiOperator : public PsiBaseOperator { public: struct Options { std::shared_ptr link_ctx; size_t receiver_rank = 0; - Rr22PsiOptions rr22_options = Rr22PsiOptions(40, 0, true); + rr22::Rr22PsiOptions rr22_options = rr22::Rr22PsiOptions(40, 0, true); }; static Options ParseConfig(const MemoryPsiConfig& config, @@ -43,4 +43,4 @@ class Rr22PsiOperator : public PsiBaseOperator { Options options_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/main.cc b/psi/main.cc index 602ea782..27a61a4c 100644 --- a/psi/main.cc +++ b/psi/main.cc @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// perfetto usage is adapted from -// https://github.com/google/perfetto/blob/master/examples/sdk/example.cc - #include #include "gflags/gflags.h" @@ -22,7 +19,7 @@ #include "spdlog/spdlog.h" #include "psi/kuscia_adapter.h" -#include "psi/psi/launch.h" +#include "psi/launch.h" #include "psi/version.h" #include "psi/proto/entry.pb.h" @@ -86,13 +83,13 @@ int main(int argc, char* argv[]) { lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank); } - psi::psi::PsiResultReport report; + psi::PsiResultReport report; if (launch_config.has_legacy_psi_config()) { - report = psi::psi::RunLegacyPsi(launch_config.legacy_psi_config(), lctx); + report = psi::RunLegacyPsi(launch_config.legacy_psi_config(), lctx); } else if (launch_config.has_psi_config()) { - report = psi::psi::RunPsi(launch_config.psi_config(), lctx); + report = psi::RunPsi(launch_config.psi_config(), lctx); } else if (launch_config.has_ub_psi_config()) { - report = psi::psi::RunUbPsi(launch_config.ub_psi_config(), lctx); + report = psi::RunUbPsi(launch_config.ub_psi_config(), lctx); } else { SPDLOG_WARN("No runtime config is provided."); } diff --git a/psi/psi/prelude.h b/psi/prelude.h similarity index 63% rename from psi/psi/prelude.h rename to psi/prelude.h index a4ad8c7e..38ea6d0d 100644 --- a/psi/psi/prelude.h +++ b/psi/prelude.h @@ -22,26 +22,24 @@ namespace fmt { template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter - : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter { -}; +struct formatter : ostream_formatter {}; } // namespace fmt diff --git a/psi/proto/entry.proto b/psi/proto/entry.proto index 73d8fc91..92ada2ee 100644 --- a/psi/proto/entry.proto +++ b/psi/proto/entry.proto @@ -30,13 +30,13 @@ message LaunchConfig { oneof runtime_config { // Please check at psi.proto. - psi.BucketPsiConfig legacy_psi_config = 3; + BucketPsiConfig legacy_psi_config = 3; // Please check at psi_v2.proto. - psi.v2.PsiConfig psi_config = 4; + v2.PsiConfig psi_config = 4; // Please check at psi_v2.proto. - psi.v2.UbPsiConfig ub_psi_config = 5; + v2.UbPsiConfig ub_psi_config = 5; // TODO(junfeng): add PIR config here. } diff --git a/psi/proto/pir.proto b/psi/proto/pir.proto index b4540170..d2e5f3c3 100644 --- a/psi/proto/pir.proto +++ b/psi/proto/pir.proto @@ -16,7 +16,7 @@ syntax = "proto3"; -package psi.pir; +package psi; // The kv-store type of pir. enum KvStoreType { diff --git a/psi/proto/psi.proto b/psi/proto/psi.proto index ef08b7ba..fbd654ed 100644 --- a/psi/proto/psi.proto +++ b/psi/proto/psi.proto @@ -16,7 +16,7 @@ syntax = "proto3"; -package psi.psi; +package psi; // The algorithm type of psi. enum PsiType { diff --git a/psi/proto/psi_v2.proto b/psi/proto/psi_v2.proto index ec3d55a8..545dd0c3 100644 --- a/psi/proto/psi_v2.proto +++ b/psi/proto/psi_v2.proto @@ -17,7 +17,7 @@ syntax = "proto3"; import "psi/proto/psi.proto"; -package psi.psi.v2; +package psi.v2; // Role of parties. enum Role { @@ -64,7 +64,7 @@ enum Protocol { // Configs for ECDH protocol. message EcdhConfig { - .psi.psi.CurveType curve = 1; + .psi.CurveType curve = 1; } // Configs for KKRT protocol diff --git a/psi/psi/BUILD.bazel b/psi/psi/BUILD.bazel deleted file mode 100644 index d52a3c91..00000000 --- a/psi/psi/BUILD.bazel +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "memory_psi", - srcs = ["memory_psi.cc"], - hdrs = [ - "memory_psi.h", - ], - deps = [ - ":prelude", - "//psi/proto:psi_cc_proto", - "//psi/psi/core:ecdh_psi", - "//psi/psi/operator", - "//psi/psi/operator:factory", - "//psi/psi/utils:sync", - ], -) - -psi_cc_library( - name = "prelude", - hdrs = [ - "prelude.h", - ], - deps = [ - "//psi/proto:psi_cc_proto", - "//psi/proto:psi_v2_cc_proto", - ], -) - -psi_cc_test( - name = "memory_psi_test", - srcs = ["memory_psi_test.cc"], - deps = [ - ":memory_psi", - "//psi/psi/utils:test_utils", - ], -) - -psi_cc_library( - name = "bucket_ub_psi", - srcs = ["bucket_ub_psi.cc"], - hdrs = [ - "bucket_ub_psi.h", - ], - deps = [ - ":prelude", - "//psi/proto:psi_cc_proto", - "//psi/psi/core:ecdh_oprf_psi", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:csv_checker", - "//psi/psi/utils:csv_header_analyzer", - "//psi/psi/utils:ec", - "//psi/psi/utils:ec_point_store", - "//psi/psi/utils:progress", - "//psi/psi/utils:sync", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@yacl//yacl/utils:scope_guard", - ], -) - -psi_cc_library( - name = "bucket_psi", - srcs = ["bucket_psi.cc"], - hdrs = [ - "bucket_psi.h", - ], - deps = [ - ":bucket_ub_psi", - ":memory_psi", - ":prelude", - "//psi/proto:psi_cc_proto", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:csv_checker", - "//psi/psi/utils:csv_header_analyzer", - "//psi/psi/utils:ec_point_store", - "@boost//:uuid", - ], -) - -psi_cc_test( - name = "bucket_psi_test", - srcs = ["bucket_psi_test.cc"], - deps = [ - ":bucket_psi", - "@yacl//yacl/utils:scope_guard", - ], -) - -psi_cc_test( - name = "bucket_ub_psi_test", - srcs = ["bucket_ub_psi_test.cc"], - deps = [ - ":bucket_psi", - "//psi/psi/utils:test_utils", - ], -) - -psi_cc_library( - name = "interface", - srcs = ["interface.cc"], - hdrs = ["interface.h"], - deps = [ - ":trace_categories", - "//psi/proto:psi_v2_cc_proto", - "//psi/psi:bucket_psi", - "//psi/psi/utils:advanced_join", - "//psi/psi/utils:index_store", - "//psi/psi/utils:recovery", - "@boost//:uuid", - "@com_github_google_perfetto//:perfetto", - "@com_google_absl//absl/status", - "@yacl//yacl/link", - ], -) - -psi_cc_library( - name = "factory", - srcs = ["factory.cc"], - hdrs = ["factory.h"], - deps = [ - "//psi/psi/ecdh:client", - "//psi/psi/ecdh:receiver", - "//psi/psi/ecdh:sender", - "//psi/psi/ecdh:server", - "//psi/psi/kkrt:receiver", - "//psi/psi/kkrt:sender", - "//psi/psi/rr22:receiver", - "//psi/psi/rr22:sender", - "@yacl//yacl/base:exception", - ], -) - -psi_cc_library( - name = "launch", - srcs = ["launch.cc"], - hdrs = ["launch.h"], - deps = [ - ":bucket_psi", - ":factory", - ":trace_categories", - "@boost//:algorithm", - "@boost//:uuid", - ], -) - -psi_cc_library( - name = "trace_categories", - srcs = ["trace_categories.cc"], - hdrs = ["trace_categories.h"], - deps = [ - "@com_github_google_perfetto//:perfetto", - ], -) - -psi_cc_test( - name = "psi_test", - srcs = ["psi_test.cc"], - deps = [ - ":factory", - "//psi/psi/utils:arrow_csv_batch_provider", - "@boost//:uuid", - "@yacl//yacl/utils:scope_guard", - ], -) diff --git a/psi/psi/benchmark/BUILD.bazel b/psi/psi/benchmark/BUILD.bazel deleted file mode 100644 index de805003..00000000 --- a/psi/psi/benchmark/BUILD.bazel +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_binary") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_binary( - name = "standalone_bench", - srcs = [ - "standalone_bench.cc", - "standalone_bench.h", - ], - deps = [ - "//psi/psi/core:ecdh_oprf_psi", - "//psi/psi/core:ecdh_psi", - "//psi/psi/core:kkrt_psi", - "//psi/psi/core:mini_psi", - "//psi/psi/core/bc22_psi", - "//psi/psi/utils:test_utils", - "@boost//:uuid", - "@com_github_google_benchmark//:benchmark", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@yacl//yacl/link:test_util", - "@yacl//yacl/utils:scope_guard", - ], -) - -psi_cc_binary( - name = "mparty_bench", - srcs = [ - "mparty_bench.cc", - "mparty_bench.h", - ], - deps = [ - "//psi/proto:psi_cc_proto", - "//psi/psi/core:ecdh_oprf_psi", - "//psi/psi/core:ecdh_psi", - "//psi/psi/core:kkrt_psi", - "//psi/psi/core:mini_psi", - "//psi/psi/core/bc22_psi", - "//psi/psi/utils:test_utils", - "@boost//:uuid", - "@com_github_google_benchmark//:benchmark", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@yacl//yacl/base:int128", - "@yacl//yacl/crypto/base/hash:hash_utils", - "@yacl//yacl/link:factory", - "@yacl//yacl/utils:scope_guard", - ], -) diff --git a/psi/psi/benchmark/mparty_bench.cc b/psi/psi/benchmark/mparty_bench.cc deleted file mode 100644 index 4b8a3872..00000000 --- a/psi/psi/benchmark/mparty_bench.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/psi/benchmark/mparty_bench.h" - -#include -#include - -#include "gflags/gflags.h" - -DEFINE_uint32(rank, 0, "self rank, starts with 0"); - -DEFINE_string(parties, "", - "server list, format: host1:port1[,host2:port2, ...]"); - -namespace psi::psi::bench { - -void DefaultPsiArguments(benchmark::internal::Benchmark* b) { - b->Args({1 << 18}) - ->Args({1 << 20}) - ->Args({1 << 22}) - ->Args({1 << 24}) - ->Args({1000000}) - ->Args({5000000}) - ->Args({10000000}) - ->Iterations(1) - ->Unit(benchmark::kSecond); -} - -// register benchmarks with arguments -BM_REGISTER_ALL_PSI(DefaultPsiArguments); -// -// Equivalent to the following: -// -// BM_REGISTER_ECDH_PSI(DefaultPsiArguments); -// BM_REGISTER_ECDH_OPRF_PSI(DefaultPsiArguments); -// BM_REGISTER_KKRT_PSI(DefaultPsiArguments); -// BM_REGISTER_BC22_PSI(DefaultPsiArguments); -// BM_REGISTER_MINI_PSI(DefaultPsiArguments); - -} // namespace psi::psi::bench - -namespace { -void PreparePsiBench(const uint32_t rank, const std::string& parties) { - std::vector host_ips; - if (parties.empty()) { - // default ips for semi2k - host_ips = absl::StrSplit(psi::psi::bench::kTwoPartyHosts, ','); - } else { - host_ips = absl::StrSplit(parties, ','); - } - YACL_ENFORCE(host_ips.size() == 2); - - yacl::link::ContextDesc lctx_desc; - for (size_t i = 0; i < 2; i++) { - const std::string id = fmt::format("party{}", i); - lctx_desc.parties.push_back({id, host_ips[i]}); - benchmark::AddCustomContext(fmt::format("Benchmark Party-{} IP", i), - host_ips[i]); - } - - // setup bench_lctx and link - yacl::link::FactoryBrpc factory; - psi::psi::bench::PsiBench::bench_lctx = - factory.CreateContext(lctx_desc, rank); - psi::psi::bench::PsiBench::bench_lctx->ConnectToMesh(); -} -} // namespace - -// the main function -int main(int argc, char** argv) { - gflags::AllowCommandLineReparsing(); - gflags::ParseCommandLineFlags(&argc, &argv, true); - - try { - PreparePsiBench(FLAGS_rank, FLAGS_parties); - - // these entries are from BENCHMARK_MAIN - // ::benchmark::Initialize(&argc, argv); // remove all benchmark flags - // if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; - ::benchmark::RunSpecifiedBenchmarks(); - ::benchmark::Shutdown(); - - // sync close - psi::psi::bench::PsiBench::bench_lctx->WaitLinkTaskFinish(); - } catch (std::exception& e) { - exit(EXIT_FAILURE); - } - - return 0; -} diff --git a/psi/psi/benchmark/mparty_bench.h b/psi/psi/benchmark/mparty_bench.h deleted file mode 100644 index 732355f8..00000000 --- a/psi/psi/benchmark/mparty_bench.h +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "benchmark/benchmark.h" -#include "boost/uuid/uuid.hpp" -#include "boost/uuid/uuid_generators.hpp" -#include "boost/uuid/uuid_io.hpp" -#include "yacl/base/exception.h" -#include "yacl/base/int128.h" -#include "yacl/utils/scope_guard.h" - -#include "psi/psi/core/bc22_psi/bc22_psi.h" -#include "psi/psi/core/ecdh_oprf_psi.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/core/kkrt_psi.h" -#include "psi/psi/core/mini_psi.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/test_utils.h" - -namespace psi::psi::bench { - -namespace { - -void WriteCsvFile(const std::string& file_name, - const std::vector& items) { - auto out = io::BuildOutputStream(io::FileIoOptions(file_name)); - out->Write("id\n"); - for (const auto& data : items) { - out->Write(fmt::format("{}\n", data)); - } - out->Close(); -} - -} // namespace - -const char kTwoPartyHosts[] = "127.0.0.1:9540,127.0.0.1:9541"; - -class PsiBench : public benchmark::Fixture { - public: - static std::shared_ptr bench_lctx; - PsiBench() { - spdlog::set_level(spdlog::level::off); // turn off spdlog - } -}; - -std::shared_ptr PsiBench::bench_lctx = nullptr; - -#define PSI_BM_DEFINE_ECDH_TYPE(CurveType) \ - BENCHMARK_DEFINE_F(PsiBench, EcdhPsi_##CurveType) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto items = psi::test::CreateRangeItems(bench_lctx->Rank(), numel); \ - const auto curve = psi::test::GetOverrideCurveType(); \ - \ - state.ResumeTiming(); \ - \ - psi::RunEcdhPsi(bench_lctx, items, 0, \ - curve.has_value() ? *curve : (CurveType)); \ - } \ - } - -#define PSI_BM_DEFINE_ECDH() \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_25519); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_FOURQ); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_SM2); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_SECP256K1); - -PSI_BM_DEFINE_ECDH() - -#define PSI_BM_DEFINE_ECDH_OPRF_FULL(CurveType) \ - BENCHMARK_DEFINE_F(PsiBench, EcdhPsiOprf_##CurveType) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto items = psi::test::CreateRangeItems(bench_lctx->Rank(), numel); \ - \ - /* We let bob obtains the final result */ \ - state.ResumeTiming(); \ - if (bench_lctx->Rank() == 0) { \ - EcdhOprfPsiOptions options; \ - options.curve_type = (CurveType); \ - options.link0 = bench_lctx; \ - options.link1 = bench_lctx->Spawn(); \ - auto offline_proc = EcdhOprfPsiServer(options); \ - const auto sk = offline_proc.GetPrivateKey(); \ - auto online_proc = EcdhOprfPsiServer(options, sk); \ - \ - /* offline: init */ \ - boost::uuids::random_generator uuid_generator; \ - auto uuid_str = boost::uuids::to_string(uuid_generator()); \ - /* server input */ \ - auto server_input_path = \ - std::filesystem::path(fmt::format("server-input-{}", uuid_str)); \ - \ - /* server output */ \ - auto server_tmp_cache_path = \ - std::filesystem::path(fmt::format("tmp-cache-{}", uuid_str)); \ - /* register remove of temp file. */ \ - ON_SCOPE_EXIT([&] { \ - std::error_code ec; \ - std::filesystem::remove(server_input_path, ec); \ - if (ec.value() != 0) { \ - SPDLOG_WARN("can not remove tmp file: {}, msg: {}", \ - server_input_path.c_str(), ec.message()); \ - } \ - std::filesystem::remove(server_tmp_cache_path, ec); \ - if (ec.value() != 0) { \ - SPDLOG_WARN("can not remove tmp file: {}, msg: {}", \ - server_tmp_cache_path.c_str(), ec.message()); \ - } \ - }); \ - \ - WriteCsvFile(server_input_path.string(), items); \ - std::vector cloumn_ids = {"id"}; \ - std::shared_ptr item_provider = \ - std::make_shared( \ - server_input_path.string(), cloumn_ids, kEcdhOprfPsiBatchSize, \ - 100000, true); \ - \ - std::shared_ptr ub_cache = std::make_shared( \ - server_tmp_cache_path.string(), offline_proc.GetCompareLength(), \ - cloumn_ids); \ - \ - offline_proc.FullEvaluate(item_provider, ub_cache); \ - \ - /* offline: finalize */ \ - std::shared_ptr batch_provider = \ - std::make_shared( \ - server_tmp_cache_path.string(), kEcdhOprfPsiBatchSize, \ - offline_proc.GetCompareLength()); \ - offline_proc.SendFinalEvaluatedItems(batch_provider); \ - \ - /* online */ \ - online_proc.RecvBlindAndSendEvaluate(); \ - \ - } else { \ - EcdhOprfPsiOptions options; \ - options.curve_type = (CurveType); \ - options.link0 = bench_lctx; \ - options.link1 = bench_lctx->Spawn(); \ - auto self_ec_point_store = std::make_shared(); \ - auto peer_ec_point_store = std::make_shared(); \ - auto offline_proc = EcdhOprfPsiClient(options); \ - auto online_proc = EcdhOprfPsiClient(options); \ - \ - /* offline: recv and evaluate */ \ - offline_proc.RecvFinalEvaluatedItems(peer_ec_point_store); \ - \ - /* online */ \ - auto proc_send = std::async([&] { \ - auto item_provider = std::make_shared( \ - items, kEcdhOprfPsiBatchSize); \ - online_proc.SendBlindedItems(item_provider); \ - }); \ - \ - auto proc_recv = std::async( \ - [&] { online_proc.RecvEvaluatedItems(self_ec_point_store); }); \ - \ - proc_send.get(); \ - proc_recv.get(); \ - \ - /* online: finalize */ \ - auto& peer_results = peer_ec_point_store->content(); \ - auto& self_results = self_ec_point_store->content(); \ - std::sort(peer_results.begin(), peer_results.end()); \ - \ - std::vector final_result; \ - for (size_t i = 0; i < self_results.size(); i++) { \ - if (std::binary_search(peer_results.begin(), peer_results.end(), \ - self_results[i])) { \ - final_result.push_back(std::to_string(i + 1)); \ - } \ - } \ - } \ - } \ - } - -#define PSI_BM_DEFINE_ECDH_OPRF() \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_25519); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_FOURQ); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_SM2); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_SECP256K1); - -PSI_BM_DEFINE_ECDH_OPRF() - -BENCHMARK_DEFINE_F(PsiBench, KkrtPsi) -(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t numel = state.range(0); - auto items = psi::test::CreateItemHashes(bench_lctx->Rank(), numel); - - state.ResumeTiming(); - - if (bench_lctx->Rank() == 0) { /* Sender */ - auto ot_recv = psi::GetKkrtOtSenderOptions(bench_lctx, 512); - psi::KkrtPsiSend(bench_lctx, ot_recv, items); - } else { /* Receiver */ - auto ot_send = psi::GetKkrtOtReceiverOptions(bench_lctx, 512); - psi::KkrtPsiRecv(bench_lctx, ot_send, items); - } - } -} - -BENCHMARK_DEFINE_F(PsiBench, Bc22Psi) -(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t numel = state.range(0); - auto items = psi::test::CreateRangeItems(bench_lctx->Rank(), numel); - - state.ResumeTiming(); - - if (bench_lctx->Rank() == 0) { /* Sender */ - Bc22PcgPsi party(bench_lctx, PsiRoleType::Sender); - party.RunPsi(items); - } else { /* Receiver */ - Bc22PcgPsi party(bench_lctx, PsiRoleType::Receiver); - party.RunPsi(items); - party.GetIntersection(); - } - } -} - -#define PSI_BM_DEFINE_MINI_TYPE(IsBatch) \ - BENCHMARK_DEFINE_F(PsiBench, MiniPsi##_##IsBatch) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto items = psi::test::CreateRangeItems(bench_lctx->Rank(), numel); \ - \ - state.ResumeTiming(); \ - if (bench_lctx->Rank() == 0) { /* Sender */ \ - psi::MiniPsiSend(bench_lctx, items); \ - } else { /* Receiver */ \ - psi::MiniPsiRecv(bench_lctx, items); \ - } \ - } \ - } - -#define PSI_BM_DEFINE_MINI() \ - PSI_BM_DEFINE_MINI_TYPE(NoBatch) \ - PSI_BM_DEFINE_MINI_TYPE(Batch) -PSI_BM_DEFINE_MINI() - -#define PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CurveType, Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, PsiType##_##CurveType)->Apply(Arguments); - -#define PSI_BM_REGISTER_CURVE_PSI(PsiType, Arguments) \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_25519, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_FOURQ, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_SM2, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_SECP256K1, Arguments); - -#define BM_REGISTER_ECDH_PSI(Arguments) \ - PSI_BM_REGISTER_CURVE_PSI(EcdhPsi, Arguments) - -#define BM_REGISTER_ECDH_OPRF_PSI(Arguments) \ - /* Currently, ECDH OPRF does not support Curve25518 donna */ \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_FOURQ, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_SM2, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_SECP256K1, Arguments); - -#define BM_REGISTER_KKRT_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, KkrtPsi)->Apply(Arguments); - -#define BM_REGISTER_BC22_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, Bc22Psi)->Apply(Arguments); - -#define BM_REGISTER_MINI_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, MiniPsi_NoBatch)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(PsiBench, MiniPsi_Batch)->Apply(Arguments); - -#define BM_REGISTER_ALL_PSI(Arguments) \ - BM_REGISTER_ECDH_PSI(Arguments); \ - BM_REGISTER_ECDH_OPRF_PSI(Arguments); \ - BM_REGISTER_KKRT_PSI(Arguments); \ - BM_REGISTER_BC22_PSI(Arguments); \ - BM_REGISTER_MINI_PSI(Arguments); - -} // namespace psi::psi::bench diff --git a/psi/psi/benchmark/standalone_bench.cc b/psi/psi/benchmark/standalone_bench.cc deleted file mode 100644 index 232c8b3b..00000000 --- a/psi/psi/benchmark/standalone_bench.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/psi/benchmark/standalone_bench.h" - -namespace psi::psi::bench { - -void DefaultPsiArguments(benchmark::internal::Benchmark* b) { - b->Args({1 << 18}) - ->Args({1 << 20}) - ->Args({1 << 22}) - ->Args({1 << 24}) - ->Args({1000000}) - ->Args({5000000}) - ->Args({10000000}) - ->Iterations(1) - ->Unit(benchmark::kSecond); -} - -// register benchmarks with arguments -BM_REGISTER_ALL_PSI(DefaultPsiArguments); -// -// Equivalent to the following: -// -// BM_REGISTER_ECDH_PSI(DefaultPsiArguments); -// BM_REGISTER_ECDH_OPRF_PSI(DefaultPsiArguments); -// BM_REGISTER_KKRT_PSI(DefaultPsiArguments); -// BM_REGISTER_BC22_PSI(DefaultPsiArguments); -// BM_REGISTER_MINI_PSI(DefaultPsiArguments); - -} // namespace psi::psi::bench - -// the main function -BENCHMARK_MAIN(); diff --git a/psi/psi/benchmark/standalone_bench.h b/psi/psi/benchmark/standalone_bench.h deleted file mode 100644 index 75bfc062..00000000 --- a/psi/psi/benchmark/standalone_bench.h +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "benchmark/benchmark.h" -#include "boost/uuid/uuid.hpp" -#include "boost/uuid/uuid_generators.hpp" -#include "boost/uuid/uuid_io.hpp" -#include "yacl/base/exception.h" -#include "yacl/base/int128.h" -#include "yacl/link/test_util.h" -#include "yacl/utils/scope_guard.h" - -#include "psi/psi/core/bc22_psi/bc22_psi.h" -#include "psi/psi/core/ecdh_oprf_psi.h" -#include "psi/psi/core/ecdh_psi.h" -#include "psi/psi/core/kkrt_psi.h" -#include "psi/psi/core/mini_psi.h" -#include "psi/psi/cryptor/cryptor_selector.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/ec_point_store.h" -#include "psi/psi/utils/test_utils.h" - -namespace psi::psi::bench { - -namespace { - -void WriteCsvFile(const std::string& file_name, - const std::vector& items) { - auto out = io::BuildOutputStream(io::FileIoOptions(file_name)); - out->Write("id\n"); - for (const auto& data : items) { - out->Write(fmt::format("{}\n", data)); - } - out->Close(); -} - -} // namespace - -class PsiBench : public benchmark::Fixture { - public: - PsiBench() { - spdlog::set_level(spdlog::level::off); // turn off spdlog - } -}; - -#define PSI_BM_DEFINE_ECDH_TYPE(CurveType) \ - BENCHMARK_DEFINE_F(PsiBench, EcdhPsi_##CurveType) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto a_items = psi::test::CreateRangeItems(1, numel); \ - auto b_items = psi::test::CreateRangeItems(2, numel); \ - auto ctxs = yacl::link::test::SetupWorld(2); \ - auto proc = [](const std::shared_ptr& ctx, \ - const std::vector& items, \ - size_t target_rank) -> std::vector { \ - const auto curve = psi::test::GetOverrideCurveType(); \ - return psi::RunEcdhPsi(ctx, items, target_rank, \ - curve.has_value() ? *curve : (CurveType)); \ - }; \ - \ - state.ResumeTiming(); \ - \ - auto fa = std::async(proc, ctxs[0], a_items, 0); \ - auto fb = std::async(proc, ctxs[1], b_items, 0); \ - \ - auto results_a = fa.get(); \ - auto results_b = fb.get(); \ - } \ - } - -#define PSI_BM_DEFINE_ECDH() \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_25519); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_FOURQ); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_SM2); \ - PSI_BM_DEFINE_ECDH_TYPE(CURVE_SECP256K1); - -PSI_BM_DEFINE_ECDH() - -#define ECDH_OPRF_SENDER_OFFLINE() \ - { \ - /* offline: init */ \ - boost::uuids::random_generator uuid_generator; \ - auto uuid_str = boost::uuids::to_string(uuid_generator()); \ - /* server input */ \ - auto server_input_path = \ - std::filesystem::path(fmt::format("server-input-{}", uuid_str)); \ - \ - /* server output */ \ - auto server_tmp_cache_path = \ - std::filesystem::path(fmt::format("tmp-cache-{}", uuid_str)); \ - /* register remove of temp file. */ \ - ON_SCOPE_EXIT([&] { \ - std::error_code ec; \ - std::filesystem::remove(server_input_path, ec); \ - if (ec.value() != 0) { \ - SPDLOG_WARN("can not remove tmp file: {}, msg: {}", \ - server_input_path.c_str(), ec.message()); \ - } \ - std::filesystem::remove(server_tmp_cache_path, ec); \ - if (ec.value() != 0) { \ - SPDLOG_WARN("can not remove tmp file: {}, msg: {}", \ - server_tmp_cache_path.c_str(), ec.message()); \ - } \ - }); \ - \ - WriteCsvFile(server_input_path.string(), items); \ - std::vector cloumn_ids = {"id"}; \ - std::shared_ptr item_provider = \ - std::make_shared( \ - server_input_path.string(), cloumn_ids, kEcdhOprfPsiBatchSize, \ - 100000, true); \ - \ - std::shared_ptr ub_cache = std::make_shared( \ - server_tmp_cache_path.string(), offline_proc.GetCompareLength(), \ - cloumn_ids); \ - offline_proc.FullEvaluate(item_provider, ub_cache); \ - \ - /* offline: finalize */ \ - std::shared_ptr batch_provider = \ - std::make_shared(server_tmp_cache_path.string(), \ - kEcdhOprfPsiBatchSize, \ - offline_proc.GetCompareLength()); \ - offline_proc.SendFinalEvaluatedItems(batch_provider); \ - } - -#define ECDH_OPRF_SENDER_ONLINE() \ - { /* online */ \ - online_proc.RecvBlindAndSendEvaluate(); \ - } - -#define ECDH_OPRF_RECEIVER_OFFLINE() \ - { \ - /* offline */ \ - auto peer_ec_point_store = std::make_shared(); \ - offline_proc.RecvFinalEvaluatedItems(peer_ec_point_store); \ - } - -#define ECDH_OPRF_RECEIVER_ONLINE() \ - { /* online */ \ - auto self_ec_point_store = std::make_shared(); \ - auto proc_send = std::async([&] { \ - auto item_provider = \ - std::make_shared(items, kEcdhOprfPsiBatchSize); \ - online_proc.SendBlindedItems(item_provider); \ - }); \ - \ - auto proc_recv = std::async( \ - [&] { online_proc.RecvEvaluatedItems(self_ec_point_store); }); \ - \ - proc_send.get(); \ - proc_recv.get(); \ - } - -#define PSI_BM_DEFINE_ECDH_OPRF_FULL(CurveType) \ - BENCHMARK_DEFINE_F(PsiBench, EcdhPsiOprf_##CurveType) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto a_items = psi::test::CreateRangeItems(1, numel); \ - auto b_items = psi::test::CreateRangeItems(2, numel); \ - auto ctxs = yacl::link::test::SetupWorld(2); \ - \ - /* We let bob obtains the final result */ \ - auto a_proc = [](const std::shared_ptr& ctx, \ - const std::vector& items) { \ - EcdhOprfPsiOptions options; \ - options.curve_type = (CurveType); \ - options.link0 = ctx; \ - options.link1 = ctx->Spawn(); \ - \ - /* Offline Phase */ \ - auto offline_proc = EcdhOprfPsiServer(options); \ - ECDH_OPRF_SENDER_OFFLINE() \ - \ - /* Online Phase */ \ - const auto sk = offline_proc.GetPrivateKey(); \ - auto online_proc = EcdhOprfPsiServer(options, sk); \ - ECDH_OPRF_SENDER_ONLINE() \ - }; \ - \ - auto b_proc = [](const std::shared_ptr& ctx, \ - const std::vector& items) { \ - EcdhOprfPsiOptions options; \ - options.curve_type = (CurveType); \ - options.link0 = ctx; \ - options.link1 = ctx->Spawn(); \ - \ - /* Offline Phase */ \ - auto offline_proc = EcdhOprfPsiClient(options); \ - ECDH_OPRF_RECEIVER_OFFLINE() \ - \ - /* Online Phase */ \ - auto online_proc = EcdhOprfPsiClient(options); \ - ECDH_OPRF_RECEIVER_ONLINE() \ - }; \ - \ - state.ResumeTiming(); \ - \ - auto fa = std::async(a_proc, ctxs[0], a_items); \ - auto fb = std::async(b_proc, ctxs[1], b_items); \ - fa.get(); \ - fb.get(); \ - } \ - } - -#define PSI_BM_DEFINE_ECDH_OPRF() \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_25519); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_FOURQ); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_SM2); \ - PSI_BM_DEFINE_ECDH_OPRF_FULL(CURVE_SECP256K1); - -PSI_BM_DEFINE_ECDH_OPRF() - -BENCHMARK_DEFINE_F(PsiBench, KkrtPsi) -(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t numel = state.range(0); - auto a_items = psi::test::CreateItemHashes(1, numel); - auto b_items = psi::test::CreateItemHashes(2, numel); - auto ctxs = yacl::link::test::SetupWorld(2); - - /* Sender */ - auto a_proc = [](const std::shared_ptr& ctx, - const std::vector& items) { - auto ot_recv = psi::GetKkrtOtSenderOptions(ctx, 512); - psi::KkrtPsiSend(ctx, ot_recv, items); - }; - - /* Receiver */ - auto b_proc = [](const std::shared_ptr& ctx, - const std::vector& items) { - auto ot_send = psi::GetKkrtOtReceiverOptions(ctx, 512); - return psi::KkrtPsiRecv(ctx, ot_send, items); - }; - - state.ResumeTiming(); - - auto fa = std::async(a_proc, ctxs[0], a_items); - auto fb = std::async(b_proc, ctxs[1], b_items); - - fa.get(); - auto results = fb.get(); - } -} - -BENCHMARK_DEFINE_F(PsiBench, Bc22Psi) -(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t numel = state.range(0); - auto a_items = psi::test::CreateRangeItems(1, numel); - auto b_items = psi::test::CreateRangeItems(2, numel); - auto ctxs = yacl::link::test::SetupWorld(2); - - Bc22PcgPsi sender(ctxs[0], PsiRoleType::Sender); - Bc22PcgPsi receiver(ctxs[1], PsiRoleType::Receiver); - - /* Sender */ - auto a_proc = [](const std::shared_ptr& ctx, - const std::vector& items) { - Bc22PcgPsi party(ctx, PsiRoleType::Sender); - party.RunPsi(items); - }; - - /* Receiver */ - auto b_proc = [](const std::shared_ptr& ctx, - const std::vector& items) { - Bc22PcgPsi party(ctx, PsiRoleType::Receiver); - party.RunPsi(items); - return party.GetIntersection(); - }; - - state.ResumeTiming(); - - auto fa = std::async(a_proc, ctxs[0], a_items); - auto fb = std::async(b_proc, ctxs[1], b_items); - - fa.get(); - auto results = fb.get(); - } -} - -#define PSI_BM_DEFINE_MINI_TYPE(IsBatch) \ - BENCHMARK_DEFINE_F(PsiBench, MiniPsi##_##IsBatch) \ - (benchmark::State & state) { \ - for (auto _ : state) { \ - state.PauseTiming(); \ - size_t numel = state.range(0); \ - auto a_items = psi::test::CreateRangeItems(1, numel); \ - auto b_items = psi::test::CreateRangeItems(2, numel); \ - auto ctxs = yacl::link::test::SetupWorld(2); \ - \ - /* Sender */ \ - auto a_proc = [](const std::shared_ptr& ctx, \ - const std::vector& items) { \ - psi::MiniPsiSend(ctx, items); \ - }; \ - \ - /* Receiver */ \ - auto b_proc = [](const std::shared_ptr& ctx, \ - const std::vector& items) { \ - psi::MiniPsiRecv(ctx, items); \ - }; \ - \ - state.ResumeTiming(); \ - \ - auto fa = std::async(a_proc, ctxs[0], a_items); \ - auto fb = std::async(b_proc, ctxs[1], b_items); \ - \ - fa.get(); \ - fb.get(); \ - } \ - } - -#define PSI_BM_DEFINE_MINI() \ - PSI_BM_DEFINE_MINI_TYPE(NoBatch) \ - PSI_BM_DEFINE_MINI_TYPE(Batch) -PSI_BM_DEFINE_MINI() - -#define PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CurveType, Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, PsiType##_##CurveType)->Apply(Arguments); - -#define PSI_BM_REGISTER_CURVE_PSI(PsiType, Arguments) \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_25519, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_FOURQ, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_SM2, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(PsiType, CURVE_SECP256K1, Arguments); - -#define BM_REGISTER_ECDH_PSI(Arguments) \ - PSI_BM_REGISTER_CURVE_PSI(EcdhPsi, Arguments) - -#define BM_REGISTER_ECDH_OPRF_PSI(Arguments) \ - /* Currently, ECDH OPRF does not support Curve25518 donna */ \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_FOURQ, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_SM2, Arguments); \ - PSI_BM_REGISTER_CURVE_PSI_TYPE(EcdhPsiOprf, CURVE_SECP256K1, Arguments); - -#define BM_REGISTER_KKRT_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, KkrtPsi)->Apply(Arguments); - -#define BM_REGISTER_BC22_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, Bc22Psi)->Apply(Arguments); - -#define BM_REGISTER_MINI_PSI(Arguments) \ - BENCHMARK_REGISTER_F(PsiBench, MiniPsi_NoBatch)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(PsiBench, MiniPsi_Batch)->Apply(Arguments); - -#define BM_REGISTER_ALL_PSI(Arguments) \ - BM_REGISTER_ECDH_PSI(Arguments); \ - BM_REGISTER_ECDH_OPRF_PSI(Arguments); \ - BM_REGISTER_KKRT_PSI(Arguments); \ - BM_REGISTER_BC22_PSI(Arguments); \ - BM_REGISTER_MINI_PSI(Arguments); - -} // namespace psi::psi::bench diff --git a/psi/psi/core/BUILD.bazel b/psi/psi/core/BUILD.bazel deleted file mode 100644 index dba9de94..00000000 --- a/psi/psi/core/BUILD.bazel +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "communication", - srcs = ["communication.cc"], - hdrs = ["communication.h"], - deps = [ - ":ic_protocol_psi_cc_proto", - "//psi/psi/utils:serialize", - "@yacl//yacl/base:exception", - "@yacl//yacl/link", - ], -) - -cc_proto_library( - name = "ic_protocol_psi_cc_proto", - deps = ["@org_interconnection//interconnection/runtime:ecdh_psi"], -) - -psi_cc_library( - name = "ecdh_psi", - srcs = ["ecdh_psi.cc"], - hdrs = ["ecdh_psi.h"], - deps = [ - ":communication", - "//psi/psi/cryptor:cryptor_selector", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:ec_point_store", - "//psi/psi/utils:recovery", - "@com_google_absl//absl/strings", - "@yacl//yacl/link", - "@yacl//yacl/utils:parallel", - ], -) - -psi_cc_test( - name = "ecdh_psi_test", - srcs = ["ecdh_psi_test.cc"], - deps = [ - ":ecdh_psi", - "//psi/psi/utils:test_utils", - ], -) - -psi_cc_binary( - name = "ecdh_psi_bench", - srcs = ["ecdh_psi_bench.cc"], - deps = [ - ":ecdh_psi", - "@com_github_google_benchmark//:benchmark_main", - ], -) - -psi_cc_library( - name = "ecdh_3pc_psi", - srcs = ["ecdh_3pc_psi.cc"], - hdrs = ["ecdh_3pc_psi.h"], - deps = [ - ":ecdh_psi", - ], -) - -psi_cc_test( - name = "ecdh_3pc_psi_test", - srcs = ["ecdh_3pc_psi_test.cc"], - deps = [ - ":ecdh_3pc_psi", - "//psi/psi/utils:test_utils", - ], -) - -psi_cc_binary( - name = "ecdh_3pc_psi_bench", - srcs = ["ecdh_3pc_psi_bench.cc"], - deps = [ - ":ecdh_3pc_psi", - "//psi/psi/utils:test_utils", - "@com_github_google_benchmark//:benchmark_main", - ], -) - -psi_cc_library( - name = "cuckoo_index", - srcs = ["cuckoo_index.cc"], - hdrs = ["cuckoo_index.h"], - linkopts = ["-lm"], - deps = [ - "@com_google_absl//absl/types:span", - "@yacl//yacl/base:exception", - "@yacl//yacl/base:int128", - ], -) - -psi_cc_test( - name = "cuckoo_index_test", - srcs = ["cuckoo_index_test.cc"], - deps = [ - ":cuckoo_index", - "@yacl//yacl/crypto/utils:rand", - ], -) - -psi_cc_library( - name = "kkrt_psi", - srcs = ["kkrt_psi.cc"], - hdrs = ["kkrt_psi.h"], - deps = [ - ":communication", - ":cuckoo_index", - "//psi/psi/utils:serialize", - "@com_google_absl//absl/strings", - "@yacl//yacl/crypto/base/hash:hash_utils", - "@yacl//yacl/crypto/primitives/ot:base_ot", - "@yacl//yacl/crypto/primitives/ot:iknp_ote", - "@yacl//yacl/crypto/primitives/ot:kkrt_ote", - "@yacl//yacl/crypto/utils:rand", - "@yacl//yacl/link", - ], -) - -psi_cc_test( - name = "kkrt_psi_test", - srcs = ["kkrt_psi_test.cc"], - deps = [ - ":kkrt_psi", - "@yacl//yacl/crypto/base/hash:hash_utils", - ], -) - -psi_cc_binary( - name = "kkrt_psi_bench", - srcs = ["kkrt_psi_bench.cc"], - deps = [ - ":kkrt_psi", - "@com_github_google_benchmark//:benchmark_main", - ], -) - -psi_cc_library( - name = "ecdh_oprf_psi", - srcs = ["ecdh_oprf_psi.cc"], - hdrs = ["ecdh_oprf_psi.h"], - deps = [ - ":communication", - "//psi/psi/core/ecdh_oprf:ecdh_oprf_selector", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:ec_point_store", - "//psi/psi/utils:ub_psi_cache", - "@com_google_absl//absl/strings", - "@yacl//yacl/base:exception", - "@yacl//yacl/link", - "@yacl//yacl/utils:parallel", - ], -) - -psi_cc_test( - name = "ecdh_oprf_psi_test", - srcs = ["ecdh_oprf_psi_test.cc"], - deps = [ - ":ecdh_oprf_psi", - "//psi/psi/utils:test_utils", - "@boost//:uuid", - "@com_google_absl//absl/time", - "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/crypto/utils:rand", - "@yacl//yacl/utils:scope_guard", - ], -) - -psi_cc_library( - name = "mini_psi", - srcs = ["mini_psi.cc"], - hdrs = ["mini_psi.h"], - defines = ["CURVE25519_DONNA"], - deps = [ - ":communication", - ":cuckoo_index", - "//psi/psi/core/polynomial", - "//psi/psi/utils:batch_provider", - "//psi/psi/utils:serialize", - "//psi/psi/utils:test_utils", - "@com_github_floodyberry_curve25519_donna//:curve25519_donna", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@yacl//yacl/crypto/base/hash:hash_utils", - "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/link", - "@yacl//yacl/utils:parallel", - ], -) - -psi_cc_test( - name = "mini_psi_test", - srcs = ["mini_psi_test.cc"], - deps = [ - ":mini_psi", - ], -) - -psi_cc_binary( - name = "mini_psi_demo", - srcs = ["mini_psi_demo.cc"], - deps = [ - ":ecdh_psi", - ":kkrt_psi", - ":mini_psi", - ], -) diff --git a/psi/psi/core/ecdh_oprf/BUILD.bazel b/psi/psi/core/ecdh_oprf/BUILD.bazel deleted file mode 100644 index 9b7db976..00000000 --- a/psi/psi/core/ecdh_oprf/BUILD.bazel +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "ecdh_oprf", - srcs = ["ecdh_oprf.cc"], - hdrs = ["ecdh_oprf.h"], - # Openssl::libcrypto requires `dlopen`... - linkopts = ["-ldl"], - deps = [ - "//psi/psi/cryptor:ecc_cryptor", - "@com_github_openssl_openssl//:openssl", - "@com_google_absl//absl/types:span", - "@yacl//yacl/base:byte_container_view", - "@yacl//yacl/base:exception", - "@yacl//yacl/utils:parallel", - ], -) - -psi_cc_library( - name = "basic_ecdh_oprf", - srcs = ["basic_ecdh_oprf.cc"], - hdrs = ["basic_ecdh_oprf.h"], - defines = [ - "__LINUX__", - ] + select({ - "@bazel_tools//src/conditions:linux_x86_64": [ - "_AMD64_", - "_ASM_", - ], - "@bazel_tools//src/conditions:darwin_arm64": [ - "_ARM64_", - ], - "//conditions:default": [ - "_AMD64_", - ], - }), - deps = [ - ":ecdh_oprf", - "//psi/psi/cryptor:ecc_utils", - "//psi/psi/cryptor:sm2_cryptor", - "@com_github_microsoft_apsi//:apsi", - "@com_google_absl//absl/types:span", - "@yacl//yacl/base:exception", - "@yacl//yacl/crypto/base/hash:blake3", - "@yacl//yacl/crypto/base/hash:hash_utils", - "@yacl//yacl/utils:parallel", - ], -) - -psi_cc_library( - name = "ecdh_oprf_selector", - srcs = ["ecdh_oprf_selector.cc"], - hdrs = ["ecdh_oprf_selector.h"], - deps = [ - ":basic_ecdh_oprf", - "@yacl//yacl/utils:platform_utils", - ], -) - -psi_cc_test( - name = "basic_ecdh_oprf_test", - srcs = ["basic_ecdh_oprf_test.cc"], - deps = [ - ":ecdh_oprf_selector", - "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/crypto/utils:rand", - ], -) diff --git a/psi/psi/core/generate_psi.py b/psi/psi/core/generate_psi.py deleted file mode 100644 index 06263db7..00000000 --- a/psi/psi/core/generate_psi.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from random import randint -from random import sample -import csv -import sys - - -def random_with_N_digits(n): - range_start = 10 ** (n - 1) - range_end = (10**n) - 1 - return randint(range_start, range_end) - - -row_list = [] -len1 = 10**2 -len2 = 10 -len3 = 10 -len4 = 10 - -if len(sys.argv) > 1: - len1 = int(sys.argv[1]) - len2 = int(len1 / 2) - -if len(sys.argv) > 2: - len3 = int(sys.argv[2]) - -len4 = int(len3 / 2) -print(len1, len2) - - -for i in range(len1): - data_list = [random_with_N_digits(18)] - row_list.append(data_list) - -row_list2 = sample(row_list, len2) -for i in range(len2, len1): - data_list = [random_with_N_digits(18)] - row_list2.append(data_list) - -row_list3 = sample(row_list, len4) -for i in range(len4, len3): - data_list = [random_with_N_digits(18)] - row_list3.append(data_list) - -print(len(row_list2)) -print(len(row_list3)) - -with open('psi_1.csv', 'w', newline='') as file: - writer = csv.writer(file) - writer.writerow(["id"]) - writer.writerows(row_list) - -with open('psi_2.csv', 'w', newline='') as file: - writer = csv.writer(file) - writer.writerow(["id"]) - writer.writerows(row_list2) - -with open('psi_3.csv', 'w', newline='') as file: - writer = csv.writer(file) - writer.writerow(["id"]) - writer.writerows(row_list3) diff --git a/psi/psi/ecdh/BUILD.bazel b/psi/psi/ecdh/BUILD.bazel deleted file mode 100644 index 6ae08fe2..00000000 --- a/psi/psi/ecdh/BUILD.bazel +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "common", - hdrs = ["common.h"], -) - -psi_cc_library( - name = "receiver", - srcs = ["receiver.cc"], - hdrs = ["receiver.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) - -psi_cc_library( - name = "sender", - srcs = ["sender.cc"], - hdrs = ["sender.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) - -psi_cc_library( - name = "client", - srcs = ["client.cc"], - hdrs = ["client.h"], - deps = [ - "//psi/psi:interface", - "//psi/psi/core:ecdh_oprf_psi", - "//psi/psi/utils:sync", - ], -) - -psi_cc_library( - name = "server", - srcs = ["server.cc"], - hdrs = ["server.h"], - deps = [ - "//psi/psi:interface", - "//psi/psi/core:ecdh_oprf_psi", - "//psi/psi/utils:ec", - "//psi/psi/utils:sync", - ], -) diff --git a/psi/psi/kkrt/BUILD.bazel b/psi/psi/kkrt/BUILD.bazel deleted file mode 100644 index fd6fb0e5..00000000 --- a/psi/psi/kkrt/BUILD.bazel +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "common", - srcs = ["common.cc"], - hdrs = ["common.h"], - deps = [ - "//psi/proto:psi_v2_cc_proto", - "//psi/psi/utils:bucket", - "//psi/psi/utils:recovery", - ], -) - -psi_cc_library( - name = "receiver", - srcs = ["receiver.cc"], - hdrs = ["receiver.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/core:kkrt_psi", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) - -psi_cc_library( - name = "sender", - srcs = ["sender.cc"], - hdrs = ["sender.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/core:kkrt_psi", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) diff --git a/psi/psi/rr22/BUILD.bazel b/psi/psi/rr22/BUILD.bazel deleted file mode 100644 index dacf8435..00000000 --- a/psi/psi/rr22/BUILD.bazel +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "common", - srcs = ["common.cc"], - hdrs = ["common.h"], - deps = [ - "//psi/proto:psi_v2_cc_proto", - "//psi/psi/core/vole_psi:rr22_psi", - "//psi/psi/utils:bucket", - "//psi/psi/utils:recovery", - ], -) - -psi_cc_library( - name = "receiver", - srcs = ["receiver.cc"], - hdrs = ["receiver.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) - -psi_cc_library( - name = "sender", - srcs = ["sender.cc"], - hdrs = ["sender.h"], - deps = [ - ":common", - "//psi/psi:interface", - "//psi/psi/utils:arrow_csv_batch_provider", - ], -) diff --git a/psi/psi/psi_test.cc b/psi/psi_test.cc similarity index 99% rename from psi/psi/psi_test.cc rename to psi/psi_test.cc index ca6b2a8d..c78d05b5 100644 --- a/psi/psi/psi_test.cc +++ b/psi/psi_test.cc @@ -35,13 +35,13 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/psi/factory.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/io.h" +#include "psi/factory.h" +#include "psi/prelude.h" +#include "psi/utils/io.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { namespace { struct TestTable { @@ -702,4 +702,4 @@ INSTANTIATE_TEST_SUITE_P( v2::PsiConfig::ADVANCED_JOIN_TYPE_INNER_JOIN}))); } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/vole_psi/BUILD.bazel b/psi/rr22/BUILD.bazel similarity index 74% rename from psi/psi/core/vole_psi/BUILD.bazel rename to psi/rr22/BUILD.bazel index 3e15f8db..2e864a47 100644 --- a/psi/psi/core/vole_psi/BUILD.bazel +++ b/psi/rr22/BUILD.bazel @@ -45,8 +45,8 @@ psi_cc_library( hdrs = ["rr22_oprf.h"], deps = [ ":davis_meyer_hash", - "//psi/psi/core/vole_psi/okvs:aes_crhash", - "//psi/psi/core/vole_psi/okvs:baxos", + "//psi/rr22/okvs:aes_crhash", + "//psi/rr22/okvs:baxos", "@yacl//yacl/base:buffer", "@yacl//yacl/crypto/primitives/vole:silent_vole", "@yacl//yacl/crypto/tools:prg", @@ -63,7 +63,29 @@ psi_cc_test( srcs = ["rr22_oprf_test.cc"], deps = [ ":rr22_oprf", - "//psi/psi/utils:test_utils", + "//psi/utils:test_utils", + ], +) + +psi_cc_library( + name = "sparsehash_config", + hdrs = ["sparseconfig.h"], + include_prefix = "sparsehash/internal", + visibility = ["//visibility:public"], +) + +psi_cc_library( + name = "rr22_utils", + srcs = ["rr22_utils.cc"], + hdrs = ["rr22_utils.h"], + deps = [ + "//psi/rr22/okvs:galois128", + "//psi/rr22/okvs:simple_index", + "@com_github_ridiculousfish_libdivide//:libdivide", + "@com_github_sparsehash_sparsehash//:sparsehash", + "@yacl//yacl/base:buffer", + "@yacl//yacl/base:int128", + "@yacl//yacl/link", ], ) @@ -74,7 +96,7 @@ psi_cc_library( deps = [ ":rr22_oprf", ":rr22_utils", - "//psi/psi/utils:sync", + "//psi/utils:sync", ], ) @@ -88,34 +110,46 @@ psi_cc_test( ) psi_cc_binary( - name = "rr22_psi_bench", - srcs = ["rr22_psi_bench.cc"], + name = "rr22_psi_benchmark", + srcs = ["rr22_psi_benchmark.cc"], deps = [ ":rr22_psi", - "//psi/psi/utils:test_utils", + "//psi/utils:test_utils", "@com_github_google_benchmark//:benchmark_main", "@yacl//yacl/crypto/utils:rand", ], ) psi_cc_library( - name = "rr22_utils", - srcs = ["rr22_utils.cc"], - hdrs = ["rr22_utils.h"], + name = "common", + srcs = ["common.cc"], + hdrs = ["common.h"], deps = [ - "//psi/psi/core/vole_psi/okvs:galois128", - "//psi/psi/core/vole_psi/okvs:simple_index", - "@com_github_ridiculousfish_libdivide//:libdivide", - "@com_github_sparsehash_sparsehash//:sparsehash", - "@yacl//yacl/base:buffer", - "@yacl//yacl/base:int128", - "@yacl//yacl/link", + ":rr22_psi", + "//psi/proto:psi_v2_cc_proto", + "//psi/utils:bucket", + "//psi/utils:recovery", ], ) psi_cc_library( - name = "sparsehash_config", - hdrs = ["sparseconfig.h"], - include_prefix = "sparsehash/internal", - visibility = ["//visibility:public"], + name = "receiver", + srcs = ["receiver.cc"], + hdrs = ["receiver.h"], + deps = [ + ":common", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], +) + +psi_cc_library( + name = "sender", + srcs = ["sender.cc"], + hdrs = ["sender.h"], + deps = [ + ":common", + "//psi:interface", + "//psi/utils:arrow_csv_batch_provider", + ], ) diff --git a/psi/psi/rr22/common.cc b/psi/rr22/common.cc similarity index 91% rename from psi/psi/rr22/common.cc rename to psi/rr22/common.cc index 75327255..ceb151c8 100644 --- a/psi/psi/rr22/common.cc +++ b/psi/rr22/common.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/rr22/common.h" +#include "psi/rr22/common.h" #include "omp.h" -#include "psi/psi/utils/bucket.h" +#include "psi/utils/bucket.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { Rr22PsiOptions GenerateRr22PsiOptions(bool low_comm_mode) { Rr22PsiOptions options(kDefaultSSP, omp_get_num_procs(), kDefaultCompress); @@ -40,4 +40,4 @@ void CommonInit(const std::string& key_hash_digest, v2::PsiConfig* config, } } -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/rr22/common.h b/psi/rr22/common.h similarity index 84% rename from psi/psi/rr22/common.h rename to psi/rr22/common.h index 1d5f0dfe..0c505ea1 100644 --- a/psi/psi/rr22/common.h +++ b/psi/rr22/common.h @@ -15,13 +15,13 @@ #include -#include "psi/psi/core/vole_psi/rr22_oprf.h" -#include "psi/psi/core/vole_psi/rr22_psi.h" -#include "psi/psi/utils/recovery.h" +#include "psi/rr22/rr22_oprf.h" +#include "psi/rr22/rr22_psi.h" +#include "psi/utils/recovery.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { // Statistical security parameter constexpr size_t kDefaultSSP = 40; @@ -34,4 +34,4 @@ Rr22PsiOptions GenerateRr22PsiOptions(bool low_comm_mode); void CommonInit(const std::string& key_hash_digest, v2::PsiConfig* config, RecoveryManager* recovery_manager); -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/davis_meyer_hash.cc b/psi/rr22/davis_meyer_hash.cc similarity index 94% rename from psi/psi/core/vole_psi/davis_meyer_hash.cc rename to psi/rr22/davis_meyer_hash.cc index 2945b1e6..66469f3a 100644 --- a/psi/psi/core/vole_psi/davis_meyer_hash.cc +++ b/psi/rr22/davis_meyer_hash.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/davis_meyer_hash.h" +#include "psi/rr22/davis_meyer_hash.h" #include "yacl/crypto/base/aes/aes_opt.h" #include "yacl/crypto/base/block_cipher/symmetric_crypto.h" #include "yacl/utils/platform_utils.h" -namespace psi::psi { +namespace psi::rr22 { namespace { // vector aes batch size -[[maybe_unused]] constexpr size_t kEncBatch = 8; +constexpr size_t kEncBatch = 8; } // namespace @@ -82,4 +82,4 @@ void DavisMeyerHash(absl::Span key, } } -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/davis_meyer_hash.h b/psi/rr22/davis_meyer_hash.h similarity index 96% rename from psi/psi/core/vole_psi/davis_meyer_hash.h rename to psi/rr22/davis_meyer_hash.h index 4d504513..3ca48370 100644 --- a/psi/psi/core/vole_psi/davis_meyer_hash.h +++ b/psi/rr22/davis_meyer_hash.h @@ -17,7 +17,7 @@ #include "absl/types/span.h" #include "yacl/base/int128.h" -namespace psi::psi { +namespace psi::rr22 { // // Reference: @@ -34,4 +34,4 @@ void DavisMeyerHash(absl::Span key, absl::Span value, absl::Span outputs); -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/davis_meyer_hash_test.cc b/psi/rr22/davis_meyer_hash_test.cc similarity index 93% rename from psi/psi/core/vole_psi/davis_meyer_hash_test.cc rename to psi/rr22/davis_meyer_hash_test.cc index 4d9d752c..818e259e 100644 --- a/psi/psi/core/vole_psi/davis_meyer_hash_test.cc +++ b/psi/rr22/davis_meyer_hash_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/davis_meyer_hash.h" +#include "psi/rr22/davis_meyer_hash.h" #include "gtest/gtest.h" #include "spdlog/spdlog.h" #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi { +namespace psi::rr22 { TEST(DavisMeyerHashTest, Works) { uint128_t seed = yacl::crypto::SecureRandU128(); @@ -46,4 +46,4 @@ TEST(DavisMeyerHashTest, Works) { } } -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/okvs/BUILD.bazel b/psi/rr22/okvs/BUILD.bazel similarity index 100% rename from psi/psi/core/vole_psi/okvs/BUILD.bazel rename to psi/rr22/okvs/BUILD.bazel diff --git a/psi/psi/core/vole_psi/okvs/aes_crhash.cc b/psi/rr22/okvs/aes_crhash.cc similarity index 96% rename from psi/psi/core/vole_psi/okvs/aes_crhash.cc rename to psi/rr22/okvs/aes_crhash.cc index 0258762c..f4250466 100644 --- a/psi/psi/core/vole_psi/okvs/aes_crhash.cc +++ b/psi/rr22/okvs/aes_crhash.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/aes_crhash.h" +#include "psi/rr22/okvs/aes_crhash.h" #include #include "spdlog/spdlog.h" #include "yacl/utils/parallel.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { namespace { @@ -93,4 +93,4 @@ uint128_t AesCrHash::Hash(uint128_t input) const { return output; } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/aes_crhash.h b/psi/rr22/okvs/aes_crhash.h similarity index 96% rename from psi/psi/core/vole_psi/okvs/aes_crhash.h rename to psi/rr22/okvs/aes_crhash.h index 8355c380..baa0792e 100644 --- a/psi/psi/core/vole_psi/okvs/aes_crhash.h +++ b/psi/rr22/okvs/aes_crhash.h @@ -20,7 +20,7 @@ // Correlation robust hash function. // H(x) = AES(x) + x. -namespace psi::psi::okvs { +namespace psi::rr22::okvs { class AesCrHash : public yacl::crypto::SymmetricCrypto { public: @@ -41,4 +41,4 @@ class AesCrHash : public yacl::crypto::SymmetricCrypto { uint128_t Hash(uint128_t input) const; }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/aes_crhash_test.cc b/psi/rr22/okvs/aes_crhash_test.cc similarity index 92% rename from psi/psi/core/vole_psi/okvs/aes_crhash_test.cc rename to psi/rr22/okvs/aes_crhash_test.cc index b7c03ac6..f42963f4 100644 --- a/psi/psi/core/vole_psi/okvs/aes_crhash_test.cc +++ b/psi/rr22/okvs/aes_crhash_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/aes_crhash.h" +#include "psi/rr22/okvs/aes_crhash.h" #include #include @@ -22,9 +22,9 @@ #include "spdlog/spdlog.h" #include "yacl/crypto/tools/prg.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { class AesCrHashTest : public testing::TestWithParam {}; @@ -58,4 +58,4 @@ TEST_P(AesCrHashTest, Works) { INSTANTIATE_TEST_SUITE_P(Works_Instances, AesCrHashTest, testing::Values(1, 2, 5, 10, 20)); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/baxos.cc b/psi/rr22/okvs/baxos.cc similarity index 99% rename from psi/psi/core/vole_psi/okvs/baxos.cc rename to psi/rr22/okvs/baxos.cc index ad6039c7..a4281d6f 100644 --- a/psi/psi/core/vole_psi/okvs/baxos.cc +++ b/psi/rr22/okvs/baxos.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/baxos.h" +#include "psi/rr22/okvs/baxos.h" #include #include @@ -22,9 +22,9 @@ #include "spdlog/spdlog.h" -#include "psi/psi/core/vole_psi/okvs/simple_index.h" +#include "psi/rr22/okvs/simple_index.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { uint64_t Baxos::GetBinSize(uint64_t num_bins, uint64_t num_balls, uint64_t stat_sec_param) { @@ -678,4 +678,4 @@ void Baxos::ImplParDecode(absl::Span inputs, PxVector& values, } } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/baxos.h b/psi/rr22/okvs/baxos.h similarity index 95% rename from psi/psi/core/vole_psi/okvs/baxos.h rename to psi/rr22/okvs/baxos.h index 9fb714f1..02d8cdd5 100644 --- a/psi/psi/core/vole_psi/okvs/baxos.h +++ b/psi/rr22/okvs/baxos.h @@ -20,12 +20,12 @@ #include "absl/types/span.h" -#include "psi/psi/core/vole_psi/okvs/dense_mtx.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" -#include "psi/psi/core/vole_psi/okvs/paxos.h" -#include "psi/psi/core/vole_psi/okvs/paxos_utils.h" +#include "psi/rr22/okvs/dense_mtx.h" +#include "psi/rr22/okvs/galois128.h" +#include "psi/rr22/okvs/paxos.h" +#include "psi/rr22/okvs/paxos_utils.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { // a binned version of paxos. Internally calls paxos. class Baxos { @@ -160,4 +160,4 @@ class Baxos { } }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/baxos_test.cc b/psi/rr22/okvs/baxos_test.cc similarity index 95% rename from psi/psi/core/vole_psi/okvs/baxos_test.cc rename to psi/rr22/okvs/baxos_test.cc index 8aa09414..ae28aa56 100644 --- a/psi/psi/core/vole_psi/okvs/baxos_test.cc +++ b/psi/rr22/okvs/baxos_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/baxos.h" +#include "psi/rr22/okvs/baxos.h" #include #include @@ -22,7 +22,7 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { class BaxosTest : public testing::TestWithParam {}; @@ -71,4 +71,4 @@ TEST_P(BaxosTest, WORKS) { INSTANTIATE_TEST_SUITE_P(Works_Instances, BaxosTest, testing::Values(16, 32, 64, 128, 2048)); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/dense_mtx.cc b/psi/rr22/okvs/dense_mtx.cc similarity index 96% rename from psi/psi/core/vole_psi/okvs/dense_mtx.cc rename to psi/rr22/okvs/dense_mtx.cc index b7658feb..0c9e38ce 100644 --- a/psi/psi/core/vole_psi/okvs/dense_mtx.cc +++ b/psi/rr22/okvs/dense_mtx.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/dense_mtx.h" +#include "psi/rr22/okvs/dense_mtx.h" #include -namespace psi::psi::okvs { +namespace psi::rr22::okvs { void DenseMtx::Row::swap(const Row& r) { YACL_ENFORCE(mtx.cols() == r.mtx.cols()); @@ -131,4 +131,4 @@ DenseMtx DenseMtx::Invert() const { return Inv; } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/dense_mtx.h b/psi/rr22/okvs/dense_mtx.h similarity index 98% rename from psi/psi/core/vole_psi/okvs/dense_mtx.h rename to psi/rr22/okvs/dense_mtx.h index 096a0d27..81c00625 100644 --- a/psi/psi/core/vole_psi/okvs/dense_mtx.h +++ b/psi/rr22/okvs/dense_mtx.h @@ -19,9 +19,9 @@ #include "yacl/base/exception.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { // A class to reference a specific bit. class BitReference { @@ -273,4 +273,4 @@ class DenseMtx { std::ostream& operator<<(std::ostream& o, const DenseMtx& H); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/galois128.cc b/psi/rr22/okvs/galois128.cc similarity index 96% rename from psi/psi/core/vole_psi/okvs/galois128.cc rename to psi/rr22/okvs/galois128.cc index 800d6501..181f2979 100644 --- a/psi/psi/core/vole_psi/okvs/galois128.cc +++ b/psi/rr22/okvs/galois128.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" #include @@ -24,7 +24,7 @@ #include "cpu_features/cpuinfo_x86.h" #endif -namespace psi::psi::okvs { +namespace psi::rr22::okvs { // namespace { @@ -202,11 +202,11 @@ Galois128 Galois128::Inv() const { return result; } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs namespace std { -std::ostream& operator<<(std::ostream& os, psi::psi::okvs::Galois128 x) { +std::ostream& operator<<(std::ostream& os, psi::rr22::okvs::Galois128 x) { return os << absl::BytesToHexString( absl::string_view((const char*)x.data(), 16)); } diff --git a/psi/psi/core/vole_psi/okvs/galois128.h b/psi/rr22/okvs/galois128.h similarity index 95% rename from psi/psi/core/vole_psi/okvs/galois128.h rename to psi/rr22/okvs/galois128.h index 326cce9f..d839a101 100644 --- a/psi/psi/core/vole_psi/okvs/galois128.h +++ b/psi/rr22/okvs/galois128.h @@ -26,7 +26,7 @@ // Galois field 2^128 // polynoimal : x^127+x^7+x^2+x+1 -namespace psi::psi::okvs { +namespace psi::rr22::okvs { using Galois128Type = std::variant; @@ -93,10 +93,10 @@ class Galois128 { uint128_t cc_gf128Mul(const uint128_t a, const uint128_t b); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs namespace std { -std::ostream& operator<<(std::ostream& os, psi::psi::okvs::Galois128 x); +std::ostream& operator<<(std::ostream& os, psi::rr22::okvs::Galois128 x); } diff --git a/psi/psi/core/vole_psi/okvs/galois128_test.cc b/psi/rr22/okvs/galois128_test.cc similarity index 95% rename from psi/psi/core/vole_psi/okvs/galois128_test.cc rename to psi/rr22/okvs/galois128_test.cc index 5c2c9421..074d7113 100644 --- a/psi/psi/core/vole_psi/okvs/galois128_test.cc +++ b/psi/rr22/okvs/galois128_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" #include @@ -23,7 +23,7 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { namespace { @@ -76,4 +76,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{yacl::crypto::FastRandU64(), yacl::crypto::FastRandU64()})); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos.cc b/psi/rr22/okvs/paxos.cc similarity index 99% rename from psi/psi/core/vole_psi/okvs/paxos.cc rename to psi/rr22/okvs/paxos.cc index 88d302b5..8823c75f 100644 --- a/psi/psi/core/vole_psi/okvs/paxos.cc +++ b/psi/rr22/okvs/paxos.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/paxos.h" +#include "psi/rr22/okvs/paxos.h" #include #include @@ -20,7 +20,7 @@ #include "yacl/base/exception.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { namespace { @@ -1464,4 +1464,4 @@ template class Paxos; template class Paxos; template class Paxos; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos.h b/psi/rr22/okvs/paxos.h similarity index 96% rename from psi/psi/core/vole_psi/okvs/paxos.h rename to psi/rr22/okvs/paxos.h index 3f6ebc19..ad857338 100644 --- a/psi/psi/core/vole_psi/okvs/paxos.h +++ b/psi/rr22/okvs/paxos.h @@ -22,12 +22,12 @@ #include "libdivide.h" #include "yacl/utils/platform_utils.h" -#include "psi/psi/core/vole_psi/okvs/dense_mtx.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" -#include "psi/psi/core/vole_psi/okvs/paxos_hash.h" -#include "psi/psi/core/vole_psi/okvs/paxos_utils.h" +#include "psi/rr22/okvs/dense_mtx.h" +#include "psi/rr22/okvs/galois128.h" +#include "psi/rr22/okvs/paxos_hash.h" +#include "psi/rr22/okvs/paxos_utils.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { struct PaxosParam { // the type of dense columns. @@ -223,4 +223,4 @@ class Paxos : public PaxosParam { bool add_to_decode_ = false; }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos_hash.cc b/psi/rr22/okvs/paxos_hash.cc similarity index 99% rename from psi/psi/core/vole_psi/okvs/paxos_hash.cc rename to psi/rr22/okvs/paxos_hash.cc index 65b84b9d..b48ca735 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_hash.cc +++ b/psi/rr22/okvs/paxos_hash.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/paxos_hash.h" +#include "psi/rr22/okvs/paxos_hash.h" #include #include "yacl/base/int128.h" #include "yacl/utils/platform_utils.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { template void PaxosHash::mod32(uint64_t* vals, uint64_t mod_idx) const { @@ -315,4 +315,4 @@ template struct PaxosHash; template struct PaxosHash; template struct PaxosHash; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos_hash.h b/psi/rr22/okvs/paxos_hash.h similarity index 97% rename from psi/psi/core/vole_psi/okvs/paxos_hash.h rename to psi/rr22/okvs/paxos_hash.h index 5dba9f17..79d4b163 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_hash.h +++ b/psi/rr22/okvs/paxos_hash.h @@ -23,10 +23,10 @@ #include "yacl/math/gadget.h" #include "yacl/utils/platform_utils.h" -#include "psi/psi/core/vole_psi/okvs/aes_crhash.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/aes_crhash.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { namespace { @@ -192,4 +192,4 @@ struct PaxosHash { void BuildRow(const uint128_t& hash, absl::Span rows) const; }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos_hash_test.cc b/psi/rr22/okvs/paxos_hash_test.cc similarity index 93% rename from psi/psi/core/vole_psi/okvs/paxos_hash_test.cc rename to psi/rr22/okvs/paxos_hash_test.cc index 1a5949bc..ee40c0ab 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_hash_test.cc +++ b/psi/rr22/okvs/paxos_hash_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/paxos_hash.h" +#include "psi/rr22/okvs/paxos_hash.h" #include "gtest/gtest.h" #include "spdlog/spdlog.h" #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { class PaxosHashTest : public testing::TestWithParam {}; @@ -60,4 +60,4 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, PaxosHashTest, testing::Values(yacl::MakeUint128(0x1234, 0x5678), yacl::crypto::FastRandU128())); -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos_test.cc b/psi/rr22/okvs/paxos_test.cc similarity index 96% rename from psi/psi/core/vole_psi/okvs/paxos_test.cc rename to psi/rr22/okvs/paxos_test.cc index ba8c7c0d..8290cbb3 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_test.cc +++ b/psi/rr22/okvs/paxos_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/paxos.h" +#include "psi/rr22/okvs/paxos.h" #include "absl/strings/escaping.h" #include "gtest/gtest.h" #include "spdlog/spdlog.h" #include "yacl/crypto/tools/prg.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { TEST(PaxosTest, SolveTest) { for (auto dt : @@ -88,4 +88,4 @@ TEST(PaxosTest, SolveTest) { } } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/paxos_utils.cc b/psi/rr22/okvs/paxos_utils.cc similarity index 88% rename from psi/psi/core/vole_psi/okvs/paxos_utils.cc rename to psi/rr22/okvs/paxos_utils.cc index 6e5a3ddc..4eede9e7 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_utils.cc +++ b/psi/rr22/okvs/paxos_utils.cc @@ -12,6 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/paxos_utils.h" +#include "psi/rr22/okvs/paxos_utils.h" -namespace psi::psi::okvs {} +namespace psi::rr22::okvs {} diff --git a/psi/psi/core/vole_psi/okvs/paxos_utils.h b/psi/rr22/okvs/paxos_utils.h similarity index 98% rename from psi/psi/core/vole_psi/okvs/paxos_utils.h rename to psi/rr22/okvs/paxos_utils.h index 84f45534..b79a3142 100644 --- a/psi/psi/core/vole_psi/okvs/paxos_utils.h +++ b/psi/rr22/okvs/paxos_utils.h @@ -22,9 +22,9 @@ #include "absl/types/span.h" #include "yacl/crypto/tools/prg.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { // An efficient data structure for tracking the weight of the // paxos columns (node), which excludes the rows which have @@ -322,4 +322,4 @@ struct PxVector { static Helper DefaultHelper() { return {}; } }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/simple_index.cc b/psi/rr22/okvs/simple_index.cc similarity index 99% rename from psi/psi/core/vole_psi/okvs/simple_index.cc rename to psi/rr22/okvs/simple_index.cc index 24337af3..15883d34 100644 --- a/psi/psi/core/vole_psi/okvs/simple_index.cc +++ b/psi/rr22/okvs/simple_index.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/okvs/simple_index.h" +#include "psi/rr22/okvs/simple_index.h" #include #include @@ -24,7 +24,7 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -namespace psi::psi::okvs { +namespace psi::rr22::okvs { namespace { // template @@ -964,4 +964,4 @@ uint64_t SimpleIndex::GetBinSize(uint64_t num_bins, uint64_t num_balls, return B; } -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/core/vole_psi/okvs/simple_index.h b/psi/rr22/okvs/simple_index.h similarity index 93% rename from psi/psi/core/vole_psi/okvs/simple_index.h rename to psi/rr22/okvs/simple_index.h index 6904c32a..3f677c20 100644 --- a/psi/psi/core/vole_psi/okvs/simple_index.h +++ b/psi/rr22/okvs/simple_index.h @@ -16,7 +16,7 @@ #include -namespace psi::psi::okvs { +namespace psi::rr22::okvs { class SimpleIndex { public: @@ -24,4 +24,4 @@ class SimpleIndex { uint64_t stat_sec_param, bool approx = true); }; -} // namespace psi::psi::okvs +} // namespace psi::rr22::okvs diff --git a/psi/psi/rr22/receiver.cc b/psi/rr22/receiver.cc similarity index 93% rename from psi/psi/rr22/receiver.cc rename to psi/rr22/receiver.cc index 1125002d..73a76435 100644 --- a/psi/psi/rr22/receiver.cc +++ b/psi/rr22/receiver.cc @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/rr22/receiver.h" +#include "psi/rr22/receiver.h" #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/crypto/utils/rand.h" #include "yacl/utils/parallel.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/prelude.h" -#include "psi/psi/rr22/common.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/bucket.h" -#include "psi/psi/utils/serialize.h" -#include "psi/psi/utils/sync.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/prelude.h" +#include "psi/rr22/common.h" +#include "psi/trace_categories.h" +#include "psi/utils/bucket.h" +#include "psi/utils/serialize.h" +#include "psi/utils/sync.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { Rr22PsiReceiver::Rr22PsiReceiver(const v2::PsiConfig &config, std::shared_ptr lctx) @@ -143,7 +143,7 @@ void Rr22PsiReceiver::Online() { } std::vector rr22_psi_result = - ::psi::psi::Rr22PsiReceiver(rr22_options, lctx_, items_hash); + Rr22PsiReceiverInternal(rr22_options, lctx_, items_hash); res.reserve(rr22_psi_result.size()); for (auto index : rr22_psi_result) { @@ -186,4 +186,4 @@ void Rr22PsiReceiver::PostProcess() { SPDLOG_INFO("[Rr22PsiReceiver::PostProcess] end"); } -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/rr22/receiver.h b/psi/rr22/receiver.h similarity index 89% rename from psi/psi/rr22/receiver.h rename to psi/rr22/receiver.h index 7fe18f12..75a21d1a 100644 --- a/psi/psi/rr22/receiver.h +++ b/psi/rr22/receiver.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once -#include "psi/psi/interface.h" -#include "psi/psi/utils/hash_bucket_cache.h" +#include "psi/interface.h" +#include "psi/utils/hash_bucket_cache.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { class Rr22PsiReceiver final : public AbstractPsiReceiver { public: @@ -41,4 +41,4 @@ class Rr22PsiReceiver final : public AbstractPsiReceiver { std::unique_ptr input_bucket_store_; }; -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_oprf.cc b/psi/rr22/rr22_oprf.cc similarity index 99% rename from psi/psi/core/vole_psi/rr22_oprf.cc rename to psi/rr22/rr22_oprf.cc index ae0c82a1..d02bd162 100644 --- a/psi/psi/core/vole_psi/rr22_oprf.cc +++ b/psi/rr22/rr22_oprf.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/rr22_oprf.h" +#include "psi/rr22/rr22_oprf.h" #include #include @@ -27,10 +27,10 @@ #include "yacl/math/f2k/f2k.h" #include "yacl/utils/parallel.h" -#include "psi/psi/core/vole_psi/davis_meyer_hash.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/davis_meyer_hash.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi { +namespace psi::rr22 { namespace { @@ -649,4 +649,4 @@ void Rr22OprfReceiver::RecvLowComm( oprf_eval_proc.get(); } -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_oprf.h b/psi/rr22/rr22_oprf.h similarity index 98% rename from psi/psi/core/vole_psi/rr22_oprf.h rename to psi/rr22/rr22_oprf.h index fb2ea8f8..201416f3 100644 --- a/psi/psi/core/vole_psi/rr22_oprf.h +++ b/psi/rr22/rr22_oprf.h @@ -21,7 +21,7 @@ #include "yacl/crypto/primitives/vole/silent_vole.h" #include "yacl/link/context.h" -#include "psi/psi/core/vole_psi/okvs/baxos.h" +#include "psi/rr22/okvs/baxos.h" // Reference: // Blazing Fast PSI from Improved OKVS and Subfield VOLE @@ -31,7 +31,7 @@ // VOLE-PSI: Fast OPRF and Circuit-PSI from Vector-OLE // https://eprint.iacr.org/2021/266.pdf // 3.2 Malicious Secure Oblivious PRF. -namespace psi::psi { +namespace psi::rr22 { enum class Rr22PsiMode { FastMode, @@ -181,4 +181,4 @@ class Rr22OprfReceiver : public Rr22Oprf { private: }; -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_oprf_test.cc b/psi/rr22/rr22_oprf_test.cc similarity index 96% rename from psi/psi/core/vole_psi/rr22_oprf_test.cc rename to psi/rr22/rr22_oprf_test.cc index 2b74a6ed..6bcb2b21 100644 --- a/psi/psi/core/vole_psi/rr22_oprf_test.cc +++ b/psi/rr22/rr22_oprf_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/rr22_oprf.h" +#include "psi/rr22/rr22_oprf.h" #include #include @@ -21,9 +21,9 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" +#include "psi/rr22/okvs/galois128.h" -namespace psi::psi { +namespace psi::rr22 { namespace { @@ -159,4 +159,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{15, Rr22PsiMode::LowCommMode}, TestParams{15, Rr22PsiMode::FastMode, true})); -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_psi.cc b/psi/rr22/rr22_psi.cc similarity index 91% rename from psi/psi/core/vole_psi/rr22_psi.cc rename to psi/rr22/rr22_psi.cc index b4aa1d6c..5b087286 100644 --- a/psi/psi/core/vole_psi/rr22_psi.cc +++ b/psi/rr22/rr22_psi.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/rr22_psi.h" +#include "psi/rr22/rr22_psi.h" #include #include @@ -23,12 +23,12 @@ #include "sparsehash/dense_hash_map" #include "yacl/base/byte_container_view.h" -#include "psi/psi/core/vole_psi/okvs/galois128.h" -#include "psi/psi/core/vole_psi/rr22_oprf.h" -#include "psi/psi/core/vole_psi/rr22_utils.h" -#include "psi/psi/utils/sync.h" +#include "psi/rr22/okvs/galois128.h" +#include "psi/rr22/rr22_oprf.h" +#include "psi/rr22/rr22_utils.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi::rr22 { namespace { @@ -49,9 +49,9 @@ constexpr size_t kRr22OprfBinSize = 1 << 14; } // namespace -void Rr22PsiSender(const Rr22PsiOptions& options, - const std::shared_ptr& lctx, - const std::vector& inputs) { +void Rr22PsiSenderInternal(const Rr22PsiOptions& options, + const std::shared_ptr& lctx, + const std::vector& inputs) { YACL_ENFORCE(lctx->WorldSize() == 2); // Gather Items Size @@ -131,7 +131,7 @@ void Rr22PsiSender(const Rr22PsiOptions& options, SPDLOG_INFO("send rr22 oprf finished"); } -std::vector Rr22PsiReceiver( +std::vector Rr22PsiReceiverInternal( const Rr22PsiOptions& options, const std::shared_ptr& lctx, const std::vector& inputs) { @@ -176,4 +176,4 @@ std::vector Rr22PsiReceiver( return indices; } -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_psi.h b/psi/rr22/rr22_psi.h similarity index 86% rename from psi/psi/core/vole_psi/rr22_psi.h rename to psi/rr22/rr22_psi.h index a3603dd6..62837089 100644 --- a/psi/psi/core/vole_psi/rr22_psi.h +++ b/psi/rr22/rr22_psi.h @@ -23,13 +23,13 @@ #include "yacl/base/int128.h" #include "yacl/link/context.h" -#include "psi/psi/core/vole_psi/rr22_oprf.h" +#include "psi/rr22/rr22_oprf.h" // [RR22] Blazing Fast PSI from Improved OKVS and Subfield VOLE, CCS 2022 // https://eprint.iacr.org/2022/320 // okvs code reference https://github.com/Visa-Research/volepsi -namespace psi::psi { +namespace psi::rr22 { struct Rr22PsiOptions { Rr22PsiOptions(size_t ssp_params, size_t num_threads_params, @@ -64,15 +64,15 @@ struct Rr22PsiOptions { yacl::crypto::CodeType code_type = yacl::crypto::CodeType::Silver5; }; -void Rr22PsiSender(const Rr22PsiOptions& options, - const std::shared_ptr& lctx, - const std::vector& inputs); +void Rr22PsiSenderInternal(const Rr22PsiOptions& options, + const std::shared_ptr& lctx, + const std::vector& inputs); // return psi result indices, // indices are not sorted; need to be sorted by caller. -std::vector Rr22PsiReceiver( +std::vector Rr22PsiReceiverInternal( const Rr22PsiOptions& options, const std::shared_ptr& lctx, const std::vector& inputs); -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_psi_bench.cc b/psi/rr22/rr22_psi_benchmark.cc similarity index 90% rename from psi/psi/core/vole_psi/rr22_psi_bench.cc rename to psi/rr22/rr22_psi_benchmark.cc index b9295197..4c273d00 100644 --- a/psi/psi/core/vole_psi/rr22_psi_bench.cc +++ b/psi/rr22/rr22_psi_benchmark.cc @@ -26,7 +26,7 @@ #include "yacl/link/context.h" #include "yacl/link/test_util.h" -#include "psi/psi/core/vole_psi/rr22_psi.h" +#include "psi/rr22/rr22_psi.h" namespace { @@ -108,19 +108,21 @@ static void BM_Rr22FastPsi(benchmark::State& state) { state.ResumeTiming(); - psi::psi::Rr22PsiOptions psi_options(40, thread_num, true); + psi::rr22::Rr22PsiOptions psi_options(40, thread_num, true); if (mode == 1) { - psi_options.mode = psi::psi::Rr22PsiMode::LowCommMode; + psi_options.mode = psi::rr22::Rr22PsiMode::LowCommMode; } else if (mode == 2) { - psi_options.mode = psi::psi::Rr22PsiMode::FastMode; + psi_options.mode = psi::rr22::Rr22PsiMode::FastMode; psi_options.malicious = true; } - auto psi_sender_proc = std::async( - [&] { psi::psi::Rr22PsiSender(psi_options, lctxs[0], inputs_a); }); + auto psi_sender_proc = std::async([&] { + psi::rr22::Rr22PsiSenderInternal(psi_options, lctxs[0], inputs_a); + }); auto psi_receiver_proc = std::async([&] { - return psi::psi::Rr22PsiReceiver(psi_options, lctxs[1], inputs_b); + return psi::rr22::Rr22PsiReceiverInternal(psi_options, lctxs[1], + inputs_b); }); psi_sender_proc.get(); diff --git a/psi/psi/core/vole_psi/rr22_psi_test.cc b/psi/rr22/rr22_psi_test.cc similarity index 91% rename from psi/psi/core/vole_psi/rr22_psi_test.cc rename to psi/rr22/rr22_psi_test.cc index 1be5563d..6c12ba96 100644 --- a/psi/psi/core/vole_psi/rr22_psi_test.cc +++ b/psi/rr22/rr22_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/rr22_psi.h" +#include "psi/rr22/rr22_psi.h" #include #include @@ -23,7 +23,7 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/link/test_util.h" -namespace psi::psi { +namespace psi::rr22 { namespace { @@ -85,10 +85,10 @@ TEST_P(Rr22PsiTest, CorrectTest) { psi_options.mode = params.mode; psi_options.malicious = params.malicious; - auto psi_sender_proc = - std::async([&] { Rr22PsiSender(psi_options, lctxs[0], inputs_a); }); + auto psi_sender_proc = std::async( + [&] { Rr22PsiSenderInternal(psi_options, lctxs[0], inputs_a); }); auto psi_receiver_proc = std::async( - [&] { return Rr22PsiReceiver(psi_options, lctxs[1], inputs_b); }); + [&] { return Rr22PsiReceiverInternal(psi_options, lctxs[1], inputs_b); }); psi_sender_proc.get(); std::vector indices_psi = psi_receiver_proc.get(); @@ -113,4 +113,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{35, Rr22PsiMode::LowCommMode}, TestParams{35, Rr22PsiMode::FastMode, true})); -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_utils.cc b/psi/rr22/rr22_utils.cc similarity index 96% rename from psi/psi/core/vole_psi/rr22_utils.cc rename to psi/rr22/rr22_utils.cc index 4c85b6ee..f6f41399 100644 --- a/psi/psi/core/vole_psi/rr22_utils.cc +++ b/psi/rr22/rr22_utils.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/vole_psi/rr22_utils.h" +#include "psi/rr22/rr22_utils.h" #include #include @@ -23,10 +23,10 @@ #include "libdivide.h" #include "sparsehash/dense_hash_map" -#include "psi/psi/core/vole_psi/okvs/galois128.h" -#include "psi/psi/core/vole_psi/okvs/simple_index.h" +#include "psi/rr22/okvs/galois128.h" +#include "psi/rr22/okvs/simple_index.h" -namespace psi::psi { +namespace psi::rr22 { namespace { @@ -192,4 +192,4 @@ std::vector GetIntersection( return indices; } -} // namespace psi::psi +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/rr22_utils.h b/psi/rr22/rr22_utils.h similarity index 97% rename from psi/psi/core/vole_psi/rr22_utils.h rename to psi/rr22/rr22_utils.h index 19fbe928..0fb73b77 100644 --- a/psi/psi/core/vole_psi/rr22_utils.h +++ b/psi/rr22/rr22_utils.h @@ -21,7 +21,7 @@ #include "yacl/base/int128.h" #include "yacl/link/context.h" -namespace psi::psi { +namespace psi::rr22 { std::vector GetIntersection( absl::Span self_oprfs, size_t peer_items_num, diff --git a/psi/psi/rr22/sender.cc b/psi/rr22/sender.cc similarity index 92% rename from psi/psi/rr22/sender.cc rename to psi/rr22/sender.cc index 580131f3..8ad7ab98 100644 --- a/psi/psi/rr22/sender.cc +++ b/psi/rr22/sender.cc @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/rr22/sender.h" +#include "psi/rr22/sender.h" #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/parallel.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/core/vole_psi/rr22_psi.h" -#include "psi/psi/rr22/common.h" -#include "psi/psi/trace_categories.h" -#include "psi/psi/utils/bucket.h" -#include "psi/psi/utils/sync.h" +#include "psi/legacy/bucket_psi.h" +#include "psi/rr22/common.h" +#include "psi/rr22/rr22_psi.h" +#include "psi/trace_categories.h" +#include "psi/utils/bucket.h" +#include "psi/utils/sync.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { Rr22PsiSender::Rr22PsiSender(const v2::PsiConfig &config, std::shared_ptr lctx) @@ -129,7 +129,7 @@ void Rr22PsiSender::Online() { } }); - ::psi::psi::Rr22PsiSender(rr22_options, lctx_, items_hash); + Rr22PsiSenderInternal(rr22_options, lctx_, items_hash); }); SyncWait(lctx_, &run_f); @@ -165,4 +165,4 @@ void Rr22PsiSender::PostProcess() { SPDLOG_INFO("[Rr22PsiSender::PostProcess] end"); } -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/rr22/sender.h b/psi/rr22/sender.h similarity index 89% rename from psi/psi/rr22/sender.h rename to psi/rr22/sender.h index 580b27c2..ea9ff4b0 100644 --- a/psi/psi/rr22/sender.h +++ b/psi/rr22/sender.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once -#include "psi/psi/interface.h" -#include "psi/psi/utils/hash_bucket_cache.h" +#include "psi/interface.h" +#include "psi/utils/hash_bucket_cache.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi::rr22 { +namespace psi::rr22 { class Rr22PsiSender final : public AbstractPsiSender { public: @@ -41,4 +41,4 @@ class Rr22PsiSender final : public AbstractPsiSender { std::unique_ptr input_bucket_store_; }; -} // namespace psi::psi::rr22 +} // namespace psi::rr22 diff --git a/psi/psi/core/vole_psi/sparseconfig.h b/psi/rr22/sparseconfig.h similarity index 100% rename from psi/psi/core/vole_psi/sparseconfig.h rename to psi/rr22/sparseconfig.h diff --git a/psi/pir/BUILD.bazel b/psi/seal_pir/BUILD.bazel similarity index 75% rename from psi/pir/BUILD.bazel rename to psi/seal_pir/BUILD.bazel index fee06bf9..1f65800b 100644 --- a/psi/pir/BUILD.bazel +++ b/psi/seal_pir/BUILD.bazel @@ -26,7 +26,7 @@ psi_cc_library( ], deps = [ ":seal_pir_utils", - ":serializable_cc_proto", + "//psi/seal_pir:serializable_cc_proto", "@com_github_microsoft_seal//:seal", "@com_github_openssl_openssl//:openssl", "@yacl//yacl/base:byte_container_view", @@ -52,7 +52,7 @@ psi_cc_library( hdrs = ["seal_mpir.h"], deps = [ ":seal_pir", - "//psi/psi/core:cuckoo_index", + "//psi/utils:cuckoo_index", "@yacl//yacl/crypto/base/block_cipher:symmetric_crypto", ], ) @@ -80,31 +80,7 @@ psi_cc_test( srcs = ["seal_mpir_test.cc"], deps = [ ":seal_mpir", - "//psi/psi/cryptor:sodium_curve25519_cryptor", + "//psi/cryptor:sodium_curve25519_cryptor", "@com_google_absl//absl/strings", ], ) - -psi_cc_library( - name = "pir", - srcs = ["pir.cc"], - hdrs = ["pir.h"], - deps = [ - "//psi/proto:pir_cc_proto", - "//psi/psi/core:cuckoo_index", - "//psi/psi/core/labeled_psi", - "//psi/psi/utils:serialize", - "//psi/psi/utils:sync", - "@yacl//yacl/crypto/base/block_cipher:symmetric_crypto", - ], -) - -psi_cc_test( - name = "pir_test", - srcs = ["pir_test.cc"], - deps = [ - ":pir", - "//psi/psi/utils:io", - "@yacl//yacl/utils:scope_guard", - ], -) diff --git a/psi/pir/seal_mpir.cc b/psi/seal_pir/seal_mpir.cc similarity index 96% rename from psi/pir/seal_mpir.cc rename to psi/seal_pir/seal_mpir.cc index 19f8ebb8..f99c16d0 100644 --- a/psi/pir/seal_mpir.cc +++ b/psi/seal_pir/seal_mpir.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/seal_mpir.h" +#include "psi/seal_pir/seal_mpir.h" #include #include @@ -22,9 +22,9 @@ #include "spdlog/spdlog.h" -#include "psi/pir/serializable.pb.h" +#include "psi/seal_pir/serializable.pb.h" -namespace psi::pir { +namespace psi::seal_pir { void MultiQueryServer::GenerateSimpleHash() { std::vector query_index_hash( @@ -44,7 +44,7 @@ void MultiQueryServer::GenerateSimpleHash() { for (size_t idx = 0; idx < query_options_.seal_options.element_number; ++idx) { - ::psi::psi::CuckooIndex::HashRoom itemHash(query_index_hash[idx]); + ::psi::CuckooIndex::HashRoom itemHash(query_index_hash[idx]); std::vector bin_idx(query_options_.cuckoo_hash_number); for (size_t j = 0; j < query_options_.cuckoo_hash_number; ++j) { @@ -169,7 +169,7 @@ void MultiQueryClient::GenerateSimpleHashMap() { for (size_t idx = 0; idx < query_options_.seal_options.element_number; ++idx) { - ::psi::psi::CuckooIndex::HashRoom itemHash(query_index_hash[idx]); + ::psi::CuckooIndex::HashRoom itemHash(query_index_hash[idx]); std::vector bin_idx(query_options_.cuckoo_hash_number); for (size_t j = 0; j < query_options_.cuckoo_hash_number; ++j) { @@ -208,7 +208,7 @@ std::vector MultiQueryClient::GenerateBatchQueryIndex( } }); - ::psi::psi::CuckooIndex cuckoo_index(cuckoo_params_); + ::psi::CuckooIndex cuckoo_index(cuckoo_params_); cuckoo_index.Insert(query_index_hash); auto ck_bins = cuckoo_index.bins(); @@ -275,7 +275,7 @@ std::vector> MultiQueryClient::DoMultiPirQuery( query_proto_vec[idx]->set_query_size(0); query_proto_vec[idx]->set_start_pos(0); for (auto &query_cipher : query_ciphers) { - ::psi::pir::CiphertextsProto *ciphers_proto = + ::psi::seal_pir::CiphertextsProto *ciphers_proto = query_proto_vec[idx]->add_query_cipher(); for (size_t k = 0; k < query_cipher.size(); ++k) { @@ -343,4 +343,4 @@ std::vector> MultiQueryClient::DoMultiPirQuery( return answers; } -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_mpir.h b/psi/seal_pir/seal_mpir.h similarity index 94% rename from psi/pir/seal_mpir.h rename to psi/seal_pir/seal_mpir.h index 2633ebeb..bbabf9a9 100644 --- a/psi/pir/seal_mpir.h +++ b/psi/seal_pir/seal_mpir.h @@ -23,13 +23,13 @@ #include "yacl/crypto/base/block_cipher/symmetric_crypto.h" #include "yacl/utils/parallel.h" -#include "psi/pir/seal_pir.h" -#include "psi/psi/core/cuckoo_index.h" +#include "psi/seal_pir/seal_pir.h" +#include "psi/utils/cuckoo_index.h" -namespace psi::pir { +namespace psi::seal_pir { struct MultiQueryOptions { - ::psi::pir::SealPirOptions seal_options; + ::psi::seal_pir::SealPirOptions seal_options; size_t batch_number = 0; size_t cuckoo_hash_number = 3; }; @@ -74,7 +74,7 @@ class MultiQuery { } MultiQueryOptions query_options_; - ::psi::psi::CuckooIndex::Options cuckoo_params_; + ::psi::CuckooIndex::Options cuckoo_params_; std::array oracle_seed_; size_t max_bin_item_size_ = 0; @@ -91,7 +91,7 @@ class MultiQueryServer : public MultiQuery { GenerateSimpleHash(); - ::psi::pir::SealPirOptions pir_options{ + ::psi::seal_pir::SealPirOptions pir_options{ query_options_.seal_options.poly_modulus_degree, max_bin_item_size_, query_options_.seal_options.element_size}; @@ -164,4 +164,4 @@ class MultiQueryClient : public MultiQuery { std::unique_ptr pir_client_; }; -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_mpir_test.cc b/psi/seal_pir/seal_mpir_test.cc similarity index 91% rename from psi/pir/seal_mpir_test.cc rename to psi/seal_pir/seal_mpir_test.cc index e8ba3b7c..997439ce 100644 --- a/psi/pir/seal_mpir_test.cc +++ b/psi/seal_pir/seal_mpir_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/seal_mpir.h" +#include "psi/seal_pir/seal_mpir.h" #include #include @@ -23,9 +23,9 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -#include "psi/psi/cryptor/sodium_curve25519_cryptor.h" +#include "psi/cryptor/sodium_curve25519_cryptor.h" -namespace psi::pir { +namespace psi::seal_pir { namespace { struct TestParams { size_t batch_number; @@ -83,8 +83,7 @@ TEST_P(SealMultiPirTest, Works) { // size_t batch_number = 256; double factor = 1.5; size_t hash_num = 3; - ::psi::psi::CuckooIndex::Options cuckoo_params{batch_number, 0, hash_num, - factor}; + ::psi::CuckooIndex::Options cuckoo_params{batch_number, 0, hash_num, factor}; std::vector query_index = GenerateQueryIndex(batch_number, element_number); @@ -107,14 +106,16 @@ TEST_P(SealMultiPirTest, Works) { EXPECT_EQ(seed_server, seed_client); - ::psi::pir::MultiQueryOptions options{ + ::psi::seal_pir::MultiQueryOptions options{ {params.poly_degree, element_number, element_size}, batch_number}; SPDLOG_INFO("element_number:{}", options.seal_options.element_number); - ::psi::pir::MultiQueryServer mpir_server(options, cuckoo_params, seed_server); + ::psi::seal_pir::MultiQueryServer mpir_server(options, cuckoo_params, + seed_server); - ::psi::pir::MultiQueryClient mpir_client(options, cuckoo_params, seed_client); + ::psi::seal_pir::MultiQueryClient mpir_client(options, cuckoo_params, + seed_client); // server setup data mpir_server.SetDatabase(db_bytes); @@ -177,4 +178,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{64, 10000, 20}) // ); -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_pir.cc b/psi/seal_pir/seal_pir.cc similarity index 99% rename from psi/pir/seal_pir.cc rename to psi/seal_pir/seal_pir.cc index 43b1e429..b8958524 100644 --- a/psi/pir/seal_pir.cc +++ b/psi/seal_pir/seal_pir.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/seal_pir.h" +#include "psi/seal_pir/seal_pir.h" #include #include @@ -24,7 +24,7 @@ #include "yacl/base/exception.h" #include "yacl/utils/parallel.h" -namespace psi::pir { +namespace psi::seal_pir { namespace { // Number of coefficients needed to represent a database element @@ -219,7 +219,7 @@ void SealPir::SetPirParams(size_t element_number, size_t element_size) { std::string SealPir::SerializePlaintexts( const std::vector &plains) { - psi::pir::PlaintextsProto plains_proto; + psi::seal_pir::PlaintextsProto plains_proto; for (const auto &plain : plains) { std::string plain_bytes = SerializeSealObject(plain); @@ -231,7 +231,7 @@ std::string SealPir::SerializePlaintexts( std::vector SealPir::DeSerializePlaintexts( const std::string &plaintext_bytes, bool safe_load) { - psi::pir::PlaintextsProto plains_proto; + psi::seal_pir::PlaintextsProto plains_proto; plains_proto.ParseFromArray(plaintext_bytes.data(), plaintext_bytes.length()); std::vector plains(plains_proto.data_size()); @@ -248,7 +248,7 @@ std::vector SealPir::DeSerializePlaintexts( yacl::Buffer SealPir::SerializeCiphertexts( const std::vector &ciphers) { - psi::pir::CiphertextsProto ciphers_proto; + psi::seal_pir::CiphertextsProto ciphers_proto; for (const auto &cipher : ciphers) { std::string cipher_bytes = SerializeSealObject(cipher); @@ -1168,4 +1168,4 @@ std::vector SealPirClient::DoPirQuery( return query_reply_data; } -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_pir.h b/psi/seal_pir/seal_pir.h similarity index 98% rename from psi/pir/seal_pir.h rename to psi/seal_pir/seal_pir.h index 2c3a90b7..0a75b29b 100644 --- a/psi/pir/seal_pir.h +++ b/psi/seal_pir/seal_pir.h @@ -23,11 +23,11 @@ #include "yacl/base/byte_container_view.h" #include "yacl/link/link.h" -#include "psi/pir/seal_pir_utils.h" +#include "psi/seal_pir/seal_pir_utils.h" -#include "psi/pir/serializable.pb.h" +#include "psi/seal_pir/serializable.pb.h" -namespace psi::pir { +namespace psi::seal_pir { // // SealPIR paper: @@ -270,4 +270,4 @@ class SealPirClient : public SealPir { // set friend class friend class SealPirServer; }; -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_pir_test.cc b/psi/seal_pir/seal_pir_test.cc similarity index 90% rename from psi/pir/seal_pir_test.cc rename to psi/seal_pir/seal_pir_test.cc index a8849711..65a2fc26 100644 --- a/psi/pir/seal_pir_test.cc +++ b/psi/seal_pir/seal_pir_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/seal_pir.h" +#include "psi/seal_pir/seal_pir.h" #include #include @@ -22,7 +22,7 @@ #include "spdlog/spdlog.h" #include "yacl/link/test_util.h" -namespace psi::pir { +namespace psi::seal_pir { namespace { struct TestParams { size_t element_number; @@ -58,17 +58,17 @@ TEST_P(SealPirTest, Works) { std::vector db_data = GenerateDbData(params); - psi::pir::SealPirOptions options{n, params.element_number, - params.element_size, params.query_size}; + psi::seal_pir::SealPirOptions options{n, params.element_number, + params.element_size, params.query_size}; - psi::pir::SealPirClient client(options); + psi::seal_pir::SealPirClient client(options); std::shared_ptr plaintext_store = std::make_shared(); #ifdef DEC_DEBUG_ - psi::pir::SealPirServer server(options, client, plaintext_store); + psi::seal_pir::SealPirServer server(options, client, plaintext_store); #else - psi::pir::SealPirServer server(options, plaintext_store); + psi::seal_pir::SealPirServer server(options, plaintext_store); #endif // === server setup @@ -134,4 +134,4 @@ INSTANTIATE_TEST_SUITE_P( TestParams{3000, 288, 1000}) // ); -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_pir_utils.cc b/psi/seal_pir/seal_pir_utils.cc similarity index 95% rename from psi/pir/seal_pir_utils.cc rename to psi/seal_pir/seal_pir_utils.cc index ecf4830f..de0acb8d 100644 --- a/psi/pir/seal_pir_utils.cc +++ b/psi/seal_pir/seal_pir_utils.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/pir/seal_pir_utils.h" +#include "psi/seal_pir/seal_pir_utils.h" #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -namespace psi::pir { +namespace psi::seal_pir { std::vector MemoryDbElementProvider::ReadElement(size_t index) { YACL_ENFORCE(index < items_.size()); @@ -59,4 +59,4 @@ std::vector MemoryDbPlaintextStore::ReadPlaintexts( return std::move(db_vec_[sub_db_index]); } -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/seal_pir_utils.h b/psi/seal_pir/seal_pir_utils.h similarity index 98% rename from psi/pir/seal_pir_utils.h rename to psi/seal_pir/seal_pir_utils.h index be1c8c67..e649cc50 100644 --- a/psi/pir/seal_pir_utils.h +++ b/psi/seal_pir/seal_pir_utils.h @@ -20,7 +20,7 @@ #include "seal/seal.h" -namespace psi::pir { +namespace psi::seal_pir { // Interface which read db data. class IDbElementProvider { public: @@ -86,4 +86,4 @@ class MemoryDbPlaintextStore : public IDbPlaintextStore { std::vector> db_vec_; }; -} // namespace psi::pir +} // namespace psi::seal_pir diff --git a/psi/pir/serializable.proto b/psi/seal_pir/serializable.proto similarity index 98% rename from psi/pir/serializable.proto rename to psi/seal_pir/serializable.proto index e0e19d45..fc8ef1b4 100644 --- a/psi/pir/serializable.proto +++ b/psi/seal_pir/serializable.proto @@ -16,7 +16,7 @@ syntax = "proto3"; -package psi.pir; +package psi.seal_pir; message PlaintextsProto { repeated bytes data = 1; diff --git a/psi/psi/trace_categories.cc b/psi/trace_categories.cc similarity index 94% rename from psi/psi/trace_categories.cc rename to psi/trace_categories.cc index 4365e262..9d1b889f 100644 --- a/psi/psi/trace_categories.cc +++ b/psi/trace_categories.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/trace_categories.h" +#include "psi/trace_categories.h" // Reserves internal static storage for our tracing categories. PERFETTO_TRACK_EVENT_STATIC_STORAGE(); diff --git a/psi/psi/trace_categories.h b/psi/trace_categories.h similarity index 100% rename from psi/psi/trace_categories.h rename to psi/trace_categories.h diff --git a/psi/psi/utils/BUILD.bazel b/psi/utils/BUILD.bazel similarity index 88% rename from psi/psi/utils/BUILD.bazel rename to psi/utils/BUILD.bazel index 4ba4d29f..62ba621e 100644 --- a/psi/psi/utils/BUILD.bazel +++ b/psi/utils/BUILD.bazel @@ -250,8 +250,8 @@ psi_cc_library( hdrs = ["advanced_join.h"], deps = [ ":key", + "//psi:prelude", "//psi/proto:psi_v2_cc_proto", - "//psi/psi:prelude", "@boost//:uuid", "@org_apache_arrow//:arrow", "@yacl//yacl/base:exception", @@ -313,8 +313,8 @@ psi_cc_library( ], deps = [ ":io", + "//psi/cryptor:ecc_cryptor", "//psi/proto:psi_v2_cc_proto", - "//psi/psi/cryptor:ecc_cryptor", "@yacl//yacl/base:exception", "@yacl//yacl/link", ], @@ -325,7 +325,7 @@ psi_cc_test( srcs = ["recovery_test.cc"], deps = [ ":recovery", - "//psi/psi/cryptor:cryptor_selector", + "//psi/cryptor:cryptor_selector", ], ) @@ -338,8 +338,8 @@ psi_cc_library( ":index_store", ":recovery", ":sync", + "//psi:prelude", "//psi/proto:psi_v2_cc_proto", - "//psi/psi:prelude", ], ) @@ -375,6 +375,44 @@ psi_cc_library( hdrs = ["ec.h"], deps = [ ":io", - "//psi/psi/cryptor:ecc_cryptor", + "//psi/cryptor:ecc_cryptor", + ], +) + +cc_proto_library( + name = "ic_protocol_psi_cc_proto", + deps = ["@org_interconnection//interconnection/runtime:ecdh_psi"], +) + +psi_cc_library( + name = "communication", + srcs = ["communication.cc"], + hdrs = ["communication.h"], + deps = [ + ":ic_protocol_psi_cc_proto", + ":serialize", + "@yacl//yacl/base:exception", + "@yacl//yacl/link", + ], +) + +psi_cc_library( + name = "cuckoo_index", + srcs = ["cuckoo_index.cc"], + hdrs = ["cuckoo_index.h"], + linkopts = ["-lm"], + deps = [ + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:exception", + "@yacl//yacl/base:int128", + ], +) + +psi_cc_test( + name = "cuckoo_index_test", + srcs = ["cuckoo_index_test.cc"], + deps = [ + ":cuckoo_index", + "@yacl//yacl/crypto/utils:rand", ], ) diff --git a/psi/psi/utils/advanced_join.cc b/psi/utils/advanced_join.cc similarity index 99% rename from psi/psi/utils/advanced_join.cc rename to psi/utils/advanced_join.cc index 5354ae35..83d3490b 100644 --- a/psi/psi/utils/advanced_join.cc +++ b/psi/utils/advanced_join.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/advanced_join.h" +#include "psi/utils/advanced_join.h" #include #include @@ -33,12 +33,12 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/key.h" +#include "psi/prelude.h" +#include "psi/utils/key.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { constexpr char kAdvancedJoinKeyCount[] = "psi_advanced_join_cnt"; constexpr char kAdvancedJoinFirstIndex[] = "psi_advanced_join_first_index"; @@ -1025,4 +1025,4 @@ void AdvancedJoinGenerateResult(const AdvancedJoinConfig& config) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/advanced_join.h b/psi/utils/advanced_join.h similarity index 98% rename from psi/psi/utils/advanced_join.h rename to psi/utils/advanced_join.h index cc7bcdd1..dc93066b 100644 --- a/psi/psi/utils/advanced_join.h +++ b/psi/utils/advanced_join.h @@ -23,7 +23,7 @@ #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { struct AdvancedJoinConfig { // Origin input file. @@ -95,4 +95,4 @@ void AdvancedJoinSync(const std::shared_ptr& link_ctx, // Generate result at config.output_path void AdvancedJoinGenerateResult(const AdvancedJoinConfig& config); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/advanced_join_test.cc b/psi/utils/advanced_join_test.cc similarity index 99% rename from psi/psi/utils/advanced_join_test.cc rename to psi/utils/advanced_join_test.cc index af49cd42..74443d1f 100644 --- a/psi/psi/utils/advanced_join_test.cc +++ b/psi/utils/advanced_join_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/advanced_join.h" +#include "psi/utils/advanced_join.h" #include #include @@ -23,7 +23,7 @@ #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { void SaveFile(const std::string& path, const std::string& content) { std::ofstream file; @@ -651,4 +651,4 @@ NA,NA,NA expected_output_content_sender); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/arrow_csv_batch_provider.cc b/psi/utils/arrow_csv_batch_provider.cc similarity index 95% rename from psi/psi/utils/arrow_csv_batch_provider.cc rename to psi/utils/arrow_csv_batch_provider.cc index c8cc47a3..d6a11027 100644 --- a/psi/psi/utils/arrow_csv_batch_provider.cc +++ b/psi/utils/arrow_csv_batch_provider.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/arrow_csv_batch_provider.h" +#include "psi/utils/arrow_csv_batch_provider.h" #include @@ -21,9 +21,9 @@ #include "arrow/datum.h" #include "spdlog/spdlog.h" -#include "psi/psi/utils/key.h" +#include "psi/utils/key.h" -namespace psi::psi { +namespace psi { ArrowCsvBatchProvider::ArrowCsvBatchProvider( const std::string& file_path, const std::vector& keys, @@ -103,4 +103,4 @@ void ArrowCsvBatchProvider::Init() { .ValueOrDie(); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/arrow_csv_batch_provider.h b/psi/utils/arrow_csv_batch_provider.h similarity index 94% rename from psi/psi/utils/arrow_csv_batch_provider.h rename to psi/utils/arrow_csv_batch_provider.h index d960b18c..89bf6fe4 100644 --- a/psi/psi/utils/arrow_csv_batch_provider.h +++ b/psi/utils/arrow_csv_batch_provider.h @@ -20,9 +20,9 @@ #include "arrow/csv/api.h" #include "arrow/io/api.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/utils/batch_provider.h" -namespace psi::psi { +namespace psi { class ArrowCsvBatchProvider : public IBasicBatchProvider { public: @@ -58,4 +58,4 @@ class ArrowCsvBatchProvider : public IBasicBatchProvider { std::vector> arrays_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/arrow_csv_batch_provider_test.cc b/psi/utils/arrow_csv_batch_provider_test.cc similarity index 96% rename from psi/psi/utils/arrow_csv_batch_provider_test.cc rename to psi/utils/arrow_csv_batch_provider_test.cc index 24bdd5ee..2ba09464 100644 --- a/psi/psi/utils/arrow_csv_batch_provider_test.cc +++ b/psi/utils/arrow_csv_batch_provider_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/arrow_csv_batch_provider.h" +#include "psi/utils/arrow_csv_batch_provider.h" #include #include #include "gtest/gtest.h" -namespace psi::psi { +namespace psi { namespace { constexpr auto content = R"csv(id1,id2,id3 @@ -77,4 +77,4 @@ TEST(ArrowCsvBatchProvider, works) { } } // namespace -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/batch_provider.cc b/psi/utils/batch_provider.cc similarity index 99% rename from psi/psi/utils/batch_provider.cc rename to psi/utils/batch_provider.cc index f5b42194..4d0400e0 100644 --- a/psi/psi/utils/batch_provider.cc +++ b/psi/utils/batch_provider.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/batch_provider.h" +#include "psi/utils/batch_provider.h" #include #include @@ -26,9 +26,9 @@ #include "yacl/base/exception.h" #include "yacl/crypto/utils/rand.h" -#include "psi/psi/utils/key.h" +#include "psi/utils/key.h" -namespace psi::psi { +namespace psi { MemoryBatchProvider::MemoryBatchProvider(const std::vector& items, size_t batch_size, @@ -357,4 +357,4 @@ void SimpleShuffledBatchProvider::ReadAndShuffle(size_t read_index, } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/batch_provider.h b/psi/utils/batch_provider.h similarity index 97% rename from psi/psi/utils/batch_provider.h rename to psi/utils/batch_provider.h index df980f15..017232be 100644 --- a/psi/psi/utils/batch_provider.h +++ b/psi/utils/batch_provider.h @@ -23,10 +23,10 @@ #include #include -#include "psi/psi/utils/csv_header_analyzer.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/csv_header_analyzer.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { /// Interface which produce batch of strings. class IBatchProvider { @@ -178,4 +178,4 @@ class SimpleShuffledBatchProvider : public IShuffledBatchProvider { bool file_end_flag_ = false; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/bucket.cc b/psi/utils/bucket.cc similarity index 95% rename from psi/psi/utils/bucket.cc rename to psi/utils/bucket.cc index 69d18073..0005a3d3 100644 --- a/psi/psi/utils/bucket.cc +++ b/psi/utils/bucket.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/bucket.h" +#include "psi/utils/bucket.h" -#include "psi/psi/prelude.h" -#include "psi/psi/utils/sync.h" +#include "psi/prelude.h" +#include "psi/utils/sync.h" -namespace psi::psi { +namespace psi { std::optional> PrepareBucketData( v2::Protocol protocol, size_t bucket_idx, @@ -105,4 +105,4 @@ void HandleBucketResultByReceiver( writer->Commit(); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/bucket.h b/psi/utils/bucket.h similarity index 89% rename from psi/psi/utils/bucket.h rename to psi/utils/bucket.h index b500d39d..409e4e18 100644 --- a/psi/psi/utils/bucket.h +++ b/psi/utils/bucket.h @@ -17,11 +17,11 @@ #include "yacl/link/link.h" -#include "psi/psi/utils/hash_bucket_cache.h" -#include "psi/psi/utils/index_store.h" -#include "psi/psi/utils/recovery.h" +#include "psi/utils/hash_bucket_cache.h" +#include "psi/utils/index_store.h" +#include "psi/utils/recovery.h" -namespace psi::psi { +namespace psi { // Default bucket size when not provided. constexpr uint64_t kDefaultBucketSize = 1 << 20; @@ -41,4 +41,4 @@ void HandleBucketResultByReceiver( const std::vector& result_list, IndexWriter* writer); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/communication.cc b/psi/utils/communication.cc similarity index 96% rename from psi/psi/core/communication.cc rename to psi/utils/communication.cc index aad4bd91..c498cf48 100644 --- a/psi/psi/core/communication.cc +++ b/psi/utils/communication.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/communication.h" +#include "psi/utils/communication.h" #include "spdlog/spdlog.h" #include "yacl/base/exception.h" #include "interconnection/runtime/ecdh_psi.pb.h" -namespace psi::psi { +namespace psi { std::shared_ptr CreateP2PLinkCtx( const std::string& id_prefix, @@ -73,4 +73,4 @@ PsiDataBatch IcPsiBatchSerializer::Deserialize(yacl::ByteContainerView buf) { return batch; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/communication.h b/psi/utils/communication.h similarity index 97% rename from psi/psi/core/communication.h rename to psi/utils/communication.h index 97d4df4b..838b1028 100644 --- a/psi/psi/core/communication.h +++ b/psi/utils/communication.h @@ -20,9 +20,9 @@ #include "yacl/base/buffer.h" #include "yacl/link/link.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi { +namespace psi { // I prefer 4096. inline constexpr size_t kEcdhPsiBatchSize = 4096; @@ -101,4 +101,4 @@ class IcPsiBatchSerializer { std::shared_ptr CreateP2PLinkCtx( const std::string& id_prefix, const std::shared_ptr& link_ctx, size_t peer_rank); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_checker.cc b/psi/utils/csv_checker.cc similarity index 98% rename from psi/psi/utils/csv_checker.cc rename to psi/utils/csv_checker.cc index a60ccdad..288d4eac 100644 --- a/psi/psi/utils/csv_checker.cc +++ b/psi/utils/csv_checker.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/csv_checker.h" +#include "psi/utils/csv_checker.h" #include #include @@ -37,10 +37,10 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/utils/scope_guard.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/key.h" +#include "psi/utils/io.h" +#include "psi/utils/key.h" -namespace psi::psi { +namespace psi { namespace { // Check if the first line starts with BOM(Byte Order Mark). bool CheckIfBOMExists(const std::string& file_path) { @@ -291,4 +291,4 @@ CheckCsvReport CheckCsv(const std::string& input_file_path, return report; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_checker.h b/psi/utils/csv_checker.h similarity index 96% rename from psi/psi/utils/csv_checker.h rename to psi/utils/csv_checker.h index f914a399..bf9094c8 100644 --- a/psi/psi/utils/csv_checker.h +++ b/psi/utils/csv_checker.h @@ -17,7 +17,7 @@ #include #include -namespace psi::psi { +namespace psi { // TODO(junfeng): replace CsvChecker with CheckCsv. class CsvChecker { @@ -50,4 +50,4 @@ CheckCsvReport CheckCsv(const std::string& input_file_path, const std::vector& keys, bool check_duplicates, bool generate_key_hash_digest); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_checker_test.cc b/psi/utils/csv_checker_test.cc similarity index 98% rename from psi/psi/utils/csv_checker_test.cc rename to psi/utils/csv_checker_test.cc index 00afc96e..2a5232b7 100644 --- a/psi/psi/utils/csv_checker_test.cc +++ b/psi/utils/csv_checker_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/csv_checker.h" +#include "psi/utils/csv_checker.h" #include @@ -21,9 +21,9 @@ #include "gtest/gtest.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { namespace { struct TestParams { @@ -284,4 +284,4 @@ TEST(CheckCsvTest, works) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_header_analyzer.h b/psi/utils/csv_header_analyzer.h similarity index 99% rename from psi/psi/utils/csv_header_analyzer.h rename to psi/utils/csv_header_analyzer.h index 3df7b1af..cd545437 100644 --- a/psi/psi/utils/csv_header_analyzer.h +++ b/psi/utils/csv_header_analyzer.h @@ -24,7 +24,7 @@ #include "absl/strings/str_split.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi { // TODO(junfeng): rm this class and replace usage with CsvHeaderParser. class CsvHeaderAnalyzer { @@ -126,4 +126,4 @@ class CsvHeaderAnalyzer { std::string header_line_; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/utils/csv_header_parser.cc b/psi/utils/csv_header_parser.cc similarity index 96% rename from psi/psi/utils/csv_header_parser.cc rename to psi/utils/csv_header_parser.cc index 1419c7b0..3c1b021e 100644 --- a/psi/psi/utils/csv_header_parser.cc +++ b/psi/utils/csv_header_parser.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/csv_header_parser.h" +#include "psi/utils/csv_header_parser.h" #include #include @@ -22,7 +22,7 @@ #include "arrow/io/api.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi { CsvHeaderParser::CsvHeaderParser(const std::string& path) : path_(path) { YACL_ENFORCE(std::filesystem::exists(path_), "Input file {} doesn't exist.", @@ -66,4 +66,4 @@ std::vector CsvHeaderParser::target_indices( return indices; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_header_parser.h b/psi/utils/csv_header_parser.h similarity index 96% rename from psi/psi/utils/csv_header_parser.h rename to psi/utils/csv_header_parser.h index 5d552bae..30e3deb0 100644 --- a/psi/psi/utils/csv_header_parser.h +++ b/psi/utils/csv_header_parser.h @@ -20,7 +20,7 @@ #include #include -namespace psi::psi { +namespace psi { // Just another version of CsvHeaderAnalyzer based on Apache Arrow. class CsvHeaderParser { @@ -37,4 +37,4 @@ class CsvHeaderParser { std::unordered_map key_index_map_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/csv_header_parser_test.cc b/psi/utils/csv_header_parser_test.cc similarity index 94% rename from psi/psi/utils/csv_header_parser_test.cc rename to psi/utils/csv_header_parser_test.cc index be03e3e2..181509a1 100644 --- a/psi/psi/utils/csv_header_parser_test.cc +++ b/psi/utils/csv_header_parser_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/csv_header_parser.h" +#include "psi/utils/csv_header_parser.h" #include #include @@ -21,7 +21,7 @@ #include "gtest/gtest.h" -namespace psi::psi { +namespace psi { constexpr auto csv_content = R"csv(id1,id2,y1 1,"b","y1_1" @@ -54,4 +54,4 @@ TEST(CsvHeaderParserTest, Works) { { std::filesystem::remove(csv_path); } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/cuckoo_index.cc b/psi/utils/cuckoo_index.cc similarity index 98% rename from psi/psi/core/cuckoo_index.cc rename to psi/utils/cuckoo_index.cc index 0bce63f6..c42530cb 100644 --- a/psi/psi/core/cuckoo_index.cc +++ b/psi/utils/cuckoo_index.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/cuckoo_index.h" +#include "psi/utils/cuckoo_index.h" #include #include -namespace psi::psi { +namespace psi { CuckooIndex::CuckooIndex(const Options& options) : options_(options) { bins_.resize(options_.NumBins()); @@ -147,4 +147,4 @@ uint8_t CuckooIndex::MinCollidingHashIdx(uint64_t bin_index) const { return -1; } -} // namespace psi::psi \ No newline at end of file +} // namespace psi diff --git a/psi/psi/core/cuckoo_index.h b/psi/utils/cuckoo_index.h similarity index 99% rename from psi/psi/core/cuckoo_index.h rename to psi/utils/cuckoo_index.h index 8f93d629..46130d05 100644 --- a/psi/psi/core/cuckoo_index.h +++ b/psi/utils/cuckoo_index.h @@ -20,7 +20,7 @@ #include "yacl/base/exception.h" #include "yacl/base/int128.h" -namespace psi::psi { +namespace psi { // CuckooIndex does not want to be a container like `unordered_map` which // provides CRUD interfaces. Instead, CuckooIndex aims to decide the location @@ -153,4 +153,4 @@ class CuckooIndex { std::vector hashes_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/core/cuckoo_index_test.cc b/psi/utils/cuckoo_index_test.cc similarity index 97% rename from psi/psi/core/cuckoo_index_test.cc rename to psi/utils/cuckoo_index_test.cc index f0d4e1e5..e4501f01 100644 --- a/psi/psi/core/cuckoo_index_test.cc +++ b/psi/utils/cuckoo_index_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/core/cuckoo_index.h" +#include "psi/utils/cuckoo_index.h" #include #include @@ -20,7 +20,7 @@ #include "gtest/gtest.h" #include "yacl/crypto/utils/rand.h" -namespace psi::psi { +namespace psi { class CuckooIndexTest : public testing::TestWithParam {}; @@ -82,4 +82,4 @@ TEST(CuckooIndexTest, Bad_SmallScaleFactor) { ASSERT_THROW(cuckoo_index.Insert(absl::MakeSpan(inputs)), yacl::Exception); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ec.cc b/psi/utils/ec.cc similarity index 89% rename from psi/psi/utils/ec.cc rename to psi/utils/ec.cc index af013237..ac67cde9 100644 --- a/psi/psi/utils/ec.cc +++ b/psi/utils/ec.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/ec.h" +#include "psi/utils/ec.h" #include -#include "psi/psi/cryptor/ecc_cryptor.h" -#include "psi/psi/utils/io.h" +#include "psi/cryptor/ecc_cryptor.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { std::vector ReadEcSecretKeyFile(const std::string& file_path) { size_t file_byte_size = 0; @@ -41,4 +41,4 @@ std::vector ReadEcSecretKeyFile(const std::string& file_path) { return secret_key; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ec.h b/psi/utils/ec.h similarity index 93% rename from psi/psi/utils/ec.h rename to psi/utils/ec.h index e6d02b6c..2ea801a3 100644 --- a/psi/psi/utils/ec.h +++ b/psi/utils/ec.h @@ -17,8 +17,8 @@ #include #include -namespace psi::psi { +namespace psi { std::vector ReadEcSecretKeyFile(const std::string& file_path); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ec_point_store.cc b/psi/utils/ec_point_store.cc similarity index 98% rename from psi/psi/utils/ec_point_store.cc rename to psi/utils/ec_point_store.cc index 8308e70c..ecd41a19 100644 --- a/psi/psi/utils/ec_point_store.cc +++ b/psi/utils/ec_point_store.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/ec_point_store.h" +#include "psi/utils/ec_point_store.h" #include @@ -25,9 +25,9 @@ #include "absl/strings/escaping.h" #include "spdlog/spdlog.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/utils/batch_provider.h" -namespace psi::psi { +namespace psi { void MemoryEcPointStore::Save(std::string ciphertext) { store_.push_back(std::move(ciphertext)); @@ -294,4 +294,4 @@ std::vector GetIndicesByItems( return indices; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ec_point_store.h b/psi/utils/ec_point_store.h similarity index 96% rename from psi/psi/utils/ec_point_store.h rename to psi/utils/ec_point_store.h index d4440f37..7970f665 100644 --- a/psi/psi/utils/ec_point_store.h +++ b/psi/utils/ec_point_store.h @@ -24,10 +24,10 @@ #include "yacl/link/link.h" -#include "psi/psi/utils/hash_bucket_cache.h" -#include "psi/psi/utils/index_store.h" +#include "psi/utils/hash_bucket_cache.h" +#include "psi/utils/index_store.h" -namespace psi::psi { +namespace psi { class IEcPointStore { public: @@ -154,4 +154,4 @@ std::pair, std::vector> FinalizeAndComputeIndices(const std::shared_ptr& self, const std::shared_ptr& peer, size_t batch_size); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/emp_io_adapter.cc b/psi/utils/emp_io_adapter.cc similarity index 99% rename from psi/psi/utils/emp_io_adapter.cc rename to psi/utils/emp_io_adapter.cc index b0f5df27..b7f67f72 100644 --- a/psi/psi/utils/emp_io_adapter.cc +++ b/psi/utils/emp_io_adapter.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/emp_io_adapter.h" +#include "psi/utils/emp_io_adapter.h" #include diff --git a/psi/psi/utils/emp_io_adapter.h b/psi/utils/emp_io_adapter.h similarity index 100% rename from psi/psi/utils/emp_io_adapter.h rename to psi/utils/emp_io_adapter.h diff --git a/psi/psi/utils/emp_io_adapter_test.cc b/psi/utils/emp_io_adapter_test.cc similarity index 98% rename from psi/psi/utils/emp_io_adapter_test.cc rename to psi/utils/emp_io_adapter_test.cc index e10897b5..06a521bb 100644 --- a/psi/psi/utils/emp_io_adapter_test.cc +++ b/psi/utils/emp_io_adapter_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/emp_io_adapter.h" +#include "psi/utils/emp_io_adapter.h" #include #include diff --git a/psi/psi/utils/hash_bucket_cache.cc b/psi/utils/hash_bucket_cache.cc similarity index 95% rename from psi/psi/utils/hash_bucket_cache.cc rename to psi/utils/hash_bucket_cache.cc index e61f3ca9..acea9fa4 100644 --- a/psi/psi/utils/hash_bucket_cache.cc +++ b/psi/utils/hash_bucket_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/hash_bucket_cache.h" +#include "psi/utils/hash_bucket_cache.h" #include #include @@ -22,9 +22,9 @@ #include "absl/strings/str_split.h" #include "spdlog/spdlog.h" -#include "psi/psi/utils/arrow_csv_batch_provider.h" +#include "psi/utils/arrow_csv_batch_provider.h" -namespace psi::psi { +namespace psi { HashBucketCache::HashBucketCache(std::string target_dir, uint32_t bucket_num, bool use_scoped_tmp_dir) @@ -96,4 +96,4 @@ std::unique_ptr CreateCacheFromCsv( return bucket_cache; } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/hash_bucket_cache.h b/psi/utils/hash_bucket_cache.h similarity index 94% rename from psi/psi/utils/hash_bucket_cache.h rename to psi/utils/hash_bucket_cache.h index 1e7f18bd..2ebe2e58 100644 --- a/psi/psi/utils/hash_bucket_cache.h +++ b/psi/utils/hash_bucket_cache.h @@ -23,10 +23,10 @@ #include "fmt/format.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/io.h" -#include "psi/psi/utils/multiplex_disk_cache.h" +#include "psi/utils/io.h" +#include "psi/utils/multiplex_disk_cache.h" -namespace psi::psi { +namespace psi { class HashBucketCache { public: @@ -80,4 +80,4 @@ std::unique_ptr CreateCacheFromCsv( const std::string& cache_dir, uint32_t bucket_num, uint32_t read_batch_size = 4096, bool use_scoped_tmp_dir = true); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/index_store.cc b/psi/utils/index_store.cc similarity index 98% rename from psi/psi/utils/index_store.cc rename to psi/utils/index_store.cc index 5855a1f9..47b8922f 100644 --- a/psi/psi/utils/index_store.cc +++ b/psi/utils/index_store.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/index_store.h" +#include "psi/utils/index_store.h" #include #include @@ -24,7 +24,7 @@ #include "arrow/csv/options.h" #include "spdlog/spdlog.h" -namespace psi::psi { +namespace psi { IndexWriter::IndexWriter(const std::filesystem::path& path, size_t cache_size, bool trunc) @@ -157,4 +157,4 @@ std::optional IndexReader::GetNext() { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/index_store.h b/psi/utils/index_store.h similarity index 97% rename from psi/psi/utils/index_store.h rename to psi/utils/index_store.h index d8a773c4..2d735545 100644 --- a/psi/psi/utils/index_store.h +++ b/psi/utils/index_store.h @@ -25,7 +25,7 @@ #include "arrow/ipc/api.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi { constexpr char kIdx[] = "psi_index"; @@ -94,4 +94,4 @@ class IndexReader { std::shared_ptr array_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/index_store_test.cc b/psi/utils/index_store_test.cc similarity index 97% rename from psi/psi/utils/index_store_test.cc rename to psi/utils/index_store_test.cc index eb26614a..c8410d5e 100644 --- a/psi/psi/utils/index_store_test.cc +++ b/psi/utils/index_store_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/index_store.h" +#include "psi/utils/index_store.h" #include #include @@ -21,7 +21,7 @@ #include "gtest/gtest.h" -namespace psi::psi { +namespace psi { class IndexStoreTest : public ::testing::Test { protected: @@ -126,4 +126,4 @@ TEST_F(IndexStoreTest, Empty) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/io.cc b/psi/utils/io.cc similarity index 97% rename from psi/psi/utils/io.cc rename to psi/utils/io.cc index 0c4a7392..4802ee9f 100644 --- a/psi/psi/utils/io.cc +++ b/psi/utils/io.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" #include #include @@ -23,7 +23,7 @@ #include "yacl/io/stream/file_io.h" #include "yacl/io/stream/mem_io.h" -namespace psi::psi::io { +namespace psi::io { std::unique_ptr BuildInputStream(const std::any& io_options) { std::unique_ptr is; @@ -86,4 +86,4 @@ std::unique_ptr BuildWriter(const std::any& io_options, return ret; } -} // namespace psi::psi::io +} // namespace psi::io diff --git a/psi/psi/utils/io.h b/psi/utils/io.h similarity index 98% rename from psi/psi/utils/io.h rename to psi/utils/io.h index ded01690..1effc2aa 100644 --- a/psi/psi/utils/io.h +++ b/psi/utils/io.h @@ -21,7 +21,7 @@ #include "yacl/io/rw/reader.h" #include "yacl/io/rw/writer.h" #include "yacl/io/stream/interface.h" -namespace psi::psi::io { +namespace psi::io { using Schema = yacl::io::Schema; @@ -103,4 +103,4 @@ std::unique_ptr BuildInputStream(const std::any& io_options); // !!! PLS manually call Close before release OutputStream !!! std::unique_ptr BuildOutputStream(const std::any& io_options); -} // namespace psi::psi::io +} // namespace psi::io diff --git a/psi/psi/utils/key.cc b/psi/utils/key.cc similarity index 95% rename from psi/psi/utils/key.cc rename to psi/utils/key.cc index 5abc1ce7..000c1877 100644 --- a/psi/psi/utils/key.cc +++ b/psi/utils/key.cc @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/key.h" +#include "psi/utils/key.h" #include #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/csv_header_parser.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/csv_header_parser.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { // Multiple-Key out-of-core sort. // Out-of-core support reference: @@ -98,4 +98,4 @@ std::string KeysJoin(const std::vector& keys) { return absl::StrJoin(keys, ","); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/key.h b/psi/utils/key.h similarity index 97% rename from psi/psi/utils/key.h rename to psi/utils/key.h index 17349641..515923d2 100644 --- a/psi/psi/utils/key.h +++ b/psi/utils/key.h @@ -19,7 +19,7 @@ #include "absl/strings/string_view.h" -namespace psi::psi { +namespace psi { // Multiple-Key out-of-core sort. // Out-of-core support reference: @@ -39,4 +39,4 @@ void MultiKeySort(const std::string& in_csv, const std::string& out_csv, // join keys with "," std::string KeysJoin(const std::vector& keys); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/multiplex_disk_cache.cc b/psi/utils/multiplex_disk_cache.cc similarity index 96% rename from psi/psi/utils/multiplex_disk_cache.cc rename to psi/utils/multiplex_disk_cache.cc index d57f1bd4..3e5be711 100644 --- a/psi/psi/utils/multiplex_disk_cache.cc +++ b/psi/utils/multiplex_disk_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/multiplex_disk_cache.h" +#include "psi/utils/multiplex_disk_cache.h" #include @@ -26,7 +26,7 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi { bool ScopedTempDir::CreateUniqueTempDirUnderPath( const std::filesystem::path& parent_path) { @@ -81,4 +81,4 @@ std::unique_ptr MultiplexDiskCache::CreateInputStream( return io::BuildInputStream(io::FileIoOptions(GetPath(index))); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/multiplex_disk_cache.h b/psi/utils/multiplex_disk_cache.h similarity index 96% rename from psi/psi/utils/multiplex_disk_cache.h rename to psi/utils/multiplex_disk_cache.h index 21f4e6f4..f67e1a10 100644 --- a/psi/psi/utils/multiplex_disk_cache.h +++ b/psi/utils/multiplex_disk_cache.h @@ -22,9 +22,9 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { class ScopedTempDir { public: @@ -68,4 +68,4 @@ class MultiplexDiskCache { std::unique_ptr scoped_temp_dir_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/multiplex_disk_cache_test.cc b/psi/utils/multiplex_disk_cache_test.cc similarity index 95% rename from psi/psi/utils/multiplex_disk_cache_test.cc rename to psi/utils/multiplex_disk_cache_test.cc index 2413d4f5..5399f158 100644 --- a/psi/psi/utils/multiplex_disk_cache_test.cc +++ b/psi/utils/multiplex_disk_cache_test.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/multiplex_disk_cache.h" +#include "psi/utils/multiplex_disk_cache.h" #include #include "gtest/gtest.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { std::size_t GetFileCntInDirectory(const std::filesystem::path& path) { using std::filesystem::directory_iterator; @@ -82,4 +82,4 @@ TEST_F(MultiplexDiskCacheTest, UseScopedTmpDirOff) { EXPECT_EQ(10, GetFileCntInDirectory(cache_path)); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/progress.cc b/psi/utils/progress.cc similarity index 97% rename from psi/psi/utils/progress.cc rename to psi/utils/progress.cc index d830f469..e81e9e91 100644 --- a/psi/psi/utils/progress.cc +++ b/psi/utils/progress.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/progress.h" +#include "psi/utils/progress.h" #include #include #include "fmt/format.h" -namespace psi::psi { +namespace psi { Progress::Progress(std::string description) : description_(description), @@ -142,4 +142,4 @@ std::shared_ptr Progress::NextSubProgress( return p; } -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/utils/progress.h b/psi/utils/progress.h similarity index 97% rename from psi/psi/utils/progress.h rename to psi/utils/progress.h index 40dc5bb7..5d277dcc 100644 --- a/psi/psi/utils/progress.h +++ b/psi/utils/progress.h @@ -20,7 +20,7 @@ #include #include -namespace psi::psi { +namespace psi { class Progress { public: @@ -80,4 +80,4 @@ class Progress { std::atomic_bool done_; }; -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/utils/progress_test.cc b/psi/utils/progress_test.cc similarity index 96% rename from psi/psi/utils/progress_test.cc rename to psi/utils/progress_test.cc index fb1cdca5..23cc995c 100644 --- a/psi/psi/utils/progress_test.cc +++ b/psi/utils/progress_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/progress.h" +#include "psi/utils/progress.h" #include "gtest/gtest.h" -namespace psi::psi { +namespace psi { class ProgressTest : public ::testing::Test {}; @@ -82,4 +82,4 @@ TEST_F(ProgressTest, ProgressMultiPhase_Normal) { EXPECT_EQ(data.description, "Step2, 100%"); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/recovery.cc b/psi/utils/recovery.cc similarity index 98% rename from psi/psi/utils/recovery.cc rename to psi/utils/recovery.cc index 3561652c..3df9ea56 100644 --- a/psi/psi/utils/recovery.cc +++ b/psi/utils/recovery.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/recovery.h" +#include "psi/utils/recovery.h" #include @@ -23,11 +23,11 @@ #include "google/protobuf/util/message_differencer.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/io.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { v2::RecoveryCheckpoint LoadRecoveryCheckpointFromFile( const std::filesystem::path& path) { @@ -249,4 +249,4 @@ void RecoveryManager::MarkPostProcessEnd() { } } -} // namespace psi::psi \ No newline at end of file +} // namespace psi \ No newline at end of file diff --git a/psi/psi/utils/recovery.h b/psi/utils/recovery.h similarity index 97% rename from psi/psi/utils/recovery.h rename to psi/utils/recovery.h index 41d752ec..341b548c 100644 --- a/psi/psi/utils/recovery.h +++ b/psi/utils/recovery.h @@ -21,11 +21,11 @@ #include "yacl/link/link.h" -#include "psi/psi/cryptor/ecc_cryptor.h" +#include "psi/cryptor/ecc_cryptor.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { v2::RecoveryCheckpoint LoadRecoveryCheckpointFromFile( const std::filesystem::path& path); @@ -120,4 +120,4 @@ class RecoveryManager { uint64_t parsed_bucket_count_from_peer_ = 0; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/recovery_test.cc b/psi/utils/recovery_test.cc similarity index 96% rename from psi/psi/utils/recovery_test.cc rename to psi/utils/recovery_test.cc index 9da0b9c5..d8906fab 100644 --- a/psi/psi/utils/recovery_test.cc +++ b/psi/utils/recovery_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/recovery.h" +#include "psi/utils/recovery.h" #include @@ -21,11 +21,11 @@ #include "gtest/gtest.h" -#include "psi/psi/cryptor/cryptor_selector.h" +#include "psi/cryptor/cryptor_selector.h" #include "psi/proto/psi_v2.pb.h" -namespace psi::psi { +namespace psi { class RecoveryTest : public ::testing::Test { protected: @@ -125,4 +125,4 @@ TEST_F(RecoveryTest, Mark) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/resource.cc b/psi/utils/resource.cc similarity index 95% rename from psi/psi/utils/resource.cc rename to psi/utils/resource.cc index 6de1b0b8..eb52b26f 100644 --- a/psi/psi/utils/resource.cc +++ b/psi/utils/resource.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/resource.h" +#include "psi/utils/resource.h" #include #include @@ -21,7 +21,7 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -namespace psi::psi { +namespace psi { std::string ReadProcSelfStatusByKey(const std::string& key) { std::string ret; @@ -53,4 +53,4 @@ size_t ReadVMxFromProcSelfStatus(const std::string& key) { size_t GetPeakKbMemUsage() { return ReadVMxFromProcSelfStatus("VmHWM"); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/resource.h b/psi/utils/resource.h similarity index 94% rename from psi/psi/utils/resource.h rename to psi/utils/resource.h index 8dbf2b3c..2d1ee790 100644 --- a/psi/psi/utils/resource.h +++ b/psi/utils/resource.h @@ -18,7 +18,7 @@ #include "yacl/base/exception.h" -namespace psi::psi { +namespace psi { /* * VmHWM from /proc/self/status @@ -26,4 +26,4 @@ namespace psi::psi { */ size_t GetPeakKbMemUsage(); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/serializable.proto b/psi/utils/serializable.proto similarity index 97% rename from psi/psi/utils/serializable.proto rename to psi/utils/serializable.proto index 46e4fc27..4d4b1c57 100644 --- a/psi/psi/utils/serializable.proto +++ b/psi/utils/serializable.proto @@ -16,7 +16,7 @@ syntax = "proto3"; -package psi.psi.proto; +package psi.proto; message SizeProto { uint64 input_size = 1; diff --git a/psi/psi/utils/serialize.h b/psi/utils/serialize.h similarity index 94% rename from psi/psi/utils/serialize.h rename to psi/utils/serialize.h index 03befb95..6290ccaa 100644 --- a/psi/psi/utils/serialize.h +++ b/psi/utils/serialize.h @@ -19,9 +19,9 @@ #include "yacl/base/buffer.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi::utils { +namespace psi::utils { inline yacl::Buffer SerializeSize(size_t size) { proto::SizeProto proto; @@ -65,4 +65,4 @@ inline size_t GetCompareBytesLength(size_t size_a, size_t size_b, return (compare_bits + 7) / 8; } -} // namespace psi::psi::utils +} // namespace psi::utils diff --git a/psi/psi/utils/sync.cc b/psi/utils/sync.cc similarity index 94% rename from psi/psi/utils/sync.cc rename to psi/utils/sync.cc index 37a8e4ca..1222f84d 100644 --- a/psi/psi/utils/sync.cc +++ b/psi/utils/sync.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/sync.h" +#include "psi/utils/sync.h" -#include "psi/psi/utils/serialize.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi { std::vector AllGatherItemsSize( const std::shared_ptr& link_ctx, size_t self_size) { @@ -58,4 +58,4 @@ void BroadcastResult(const std::shared_ptr& link_ctx, } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/sync.h b/psi/utils/sync.h similarity index 97% rename from psi/psi/utils/sync.h rename to psi/utils/sync.h index 8c22ee8b..2fe058f4 100644 --- a/psi/psi/utils/sync.h +++ b/psi/utils/sync.h @@ -25,9 +25,9 @@ #include "yacl/link/link.h" -#include "psi/psi/utils/serializable.pb.h" +#include "psi/utils/serializable.pb.h" -namespace psi::psi { +namespace psi { namespace { @@ -113,4 +113,4 @@ std::vector AllGatherItemsSize( void BroadcastResult(const std::shared_ptr& link_ctx, std::vector* res); -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/test_utils.h b/psi/utils/test_utils.h similarity index 96% rename from psi/psi/utils/test_utils.h rename to psi/utils/test_utils.h index f6e9756f..3065a781 100644 --- a/psi/psi/utils/test_utils.h +++ b/psi/utils/test_utils.h @@ -23,7 +23,7 @@ #include "psi/proto/psi.pb.h" -namespace psi::psi::test { +namespace psi::test { inline std::vector CreateRangeItems(size_t begin, size_t size) { std::vector ret; @@ -66,4 +66,4 @@ inline std::optional GetOverrideCurveType() { return {}; } -} // namespace psi::psi::test +} // namespace psi::test diff --git a/psi/psi/utils/ub_psi_cache.cc b/psi/utils/ub_psi_cache.cc similarity index 97% rename from psi/psi/utils/ub_psi_cache.cc rename to psi/utils/ub_psi_cache.cc index 96f020f5..2000a91c 100644 --- a/psi/psi/utils/ub_psi_cache.cc +++ b/psi/utils/ub_psi_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/ub_psi_cache.h" +#include "psi/utils/ub_psi_cache.h" #include #include @@ -20,9 +20,9 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" -#include "psi/psi/utils/serialize.h" +#include "psi/utils/serialize.h" -namespace psi::psi { +namespace psi { UbPsiCacheProvider::UbPsiCacheProvider(const std::string &file_path, size_t batch_size, size_t data_len) @@ -144,4 +144,4 @@ void UbPsiCache::SaveData(yacl::ByteContainerView item, size_t index, out_stream_->Write(data_with_index.data(), data_with_index.length()); } -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ub_psi_cache.h b/psi/utils/ub_psi_cache.h similarity index 95% rename from psi/psi/utils/ub_psi_cache.h rename to psi/utils/ub_psi_cache.h index 73a3c68e..76f819c3 100644 --- a/psi/psi/utils/ub_psi_cache.h +++ b/psi/utils/ub_psi_cache.h @@ -22,10 +22,10 @@ #include "yacl/base/byte_container_view.h" -#include "psi/psi/utils/batch_provider.h" -#include "psi/psi/utils/io.h" +#include "psi/utils/batch_provider.h" +#include "psi/utils/io.h" -namespace psi::psi { +namespace psi { class UbPsiCacheProvider : public IBasicBatchProvider, public IShuffledBatchProvider { @@ -87,4 +87,4 @@ class UbPsiCache : public IUbPsiCache { std::unique_ptr out_stream_; }; -} // namespace psi::psi +} // namespace psi diff --git a/psi/psi/utils/ub_psi_cache_test.cc b/psi/utils/ub_psi_cache_test.cc similarity index 96% rename from psi/psi/utils/ub_psi_cache_test.cc rename to psi/utils/ub_psi_cache_test.cc index 1b0953fc..0521d2d5 100644 --- a/psi/psi/utils/ub_psi_cache_test.cc +++ b/psi/utils/ub_psi_cache_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/utils/ub_psi_cache.h" +#include "psi/utils/ub_psi_cache.h" #include #include @@ -27,7 +27,7 @@ #include "yacl/crypto/utils/rand.h" #include "yacl/utils/scope_guard.h" -namespace psi::psi { +namespace psi { TEST(UbPsiCacheTest, Simple) { size_t data_len = 12; @@ -81,4 +81,4 @@ TEST(UbPsiCacheTest, Simple) { } } -} // namespace psi::psi +} // namespace psi diff --git a/psi/version.h b/psi/version.h index a5ea3771..43fbd6ff 100644 --- a/psi/version.h +++ b/psi/version.h @@ -17,4 +17,4 @@ #define PSI_VERSION_MAJOR 0 #define PSI_VERSION_MINOR 2 #define PSI_VERSION_PATCH 0 -#define PSI_DEV_IDENTIFIER ".dev240123" +#define PSI_DEV_IDENTIFIER ".dev240219"