From dc87af42dd82d9e982526c4ec20a7a616d58ce6a Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Sun, 17 Nov 2024 23:25:28 +0800 Subject: [PATCH 1/8] Add some changes --- psi/psi21_experiment/el_c_psi/BUILD.bazel | 61 ++++++ psi/psi21_experiment/el_c_psi/README.md | 10 + psi/psi21_experiment/el_c_psi/el_c_psi.cc | 165 +++++++++++++++ psi/psi21_experiment/el_c_psi/el_c_psi.h | 59 ++++++ .../el_c_psi/el_c_psi_benchmark.cc | 87 ++++++++ .../el_c_psi/el_c_psi_test.cc | 171 +++++++++++++++ psi/psi21_experiment/el_c_psi/el_hashing.cc | 69 ++++++ psi/psi21_experiment/el_c_psi/el_hashing.h | 65 ++++++ psi/psi21_experiment/el_c_psi/el_opprf.cc | 188 +++++++++++++++++ psi/psi21_experiment/el_c_psi/el_opprf.h | 35 ++++ psi/psi21_experiment/el_mp_psi/BUILD.bazel | 61 ++++++ psi/psi21_experiment/el_mp_psi/README.md | 8 + psi/psi21_experiment/el_mp_psi/el_hashing.cc | 69 ++++++ psi/psi21_experiment/el_mp_psi/el_hashing.h | 64 ++++++ psi/psi21_experiment/el_mp_psi/el_mp_psi.cc | 143 +++++++++++++ psi/psi21_experiment/el_mp_psi/el_mp_psi.h | 66 ++++++ .../el_mp_psi/el_mp_psi_benchmark.cc | 85 ++++++++ .../el_mp_psi/el_mp_psi_test.cc | 149 +++++++++++++ psi/psi21_experiment/el_mp_psi/el_sender.cc | 116 +++++++++++ psi/psi21_experiment/el_mp_psi/el_sender.h | 35 ++++ psi/psi21_experiment/el_q_psi/BUILD.bazel | 61 ++++++ psi/psi21_experiment/el_q_psi/README.md | 10 + psi/psi21_experiment/el_q_psi/el_hashing.cc | 69 ++++++ psi/psi21_experiment/el_q_psi/el_hashing.h | 65 ++++++ psi/psi21_experiment/el_q_psi/el_opprf.cc | 189 +++++++++++++++++ psi/psi21_experiment/el_q_psi/el_opprf.h | 35 ++++ psi/psi21_experiment/el_q_psi/el_q_psi.cc | 165 +++++++++++++++ psi/psi21_experiment/el_q_psi/el_q_psi.h | 60 ++++++ .../el_q_psi/el_q_psi_benchmark.cc | 87 ++++++++ .../el_q_psi/el_q_psi_test.cc | 196 ++++++++++++++++++ 30 files changed, 2643 insertions(+) create mode 100644 psi/psi21_experiment/el_c_psi/BUILD.bazel create mode 100644 psi/psi21_experiment/el_c_psi/README.md create mode 100644 psi/psi21_experiment/el_c_psi/el_c_psi.cc create mode 100644 psi/psi21_experiment/el_c_psi/el_c_psi.h create mode 100644 psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc create mode 100644 psi/psi21_experiment/el_c_psi/el_c_psi_test.cc create mode 100644 psi/psi21_experiment/el_c_psi/el_hashing.cc create mode 100644 psi/psi21_experiment/el_c_psi/el_hashing.h create mode 100644 psi/psi21_experiment/el_c_psi/el_opprf.cc create mode 100644 psi/psi21_experiment/el_c_psi/el_opprf.h create mode 100644 psi/psi21_experiment/el_mp_psi/BUILD.bazel create mode 100644 psi/psi21_experiment/el_mp_psi/README.md create mode 100644 psi/psi21_experiment/el_mp_psi/el_hashing.cc create mode 100644 psi/psi21_experiment/el_mp_psi/el_hashing.h create mode 100644 psi/psi21_experiment/el_mp_psi/el_mp_psi.cc create mode 100644 psi/psi21_experiment/el_mp_psi/el_mp_psi.h create mode 100644 psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc create mode 100644 psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc create mode 100644 psi/psi21_experiment/el_mp_psi/el_sender.cc create mode 100644 psi/psi21_experiment/el_mp_psi/el_sender.h create mode 100644 psi/psi21_experiment/el_q_psi/BUILD.bazel create mode 100644 psi/psi21_experiment/el_q_psi/README.md create mode 100644 psi/psi21_experiment/el_q_psi/el_hashing.cc create mode 100644 psi/psi21_experiment/el_q_psi/el_hashing.h create mode 100644 psi/psi21_experiment/el_q_psi/el_opprf.cc create mode 100644 psi/psi21_experiment/el_q_psi/el_opprf.h create mode 100644 psi/psi21_experiment/el_q_psi/el_q_psi.cc create mode 100644 psi/psi21_experiment/el_q_psi/el_q_psi.h create mode 100644 psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc create mode 100644 psi/psi21_experiment/el_q_psi/el_q_psi_test.cc diff --git a/psi/psi21_experiment/el_c_psi/BUILD.bazel b/psi/psi21_experiment/el_c_psi/BUILD.bazel new file mode 100644 index 00000000..b60d86bb --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/BUILD.bazel @@ -0,0 +1,61 @@ +# Copyright 2024 zhangwfjh +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_binary( + name = "el_c_psi_benchmark", + srcs = ["el_c_psi_benchmark.cc"], + deps = [ + ":el_c_psi", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "el_c_psi", + srcs = [ + "el_hashing.cc", + "el_c_psi.cc", + "el_opprf.cc", + ], + hdrs = [ + "el_hashing.h", + "el_c_psi.h", + "el_opprf.h", + ], + deps = [ + "//psi/utils:communication", + "//psi/utils:sync", + "//psi/utils:test_utils", + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:exception", + "@yacl//yacl/base:int128", + "@yacl//yacl/crypto/hash:hash_utils", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/kernel/algorithms:base_ot", + "@yacl//yacl/kernel/algorithms:iknp_ote", + "@yacl//yacl/kernel/algorithms:kkrt_ote", + "@yacl//yacl/link", + ], +) + +psi_cc_test( + name = "el_c_psi_test", + srcs = ["el_c_psi_test.cc"], + tags = ["manual"], + deps = [":el_c_psi"], +) diff --git a/psi/psi21_experiment/el_c_psi/README.md b/psi/psi21_experiment/el_c_psi/README.md new file mode 100644 index 00000000..f3dd2202 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/README.md @@ -0,0 +1,10 @@ +论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI + +论文地址:https://www.xueshufan.com/publication/3150904314 + +方案概括:circuit_psi
+1、参与方分别对持有的隐私集合中的元素进行分桶,通过不经意伪随机函数协议为元素计算伪随机函数值;
+2、发送方为元素的伪随机函数值生成不经意键值存储,发送给接收方;
+3、接收方将桶中元素的伪随机函数值与不经意键值存储进行异或操作,生成桶向量;
+4、将接收方的桶向量及发送方选取的向量输入到零分享测试电路中,测试所有输入是否为零的加性秘密共享,生成每个参与方的比特秘密共享;
+5、基于所有参与方的比特秘密共享,通过计算对称函数的电路得到在交集上的对称函数。
diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi.cc b/psi/psi21_experiment/el_c_psi/el_c_psi.cc new file mode 100644 index 00000000..ca3d7427 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_c_psi.cc @@ -0,0 +1,165 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" + +#include + +#include "psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "psi/utils/communication.h" +#include "psi/utils/sync.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace { + +constexpr uint32_t kLinkRecvTimeout = 60 * 60 * 1000; + +} // namespace + +NcParty::NcParty(const Options& options) : options_{options} { + auto [ctx, wsize, me, leader] = CollectContext(); + ctx->SetRecvTimeout(kLinkRecvTimeout); + p2p_.resize(wsize); + for (size_t dst{}; dst != wsize; ++dst) { + if (me != dst) { + p2p_[dst] = CreateP2PLinkCtx("el_c_psi", ctx, dst); + } + } +} + +std::vector NcParty::Run( + const std::vector& inputs) { + auto [ctx, wsize, me, leader] = CollectContext(); + auto counts = AllGatherItemsSize(ctx, inputs.size()); + size_t count{}; + for (auto cnt : counts) { + if (cnt == 0) { + return {}; + } + count = std::max(cnt, count); + } + + auto items = EncodeInputs(inputs, count); + auto shares = ZeroSharing(count); + auto recv_share = SwapShares(items, shares); + auto recons = Reconstruct(items, recv_share); + std::vector intersection; + for (size_t k{}; k != count; ++k) { + if (recons[k] == 0) { + intersection.push_back("1"); + } else { + intersection.push_back("0"); + } + } + return intersection; +} + +std::vector NcParty::EncodeInputs( + const std::vector& inputs, size_t count) const { + std::vector items; + items.reserve(count); + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(items), + [](std::string_view input) { return yacl::crypto::Blake3_128(input); }); + // Add random dummy elements + std::generate_n(std::back_inserter(items), count - inputs.size(), + yacl::crypto::FastRandU128); + return items; +} + +auto NcParty::ZeroSharing(size_t count) const -> std::vector { + auto [ctx, wsize, me, leader] = CollectContext(); + std::vector shares(wsize, Share(count)); + for (size_t k{}; k != count; ++k) { + uint64_t sum{}; + for (size_t dst{1}; dst != wsize; ++dst) { + sum ^= shares[dst][k] = yacl::crypto::FastRandU64(); + } + shares[0][k] = sum; + } + return shares; +} + +auto NcParty::SwapShares(const std::vector& items, + const std::vector& shares) const -> Share { + auto [ctx, wsize, me, leader] = CollectContext(); + auto count = shares.front().size(); + std::vector recv_shares(count); + std::vector> futures(wsize); + // NOTE: First Send Then Receive for peers of smaller ranks + for (size_t id{}; id != me; ++id) { + futures[id] = std::async( + [&](size_t id) { + ElOpprfSend(p2p_[id], items, shares[id]); + return ElOpprfRecv(p2p_[id], items); + }, + id); + } + // NOTE: First Receive Then Send for peers of larger ranks + for (size_t id{me + 1}; id != wsize; ++id) { + futures[id] = std::async( + [&](size_t id) { + auto ret = ElOpprfRecv(p2p_[id], items); + ElOpprfSend(p2p_[id], items, shares[id]); + return ret; + }, + id); + } + for (size_t id{}; id != wsize; ++id) { + recv_shares[id] = (me == id ? shares[id] : futures[id].get()); + } + + Share share(count); // S(x_k) + for (size_t k{}; k != count; ++k) { + for (size_t src{}; src != wsize; ++src) { + share[k] ^= recv_shares[src][k]; + } + } + return share; +} + +auto NcParty::Reconstruct(const std::vector& items, + const Share& share) const -> Share { + auto [ctx, wsize, me, leader] = CollectContext(); + auto count = items.size(); + if (me == leader) { + std::vector recv_shares(count); + std::vector> futures(wsize); + for (size_t src{}; src != wsize; ++src) { + if (me != src) { + futures[src] = std::async( + [&](size_t src) { return ElOpprfRecv(p2p_[src], items); }, src); + } + } + for (size_t src{}; src != wsize; ++src) { + recv_shares[src] = (me == src ? share : futures[src].get()); + } + Share recons(count); // sum of S_i(x_k) over i + for (size_t k{}; k != count; ++k) { + for (size_t src{}; src != wsize; ++src) { + recons[k] ^= recv_shares[src][k]; + } + } + return recons; + } else { + ElOpprfSend(p2p_[leader], items, share); + return share; + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi.h b/psi/psi21_experiment/el_c_psi/el_c_psi.h new file mode 100644 index 00000000..00176003 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_c_psi.h @@ -0,0 +1,59 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Practical Multi-party Private Set Intersection from Symmetric-Key Techniques +// https://eprint.iacr.org/2017/799.pdf + +class NcParty { + public: + struct Options { + std::shared_ptr link_ctx; + size_t leader_rank; + }; + + NcParty(const Options& options); + virtual std::vector Run(const std::vector& inputs); + + private: + using Share = std::vector; + + std::vector EncodeInputs(const std::vector& inputs, + size_t count) const; + std::vector ZeroSharing(size_t count) const; + Share SwapShares(const std::vector& items, + const std::vector& shares) const; + Share Reconstruct(const std::vector& items, + const Share& share) const; + + // (ctx, world_size, my_rank, leader_rank) + auto CollectContext() const { + return std::make_tuple(options_.link_ctx, options_.link_ctx->WorldSize(), + options_.link_ctx->Rank(), options_.leader_rank); + } + + Options options_; + std::vector> p2p_; +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc b/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc new file mode 100644 index 00000000..08721c75 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc @@ -0,0 +1,87 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "benchmark/benchmark.h" +#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "yacl/base/exception.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/link/test_util.h" + +namespace { +std::vector CreateRangeItems(size_t begin, size_t size) { + std::vector ret(size); + for (size_t i = 0; i < size; i++) { + auto hash = yacl::crypto::Blake3(std::to_string(begin + i)); + memcpy(&ret[i], hash.data(), sizeof(uint128_t)); + } + return ret; +} + +void ElCPsiSend(const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + // auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); + // return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); + std::vector shares; + for (size_t i = 0; i < items_hash.size(); i++) { + uint64_t item = 0; + shares.push_back(item); + } + + return psi::psi::ElOpprfSend(link_ctx, items_hash, shares); +} + +std::vector ElCPsiRecv( + const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + // auto ot_send = psi::kkrt::GetKkrtOtReceiverOptions(link_ctx, 512); + // return psi::kkrt::KkrtPsiRecv(link_ctx, ot_send, items_hash); + return psi::psi::ElOpprfRecv(link_ctx, items_hash); +} + +} // namespace + +static void BM_El_C_Psi(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + size_t n = state.range(0); + auto alice_items = CreateRangeItems(1, n); + auto bob_items = CreateRangeItems(2, n); + + auto contexts = yacl::link::test::SetupWorld(2); + + state.ResumeTiming(); + + std::future kkrt_psi_sender = + std::async([&] { return ElCPsiSend(contexts[0], alice_items); }); + std::future> kkrt_psi_receiver = + std::async([&] { return ElCPsiRecv(contexts[1], bob_items); }); + + kkrt_psi_sender.get(); + auto results_b = kkrt_psi_receiver.get(); + } +} + +// [256k, 512k, 1m, 2m, 4m, 8m] +BENCHMARK(BM_El_C_Psi) + ->Unit(benchmark::kMillisecond) + ->Arg(256 << 10) + ->Arg(512 << 10) + ->Arg(1 << 20) + ->Arg(2 << 20) + ->Arg(4 << 20) + ->Arg(8 << 20); diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc b/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc new file mode 100644 index 00000000..f48b4ee3 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc @@ -0,0 +1,171 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "psi/utils/test_utils.h" +#include "yacl/link/test_util.h" + +namespace psi::psi { + +namespace { + +struct NCTestParams { + std::vector item_size; + size_t intersection_size; +}; + +std::vector> CreateNPartyItems( + const NCTestParams& params) { + std::vector> ret(params.item_size.size() + 1); + ret[params.item_size.size()] = + test::CreateRangeItems(1, params.intersection_size); + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + ret[idx] = + test::CreateRangeItems((idx + 1) * 1000000, params.item_size[idx]); + } + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + std::set idx_set; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, params.item_size[idx] - 1); + + while (idx_set.size() < params.intersection_size) { + idx_set.insert(dis(gen)); + } + size_t j = 0; + for (const auto& iter : idx_set) { + ret[idx][iter] = ret[params.item_size.size()][j++]; + } + } + return ret; +} + +} // namespace + +class NCPsiTest : public testing::TestWithParam {}; + +// FIXME : this test is not stable in arm env +TEST_P(NCPsiTest, Works) { + std::vector> items; + std::vector> resultvec; + std::vector finalresult; + + auto params = GetParam(); + items = CreateNPartyItems(params); + size_t leader_rank = 0; + uint128_t maxlength = 0; + uint128_t n = params.item_size.size() - 1; + + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + std::vector> items1; + items1.push_back(items[0]); + items1.push_back(items[i + 1]); + leader_rank = 0; + + auto ctxs = yacl::link::test::SetupWorld(2); + auto proc = [&](int idx) -> std::vector { + NcParty::Options opts; + opts.link_ctx = ctxs[idx]; + opts.leader_rank = leader_rank; + NcParty op(opts); + // for (size_t j{}; j != items1[idx].size(); ++j) { + // SPDLOG_INFO(" items[{}][{}] = {}, size{}", idx, i, items[idx][i], + // items[idx].size()); + // } + + return op.Run(items[idx]); + }; + + size_t world_size = ctxs.size(); + std::vector>> f_links(world_size); + for (size_t j = 0; j < world_size; j++) { + f_links[j] = std::async(proc, j); + } + sleep(1); + + std::vector result; + result = f_links[0].get(); + resultvec.push_back(result); + + /*for (size_t j = 0; j < result.size(); j++) { + SPDLOG_INFO("i{} j{}, result[j] {} size{}", i, j, result[j], + result.size()); + }*/ + } + + maxlength = items[0].size(); + std::vector qpsivector; + for (size_t j = 0; j < maxlength; j++) { + uint128_t sum = 0; + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + // 如果有的集合没有那么多项就continue + // results[i] = f_links[i].get(); + if (resultvec[i].size() <= j) { + continue; + } + + // SPDLOG_INFO(" result[{}][{}] = {}", i, j, resultvec[i][j]); + auto it = resultvec[i].begin() + j; + std::string element = *it; + if (element == "1") { + sum++; + } + } + if (sum >= n) { + // todo//推入对应input元素 之后再查输入变量从param中怎么取出推入 + qpsivector.push_back(1); + } else { + qpsivector.push_back(0); + } + } + + // std::vector intersectionnparty; + for (size_t k{}; k != items[0].size(); ++k) { + if (qpsivector[k] == 1) { + finalresult.push_back(items[0][k]); + } + } + + /*for (size_t i{}; i != finalresult.size(); ++i) { + SPDLOG_INFO("intersectionnparty = {}", finalresult[i]); + }*/ + + std::vector intersection = items[params.item_size.size()]; + std::sort(intersection.begin(), intersection.end()); + + std::sort(finalresult.begin(), finalresult.end()); + EXPECT_EQ(finalresult.size(), intersection.size()); + EXPECT_EQ(finalresult, intersection); +} + +INSTANTIATE_TEST_SUITE_P( + Works_Instances, NCPsiTest, + // testing::Values(NCTestParams{{1, 3}, 1})); + testing::Values(NCTestParams{{0, 3}, 0}, // + NCTestParams{{3, 0}, 0}, // + NCTestParams{{0, 0}, 0}, // + NCTestParams{{4, 3}, 2}, // + NCTestParams{{20, 17, 14}, 10}, // + NCTestParams{{20, 17, 14, 30}, 10}, // + NCTestParams{{20, 17, 14, 30, 35}, 11}, // + NCTestParams{{20, 17, 14, 30, 35}, 0})); +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_hashing.cc b/psi/psi21_experiment/el_c_psi/el_hashing.cc new file mode 100644 index 00000000..b5432915 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_hashing.cc @@ -0,0 +1,69 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_c_psi/el_hashing.h" + +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/rand/rand.h" + +namespace psi::psi { + +void KmprtCuckooHashing::Insert(uint128_t elem) { + auto insert_into = [this, &elem](uint8_t c) { + for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) { + uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c]; + uint8_t idx = (rand_idx + 1) % num_hashes_[c]; + size_t addr; + do { + addr = HashU128{}(elem, idx) % num_bins_[c]; + if (auto &bin = bins_[c][addr]; bin == NONE || bin == elem) { + bin = std::exchange(elem, NONE); + return; + } + idx = (idx + 1) % num_hashes_[c]; + } while (idx != rand_idx); + std::swap(bins_[c][addr], elem); + } + }; + for (uint8_t c{}; c != 2; ++c) { + insert_into(c); + } + YACL_ENFORCE_EQ(elem, NONE, "Failed to insert element."); +} + +auto KmprtCuckooHashing::Lookup(uint128_t elem) const + -> std::pair { + for (uint8_t c{}; c != 2; ++c) { + for (uint8_t idx{}; idx != num_hashes_[c]; ++idx) { + if (size_t addr = HashU128{}(elem, idx) % num_bins_[c]; + bins_[c][addr] == elem) { + return {c, addr}; + } + } + } + return {-1, -1}; +} + +void KmprtSimpleHashing::Insert(std::pair point) { + for (uint8_t c{}; c != 2; ++c) { + for (size_t idx{}; idx != num_hashes_[c]; ++idx) { + bins_[c][HashU128{}(point.first, idx) % num_bins_[c]].emplace(point); + } + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_hashing.h b/psi/psi21_experiment/el_c_psi/el_hashing.h new file mode 100644 index 00000000..be80df99 --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_hashing.h @@ -0,0 +1,65 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "yacl/base/int128.h" + +namespace psi::psi { + +struct HashU128 { + size_t operator()(uint128_t x, uint8_t idx = 0) const { + return absl::Uint128High64(x) + idx * absl::Uint128Low64(x); + } +}; + +template +struct KmprtDoubleHashing { + KmprtDoubleHashing(size_t m1, size_t m2) : num_bins_{m1, m2} { + bins_[0].resize(m1); + bins_[1].resize(m2); + } + + const Bin &GetBin(uint8_t c, size_t addr) const { return bins_[c][addr]; } + + const uint8_t num_hashes_[2]{3, 2}; + const size_t num_bins_[2]; + std::vector bins_[2]; +}; + +class KmprtCuckooHashing : public KmprtDoubleHashing { + public: + constexpr static uint128_t NONE{}; + + KmprtCuckooHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(uint128_t); + std::pair Lookup(uint128_t) const; +}; + +class KmprtSimpleHashing + : public KmprtDoubleHashing< + std::unordered_map> { + public: + KmprtSimpleHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(std::pair); +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_opprf.cc b/psi/psi21_experiment/el_c_psi/el_opprf.cc new file mode 100644 index 00000000..1f8b6add --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_opprf.cc @@ -0,0 +1,188 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_c_psi/el_opprf.h" + +#include +#include +#include + +#include "el_c_psi/el_hashing.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/iknp_ote.h" +#include "yacl/kernel/algorithms/kkrt_ote.h" +#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace yc = yacl::crypto; + +namespace { + +// PSI-related constants +// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) +constexpr float ZETA[]{1.12f, 0.17f}; +// constexpr size_t BETA[]{31, 63}; +constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs + +// OTe-related constants +constexpr size_t NUM_BASE_OT{128}; +constexpr size_t NUM_INKP_OT{512}; +constexpr size_t BATCH_SIZE{896}; + +static auto ro = yc::RandomOracle::GetDefault(); + +} // namespace + +// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) + +std::vector ElOpprfRecv( + const std::shared_ptr& ctx, + const std::vector& queries) { + const size_t size{queries.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtReceiver receiver; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto choice = yc::RandBits(NUM_BASE_OT); + auto base_ot = yc::BaseOtRecv(ctx, choice, NUM_BASE_OT); + auto store = yc::IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); + receiver.Init(ctx, store, num_ot); + receiver.SetBatchSize(BATCH_SIZE); + + // Step 1. Hashes queries into Cuckoo hashing + KmprtCuckooHashing hashing{bin_sizes[0], bin_sizes[1]}; + for (size_t i{}; i != size; ++i) { + hashing.Insert(queries[i]); + } + + std::vector evals; + evals.reserve(num_ot); + size_t ot_idx{}, b{}; + std::array batch_evals; + // Step 2. For each bin, invokes single-query OPPRF + for (uint8_t c{}; c != 2; ++c) { + size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; + size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; + for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { + auto elem = hashing.GetBin(c, addr); + elem == KmprtCuckooHashing::NONE && (elem = yc::FastRandU128()); + receiver.Encode( + ot_idx, elem, + {reinterpret_cast(&batch_evals[b++]), sizeof(uint64_t)}); + if (auto batch_size = (ot_idx - ot_begin) % BATCH_SIZE + 1; + batch_size == BATCH_SIZE || ot_idx + 1 == ot_end) { + b = 0; + receiver.SendCorrection(ctx, batch_size); + // For each query in a batch + for (size_t i{}; i != batch_size; ++i) { + uint128_t nonce = yacl::DeserializeUint128( + ctx->Recv(ctx->NextRank(), "Receive OPPRF nonce")); + std::vector table(TABLE_SIZE[c]); + auto buf = + ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); + std::memcpy(table.data(), buf.data(), + TABLE_SIZE[c] * sizeof(uint64_t)); + uint64_t eval = batch_evals[i]; + auto index = + ro.Gen(absl::MakeSpan(reinterpret_cast(&eval), + sizeof eval), + nonce) % + table.size(); + evals.emplace_back(eval ^ table[index]); + } + } + } + } + + // Step 3. Filters and obtains the results + std::vector results(size); + std::transform(queries.cbegin(), queries.cend(), results.begin(), + [&](auto q) { + auto [c, addr] = hashing.Lookup(q); + return evals[c * bin_sizes[0] + addr]; + }); + return results; +} + +void ElOpprfSend(const std::shared_ptr& ctx, + const std::vector& xs, + const std::vector& ys) { + YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); + const size_t size{xs.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtSender sender; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto base_ot = yc::BaseOtSend(ctx, NUM_BASE_OT); + auto choice = yc::RandBits(NUM_INKP_OT); + auto store = yc::IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); + sender.Init(ctx, store, num_ot); + sender.SetBatchSize(BATCH_SIZE); + + // Step 1. Hashes points into Simple hashing + KmprtSimpleHashing hashing{bin_sizes[0], bin_sizes[1]}; + for (size_t i{}; i != size; ++i) { + hashing.Insert({xs[i], ys[i]}); + } + size_t ot_idx{}; + auto evaluator = sender.GetOprf(); + // Step 2. For each bin, invokes single-query OPPRF + for (uint8_t c{}; c != 2; ++c) { + size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; + size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; + // For each programmable point in a batch + for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { + if ((ot_idx - ot_begin) % BATCH_SIZE == 0) { + sender.RecvCorrection( + ctx, ot_idx + BATCH_SIZE >= ot_end ? ot_end - ot_idx : BATCH_SIZE); + } + auto bin = hashing.GetBin(c, addr); + uint128_t nonce; + std::vector table; + bool separable; + do { + separable = true; + nonce = yc::FastRandSeed(); + table.assign(TABLE_SIZE[c], uint64_t{0}); + for (auto it = bin.cbegin(); it != bin.cend(); ++it) { + uint64_t eval = evaluator->Eval(ot_idx, it->first); + auto index = + ro.Gen({reinterpret_cast(&eval), sizeof eval}, + nonce) % + table.size(); + if (table[index] != uint64_t{0}) { + separable = false; + break; + } + table[index] = eval ^ it->second; + } + } while (!separable); + for (size_t i{}; i != TABLE_SIZE[c]; ++i) { + table[i] == uint64_t{0} && (table[i] = yc::FastRandU128()); + } + ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(nonce), + fmt::format("OPPRF:Nonce={}", nonce)); + yacl::Buffer buf(table.data(), table.size() * sizeof(uint64_t)); + ctx->SendAsync(ctx->NextRank(), buf, "OPPRF:EncryptionTable"); + } + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_c_psi/el_opprf.h b/psi/psi21_experiment/el_c_psi/el_opprf.h new file mode 100644 index 00000000..a26a095c --- /dev/null +++ b/psi/psi21_experiment/el_c_psi/el_opprf.h @@ -0,0 +1,35 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) + +std::vector ElOpprfRecv(const std::shared_ptr&, + const std::vector& queries); + +void ElOpprfSend(const std::shared_ptr&, + const std::vector& xs, + const std::vector& ys); + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/BUILD.bazel b/psi/psi21_experiment/el_mp_psi/BUILD.bazel new file mode 100644 index 00000000..f4aa9260 --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/BUILD.bazel @@ -0,0 +1,61 @@ +# Copyright 2024 zhangwfjh +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_binary( + name = "el_mp_psi_benchmark", + srcs = ["el_mp_psi_benchmark.cc"], + deps = [ + ":el_mp_psi", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "el_mp_psi", + srcs = [ + "el_hashing.cc", + "el_mp_psi.cc", + "el_sender.cc", + ], + hdrs = [ + "el_hashing.h", + "el_mp_psi.h", + "el_sender.h", + ], + deps = [ + "//psi/utils:communication", + "//psi/utils:sync", + "//psi/utils:test_utils", + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:exception", + "@yacl//yacl/base:int128", + "@yacl//yacl/crypto/hash:hash_utils", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/kernel/algorithms:base_ot", + "@yacl//yacl/kernel/algorithms:iknp_ote", + "@yacl//yacl/kernel/algorithms:kkrt_ote", + "@yacl//yacl/link", + ], +) + +psi_cc_test( + name = "el_mp_psi_test", + srcs = ["el_mp_psi_test.cc"], + tags = ["manual"], + deps = [":el_mp_psi"], +) diff --git a/psi/psi21_experiment/el_mp_psi/README.md b/psi/psi21_experiment/el_mp_psi/README.md new file mode 100644 index 00000000..78f422ba --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/README.md @@ -0,0 +1,8 @@ +论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI + +论文地址:https://www.xueshufan.com/publication/3150904314 + +方案概括:multiparty_psi
+1、参与者为$P_1$,...,$P_n$。选择$P_1$作为接收方,$P_2$,...,$P_n$为发送方。所用参与者共同产生三个hash函数$h_1,h_2,h_3$。$P_1$通过Cuchoo Hashing产生一个$Table_1$,$P_2,...,P_n$使用普通的三次hash产生$Table_2,...,Table_n$(就是将一个数x hash三次放在表的三个位置上)。
+2、使用PSM协议,将$P_1$作为接收方发送$Table_1$,$P_2,...,P_n$为发送方发送$Table_2,...,Table_n$。
+3、结束时$P_1$对于表中的每一个位置的元素查询了n-2次,计作$y_{ij}$,其中ij代表在$P_i$的表中查询第j个元素,$P_2$,...,$P_n$获得$w_{ij}$。$P_1$计算$_1=\sum_{i=2}^{n} -y_{ij}$,$P_2$,...,$P_n$计算$_i=w_{ij}$,若查询到元素,可以看出来所有$$加起来为0。
\ No newline at end of file diff --git a/psi/psi21_experiment/el_mp_psi/el_hashing.cc b/psi/psi21_experiment/el_mp_psi/el_hashing.cc new file mode 100644 index 00000000..1eafee8e --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_hashing.cc @@ -0,0 +1,69 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_mp_psi/el_hashing.h" + +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/rand/rand.h" + +namespace psi::psi { + +void KmprtCuckooHashing::Insert(uint128_t elem) { + auto insert_into = [this, &elem](uint8_t c) { + for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) { + uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c]; + uint8_t idx = (rand_idx + 1) % num_hashes_[c]; + size_t addr; + do { + addr = HashU128{}(elem, idx) % num_bins_[c]; + if (auto &bin = bins_[c][addr]; bin == NONE || bin == elem) { + bin = std::exchange(elem, NONE); + return; + } + idx = (idx + 1) % num_hashes_[c]; + } while (idx != rand_idx); + std::swap(bins_[c][addr], elem); + } + }; + for (uint8_t c{}; c != 2; ++c) { + insert_into(c); + } + YACL_ENFORCE_EQ(elem, NONE, "Failed to insert element."); +} + +auto KmprtCuckooHashing::Lookup(uint128_t elem) const + -> std::pair { + for (uint8_t c{}; c != 2; ++c) { + for (uint8_t idx{}; idx != num_hashes_[c]; ++idx) { + if (size_t addr = HashU128{}(elem, idx) % num_bins_[c]; + bins_[c][addr] == elem) { + return {c, addr}; + } + } + } + return {-1, -1}; +} + +void KmprtSimpleHashing::Insert(std::pair point) { + for (uint8_t c{}; c != 2; ++c) { + for (size_t idx{}; idx != num_hashes_[c]; ++idx) { + bins_[c][HashU128{}(point.first, idx) % num_bins_[c]].emplace(point); + } + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_hashing.h b/psi/psi21_experiment/el_mp_psi/el_hashing.h new file mode 100644 index 00000000..3eaf309d --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_hashing.h @@ -0,0 +1,64 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "yacl/base/int128.h" + +namespace psi::psi { + +struct HashU128 { + size_t operator()(uint128_t x, uint8_t idx = 0) const { + return absl::Uint128High64(x) + idx * absl::Uint128Low64(x); + } +}; + +template +struct KmprtDoubleHashing { + KmprtDoubleHashing(size_t m1, size_t m2) : num_bins_{m1, m2} { + bins_[0].resize(m1); + bins_[1].resize(m2); + } + + const Bin &GetBin(uint8_t c, size_t addr) const { return bins_[c][addr]; } + + const uint8_t num_hashes_[2]{3, 2}; + const size_t num_bins_[2]; + std::vector bins_[2]; +}; + +class KmprtCuckooHashing : public KmprtDoubleHashing { + public: + constexpr static uint128_t NONE{}; + + KmprtCuckooHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(uint128_t); + std::pair Lookup(uint128_t) const; +}; + +class KmprtSimpleHashing + : public KmprtDoubleHashing< + std::unordered_map> { + public: + KmprtSimpleHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(std::pair); +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc new file mode 100644 index 00000000..3e777e6a --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc @@ -0,0 +1,143 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" + +#include + +#include "psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "psi/utils/communication.h" +#include "psi/utils/sync.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace { + +constexpr uint32_t kLinkRecvTimeout = 60 * 60 * 1000; + +} // namespace + +NmpParty::NmpParty(const Options& options) : options_{options} { + auto [ctx, wsize, me, leader] = CollectContext(); + ctx->SetRecvTimeout(kLinkRecvTimeout); + p2p_.resize(wsize); + for (size_t dst{}; dst != wsize; ++dst) { + if (me != dst) { + p2p_[dst] = CreateP2PLinkCtx("el_mp_psi", ctx, dst); + } + } +} + +std::vector NmpParty::Run( + const std::vector& inputs) { + auto [ctx, wsize, me, leader] = CollectContext(); + auto counts = AllGatherItemsSize(ctx, inputs.size()); + size_t count{}; + for (auto cnt : counts) { + if (cnt == 0) { + return {}; + } + count = std::max(cnt, count); + } + + auto items = EncodeInputs(inputs, count); + auto shares = ZeroSharing(count); + auto recv_share = ConvertShares(items, count, shares); + auto recons = Reconstruct(items, count, recv_share); + std::vector intersection; + for (size_t i{}; i != recons.size(); ++i) { + // SPDLOG_INFO("recons[i] = {}, size{}, i{}", + // recons[i], recons.size(), i); + std::stringstream ss; + ss << recons[i]; + intersection.push_back(ss.str()); + } + return intersection; +} + +std::vector NmpParty::EncodeInputs( + const std::vector& inputs, size_t count) const { + std::vector items; + items.reserve(count); + std::transform(inputs.begin(), inputs.end(), std::back_inserter(items), + [](std::string_view input) { + /* SPDLOG_INFO("input {},encode {} size {}",input, + yacl::crypto::Blake3_128(input), input.size()); */ + return yacl::crypto::Blake3_128(input); + }); + + std::generate_n(std::back_inserter(items), count - inputs.size(), + yacl::crypto::FastRandU128); + return items; +} + +auto NmpParty::ZeroSharing(size_t count) const -> std::vector { + auto [ctx, wsize, me, leader] = CollectContext(); + std::vector shares(wsize, Share(count)); + for (size_t k{}; k != count; ++k) { + uint64_t sum{}; + for (size_t dst{1}; dst != wsize; ++dst) { + sum ^= shares[dst][k] = yacl::crypto::FastRandU64(); + } + shares[0][k] = sum; + } + return shares; +} + +std::vector NmpParty::ConvertShares( + const std::vector& items, size_t count, + const std::vector& shares) const { + count = count + 1; + auto [ctx, wsize, me, leader] = CollectContext(); + std::vector recv_shares(count); + + for (size_t id{}; id != me; ++id) { + ElSend(p2p_[id], items, shares[id]); + return ElRecv(p2p_[id], items); + } + + for (size_t id{me + 1}; id != wsize; ++id) { + auto ret = ElRecv(p2p_[id], items); + ElSend(p2p_[id], items, shares[id]); + return ret; + } + + return recv_shares; +} + +std::vector findSame(const std::vector& nLeft, + const std::vector& nRight) { + std::vector nResult; + for (std::vector::const_iterator nIterator = nLeft.begin(); + nIterator != nLeft.end(); nIterator++) { + if (std::find(nRight.begin(), nRight.end(), *nIterator) != nRight.end()) { + nResult.push_back(0); + } else { + nResult.push_back(*nIterator); + } + } + + return nResult; +} + +std::vector NmpParty::Reconstruct( + const std::vector& items, size_t count, + const std::vector& shares) const { + count = count + 1; + return findSame(items, shares); +} +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi.h b/psi/psi21_experiment/el_mp_psi/el_mp_psi.h new file mode 100644 index 00000000..6ecc369b --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi.h @@ -0,0 +1,66 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Practical Multi-party Private Set Intersection from Symmetric-Key Techniques +// https://eprint.iacr.org/2017/799.pdf + +class NmpParty { + public: + struct Options { + std::shared_ptr link_ctx; + size_t leader_rank; + }; + + NmpParty(const Options& options); + virtual std::vector Run(const std::vector& inputs); + + private: + using Share = std::vector; + + std::vector EncodeInputs(const std::vector& inputs, + size_t count) const; + std::vector ZeroSharing(size_t count) const; + std::vector ConvertShares(const std::vector& items, + size_t count, + const std::vector& shares) const; + std::vector Reconstruct( + const std::vector& items, size_t count, + const std::vector& shares) const; + // std::vector SwapShares(const std::vector& items, + // const std::vector& shares) const; + // auto Reconstruct(const std::vector& items, + // const Share& share) const; + + // (ctx, world_size, my_rank, leader_rank) + auto CollectContext() const { + return std::make_tuple(options_.link_ctx, options_.link_ctx->WorldSize(), + options_.link_ctx->Rank(), options_.leader_rank); + } + + Options options_; + std::vector> p2p_; +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc new file mode 100644 index 00000000..61fc18c4 --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc @@ -0,0 +1,85 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "benchmark/benchmark.h" +#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" +#include "psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "yacl/base/exception.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/link/test_util.h" + +namespace { +std::vector CreateRangeItems(size_t begin, size_t size) { + std::vector ret(size); + for (size_t i = 0; i < size; i++) { + auto hash = yacl::crypto::Blake3(std::to_string(begin + i)); + memcpy(&ret[i], hash.data(), sizeof(uint128_t)); + } + return ret; +} + +void ElMpPsiSend(const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + // auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); + // return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); + std::vector shares; + for (size_t i = 0; i < items_hash.size(); i++) { + uint64_t item = 0; + shares.push_back(item); + } + + return psi::psi::ElSend(link_ctx, items_hash, shares); +} + +std::vector ElMpPsiRecv( + const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + return psi::psi::ElRecv(link_ctx, items_hash); +} + +} // namespace + +static void BM_El_Mp_Psi(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + size_t n = state.range(0); + auto alice_items = CreateRangeItems(1, n); + auto bob_items = CreateRangeItems(2, n); + + auto contexts = yacl::link::test::SetupWorld(2); + + state.ResumeTiming(); + + std::future kkrt_psi_sender = + std::async([&] { return ElMpPsiSend(contexts[0], alice_items); }); + std::future> kkrt_psi_receiver = + std::async([&] { return ElMpPsiRecv(contexts[1], bob_items); }); + + kkrt_psi_sender.get(); + auto results_b = kkrt_psi_receiver.get(); + } +} + +// [256k, 512k, 1m, 2m, 4m, 8m] +BENCHMARK(BM_El_Mp_Psi) + ->Unit(benchmark::kMillisecond) + ->Arg(256 << 10) + ->Arg(512 << 10) + ->Arg(1 << 20) + ->Arg(2 << 20) + ->Arg(4 << 20) + ->Arg(8 << 20); diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc new file mode 100644 index 00000000..d951f153 --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc @@ -0,0 +1,149 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "psi/utils/test_utils.h" +#include "yacl/link/test_util.h" + +namespace psi::psi { + +namespace { + +struct NMpTestParams { + std::vector item_size; + size_t intersection_size; +}; + +std::vector> CreateNPartyItems( + const NMpTestParams& params) { + std::vector> ret(params.item_size.size() + 1); + ret[params.item_size.size()] = + test::CreateRangeItems(1, params.intersection_size); + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + ret[idx] = + test::CreateRangeItems((idx + 1) * 1000000, params.item_size[idx]); + } + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + std::set idx_set; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, params.item_size[idx] - 1); + + while (idx_set.size() < params.intersection_size) { + idx_set.insert(dis(gen)); + } + size_t j = 0; + for (const auto& iter : idx_set) { + ret[idx][iter] = ret[params.item_size.size()][j++]; + } + } + return ret; +} + +} // namespace + +class NMpPsiTest : public testing::TestWithParam {}; + +// FIXME : this test is not stable in arm env +TEST_P(NMpPsiTest, Works) { + std::vector> items; + std::vector> resultvec; + std::vector finalresult; + + auto params = GetParam(); + items = CreateNPartyItems(params); + size_t leader_rank = 0; + uint128_t maxlength = 0; + + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + std::vector> items1; + items1.push_back(items[0]); + items1.push_back(items[i + 1]); + leader_rank = 0; + + auto ctxs = yacl::link::test::SetupWorld(2); + auto proc = [&](int idx) -> std::vector { + NmpParty::Options opts; + opts.link_ctx = ctxs[idx]; + opts.leader_rank = leader_rank; + NmpParty op(opts); + // for (size_t j{}; j != items1[idx].size(); ++j) { + // SPDLOG_INFO(" items[{}][{}] = {}, size{}", idx, i, items[idx][i], + // items[idx].size()); + // } + + return op.Run(items[idx]); + }; + + size_t world_size = ctxs.size(); + std::vector>> f_links(world_size); + for (size_t j = 0; j < world_size; j++) { + f_links[j] = std::async(proc, j); + } + sleep(1); + + std::vector result; + result = f_links[0].get(); + resultvec.push_back(result); + + // for (size_t j = 0; j < result.size() ; j++) { + // SPDLOG_INFO("i{} j{}, result[j] {} size{}",i,j,result[j], + // result.size()); + // } + } + + maxlength = items[0].size(); + for (size_t j = 0; j < maxlength; j++) { + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + if (resultvec[i].size() <= j) { + break; + } + if (resultvec[i][j] != "0") { + break; + } + + if (i == params.item_size.size() - 2) { + finalresult.push_back(items[0][j]); + } + } + } + + std::vector intersection = items[params.item_size.size()]; + std::sort(intersection.begin(), intersection.end()); + + std::sort(finalresult.begin(), finalresult.end()); + EXPECT_EQ(finalresult.size(), intersection.size()); + EXPECT_EQ(finalresult, intersection); +} + +INSTANTIATE_TEST_SUITE_P( + Works_Instances, NMpPsiTest, + testing::Values(NMpTestParams{{0, 3}, 0}, // + NMpTestParams{{3, 0}, 0}, // + NMpTestParams{{0, 0}, 0}, // + NMpTestParams{{4, 3}, 2}, // + NMpTestParams{{20, 17, 14}, 10}, // + NMpTestParams{{20, 17, 14, 30}, 10}, // + NMpTestParams{{20, 17, 14, 30, 35}, 11}, // + NMpTestParams{{20, 17, 14, 30, 35}, 0})); + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_sender.cc b/psi/psi21_experiment/el_mp_psi/el_sender.cc new file mode 100644 index 00000000..672e3dd8 --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_sender.cc @@ -0,0 +1,116 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_mp_psi/el_sender.h" + +#include +#include +#include + +#include "psi/psi21_experiment/el_mp_psi/el_hashing.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/iknp_ote.h" +#include "yacl/kernel/algorithms/kkrt_ote.h" +#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace yc = yacl::crypto; + +namespace { + +// PSI-related constants +// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) +constexpr float ZETA[]{1.12f, 0.17f}; +// constexpr size_t BETA[]{31, 63}; +constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs + +// OTe-related constants +constexpr size_t NUM_BASE_OT{128}; +constexpr size_t NUM_INKP_OT{512}; +constexpr size_t BATCH_SIZE{896}; + +static auto ro = yc::RandomOracle::GetDefault(); + +} // namespace + +// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) + +std::vector ElRecv( + const std::shared_ptr& ctx, + const std::vector& queries) { + const size_t size{queries.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtReceiver receiver; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto choice = yc::RandBits(NUM_BASE_OT); + auto base_ot = yc::BaseOtRecv(ctx, choice, NUM_BASE_OT); + auto store = yc::IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); + receiver.Init(ctx, store, num_ot); + receiver.SetBatchSize(BATCH_SIZE); + + uint128_t nonce = yacl::DeserializeUint128( + ctx->Recv(ctx->NextRank(), "Receive OPPRF nonce")); + uint128_t xssize = yacl::DeserializeUint128( + ctx->Recv(ctx->NextRank(), "Receive OPPRF Xssize")); + auto buf = ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); + std::vector table(xssize); + std::memcpy(table.data(), buf.data(), table.size() * sizeof(uint128_t)); + // todo + for (size_t i{}; i != table.size(); ++i) { + table[i] = table[i] + (nonce >> 10); + // SPDLOG_INFO(" table[i] = {}, size{}", + // table[i], xssize); + } + + return table; +} + +void ElSend(const std::shared_ptr& ctx, + const std::vector& xs, + const std::vector& ys) { + YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); + const size_t size{xs.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtSender sender; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto base_ot = yc::BaseOtSend(ctx, NUM_BASE_OT); + auto choice = yc::RandBits(NUM_INKP_OT); + auto store = yc::IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); + std::vector zs; + sender.Init(ctx, store, num_ot); + sender.SetBatchSize(BATCH_SIZE); + + uint128_t nonce; + nonce = yc::FastRandSeed(); + for (size_t i{}; i != xs.size(); ++i) { + zs.push_back(xs[i] - (nonce >> 10)); + } + ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(nonce), + fmt::format("OPPRF:Nonce={}", nonce)); + uint128_t xssize = zs.size(); + ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(xssize), + fmt::format("OPPRF:Xssize={}", xs.size())); + yacl::Buffer buf(zs.data(), zs.size() * sizeof(uint128_t)); + ctx->SendAsync(ctx->NextRank(), buf, "OPPRF:EncryptionTable"); +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_mp_psi/el_sender.h b/psi/psi21_experiment/el_mp_psi/el_sender.h new file mode 100644 index 00000000..82feac81 --- /dev/null +++ b/psi/psi21_experiment/el_mp_psi/el_sender.h @@ -0,0 +1,35 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) + +std::vector ElRecv(const std::shared_ptr&, + const std::vector& queries); + +void ElSend(const std::shared_ptr&, + const std::vector& xs, + const std::vector& ys); + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/BUILD.bazel b/psi/psi21_experiment/el_q_psi/BUILD.bazel new file mode 100644 index 00000000..ed5299e8 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/BUILD.bazel @@ -0,0 +1,61 @@ +# Copyright 2024 zhangwfjh +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_binary( + name = "el_q_psi_benchmark", + srcs = ["el_q_psi_benchmark.cc"], + deps = [ + ":el_q_psi", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +psi_cc_library( + name = "el_q_psi", + srcs = [ + "el_hashing.cc", + "el_q_psi.cc", + "el_opprf.cc", + ], + hdrs = [ + "el_hashing.h", + "el_q_psi.h", + "el_opprf.h", + ], + deps = [ + "//psi/utils:communication", + "//psi/utils:sync", + "//psi/utils:test_utils", + "@com_google_absl//absl/types:span", + "@yacl//yacl/base:exception", + "@yacl//yacl/base:int128", + "@yacl//yacl/crypto/hash:hash_utils", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/kernel/algorithms:base_ot", + "@yacl//yacl/kernel/algorithms:iknp_ote", + "@yacl//yacl/kernel/algorithms:kkrt_ote", + "@yacl//yacl/link", + ], +) + +psi_cc_test( + name = "el_q_psi_test", + srcs = ["el_q_psi_test.cc"], + tags = ["manual"], + deps = [":el_q_psi"], +) diff --git a/psi/psi21_experiment/el_q_psi/README.md b/psi/psi21_experiment/el_q_psi/README.md new file mode 100644 index 00000000..0426ae00 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/README.md @@ -0,0 +1,10 @@ +论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI + +论文地址:https://www.xueshufan.com/publication/3150904314 + +方案概括:quorum_psi
+1、参与方分别对持有的隐私集合中的元素进行分桶,通过不经意伪随机函数协议为元素计算伪随机函数值;
+2、发送方为元素的伪随机函数值生成不经意键值存储,发送给接收方;
+3、接收方将桶中元素的伪随机函数值与不经意键值存储进行异或操作,生成桶向量;
+4、将接收方的桶向量及发送方选取的向量输入到零分享测试电路中,测试所有输入是否为零的加性秘密共享,生成每个参与方的比特秘密共享;
+5、基于所有参与方的比特秘密共享,通过计算对称函数的电路得到在交集上的对称函数。
\ No newline at end of file diff --git a/psi/psi21_experiment/el_q_psi/el_hashing.cc b/psi/psi21_experiment/el_q_psi/el_hashing.cc new file mode 100644 index 00000000..e29aab62 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_hashing.cc @@ -0,0 +1,69 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_q_psi/el_hashing.h" + +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/rand/rand.h" + +namespace psi::psi { + +void KmprtCuckooHashing::Insert(uint128_t elem) { + auto insert_into = [this, &elem](uint8_t c) { + for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) { + uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c]; + uint8_t idx = (rand_idx + 1) % num_hashes_[c]; + size_t addr; + do { + addr = HashU128{}(elem, idx) % num_bins_[c]; + if (auto &bin = bins_[c][addr]; bin == NONE || bin == elem) { + bin = std::exchange(elem, NONE); + return; + } + idx = (idx + 1) % num_hashes_[c]; + } while (idx != rand_idx); + std::swap(bins_[c][addr], elem); + } + }; + for (uint8_t c{}; c != 2; ++c) { + insert_into(c); + } + YACL_ENFORCE_EQ(elem, NONE, "Failed to insert element."); +} + +auto KmprtCuckooHashing::Lookup(uint128_t elem) const + -> std::pair { + for (uint8_t c{}; c != 2; ++c) { + for (uint8_t idx{}; idx != num_hashes_[c]; ++idx) { + if (size_t addr = HashU128{}(elem, idx) % num_bins_[c]; + bins_[c][addr] == elem) { + return {c, addr}; + } + } + } + return {-1, -1}; +} + +void KmprtSimpleHashing::Insert(std::pair point) { + for (uint8_t c{}; c != 2; ++c) { + for (size_t idx{}; idx != num_hashes_[c]; ++idx) { + bins_[c][HashU128{}(point.first, idx) % num_bins_[c]].emplace(point); + } + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_hashing.h b/psi/psi21_experiment/el_q_psi/el_hashing.h new file mode 100644 index 00000000..be80df99 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_hashing.h @@ -0,0 +1,65 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "yacl/base/int128.h" + +namespace psi::psi { + +struct HashU128 { + size_t operator()(uint128_t x, uint8_t idx = 0) const { + return absl::Uint128High64(x) + idx * absl::Uint128Low64(x); + } +}; + +template +struct KmprtDoubleHashing { + KmprtDoubleHashing(size_t m1, size_t m2) : num_bins_{m1, m2} { + bins_[0].resize(m1); + bins_[1].resize(m2); + } + + const Bin &GetBin(uint8_t c, size_t addr) const { return bins_[c][addr]; } + + const uint8_t num_hashes_[2]{3, 2}; + const size_t num_bins_[2]; + std::vector bins_[2]; +}; + +class KmprtCuckooHashing : public KmprtDoubleHashing { + public: + constexpr static uint128_t NONE{}; + + KmprtCuckooHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(uint128_t); + std::pair Lookup(uint128_t) const; +}; + +class KmprtSimpleHashing + : public KmprtDoubleHashing< + std::unordered_map> { + public: + KmprtSimpleHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} + + void Insert(std::pair); +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_opprf.cc b/psi/psi21_experiment/el_q_psi/el_opprf.cc new file mode 100644 index 00000000..d6a9742b --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_opprf.cc @@ -0,0 +1,189 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_q_psi/el_opprf.h" + +#include +#include +#include + +#include "psi/psi21_experiment/el_q_psi/el_hashing.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/iknp_ote.h" +#include "yacl/kernel/algorithms/kkrt_ote.h" +#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace yc = yacl::crypto; + +namespace { + +// PSI-related constants +// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) +constexpr float ZETA[]{1.12f, 0.17f}; +// constexpr size_t BETA[]{31, 63}; +constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs + +// OTe-related constants +constexpr size_t NUM_BASE_OT{128}; +constexpr size_t NUM_INKP_OT{512}; +constexpr size_t BATCH_SIZE{896}; + +static auto ro = yc::RandomOracle::GetDefault(); + +} // namespace + +// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) + +std::vector ElOpprfRecv( + const std::shared_ptr& ctx, + const std::vector& queries) { + const size_t size{queries.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtReceiver receiver; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto choice = yc::RandBits(NUM_BASE_OT); + auto base_ot = yc::BaseOtRecv(ctx, choice, NUM_BASE_OT); + auto store = yc::IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); + receiver.Init(ctx, store, num_ot); + receiver.SetBatchSize(BATCH_SIZE); + + // Step 1. Hashes queries into Cuckoo hashing + KmprtCuckooHashing hashing{bin_sizes[0], bin_sizes[1]}; + for (size_t i{}; i != size; ++i) { + hashing.Insert(queries[i]); + } + + std::vector evals; + evals.reserve(num_ot); + size_t ot_idx{}, b{}; + std::array batch_evals; + // Step 2. For each bin, invokes single-query OPPRF + for (uint8_t c{}; c != 2; ++c) { + size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; + size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; + for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { + auto elem = hashing.GetBin(c, addr); + elem == KmprtCuckooHashing::NONE && (elem = yc::FastRandU128()); + receiver.Encode( + ot_idx, elem, + {reinterpret_cast(&batch_evals[b++]), sizeof(uint64_t)}); + if (auto batch_size = (ot_idx - ot_begin) % BATCH_SIZE + 1; + batch_size == BATCH_SIZE || ot_idx + 1 == ot_end) { + b = 0; + receiver.SendCorrection(ctx, batch_size); + // For each query in a batch + for (size_t i{}; i != batch_size; ++i) { + uint128_t nonce = yacl::DeserializeUint128( + ctx->Recv(ctx->NextRank(), "Receive OPPRF nonce")); + std::vector table(TABLE_SIZE[c]); + auto buf = + ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); + std::memcpy(table.data(), buf.data(), + TABLE_SIZE[c] * sizeof(uint64_t)); + uint64_t eval = batch_evals[i]; + auto index = + ro.Gen(absl::MakeSpan(reinterpret_cast(&eval), + sizeof eval), + nonce) % + table.size(); + evals.emplace_back(eval ^ table[index]); + } + } + } + } + + // Step 3. Filters and obtains the results + std::vector results(size); + std::transform(queries.cbegin(), queries.cend(), results.begin(), + [&](auto q) { + auto [c, addr] = hashing.Lookup(q); + return evals[c * bin_sizes[0] + addr]; + }); + return results; +} + +void ElOpprfSend(const std::shared_ptr& ctx, + const std::vector& xs, + const std::vector& ys) { + YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); + const size_t size{xs.size()}; + const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), + static_cast(std::ceil(size * ZETA[1]))}; + // Step 0. Prepares OPRF + yc::KkrtOtExtSender sender; + size_t num_ot{bin_sizes[0] + bin_sizes[1]}; + auto base_ot = yc::BaseOtSend(ctx, NUM_BASE_OT); + auto choice = yc::RandBits(NUM_INKP_OT); + auto store = yc::IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); + sender.Init(ctx, store, num_ot); + sender.SetBatchSize(BATCH_SIZE); + + // Step 1. Hashes points into Simple hashing + KmprtSimpleHashing hashing{bin_sizes[0], bin_sizes[1]}; + for (size_t i{}; i != size; ++i) { + hashing.Insert({xs[i], ys[i]}); + } + size_t ot_idx{}; + auto evaluator = sender.GetOprf(); + // Step 2. For each bin, invokes single-query OPPRF + for (uint8_t c{}; c != 2; ++c) { + size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; + size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; + // For each programmable point in a batch + for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { + if ((ot_idx - ot_begin) % BATCH_SIZE == 0) { + sender.RecvCorrection( + ctx, ot_idx + BATCH_SIZE >= ot_end ? ot_end - ot_idx : BATCH_SIZE); + } + auto bin = hashing.GetBin(c, addr); + uint128_t nonce; + std::vector table; + bool separable; + do { + separable = true; + nonce = yc::FastRandSeed(); + table.assign(TABLE_SIZE[c], uint64_t{0}); + for (auto it = bin.cbegin(); it != bin.cend(); ++it) { + uint64_t eval = evaluator->Eval(ot_idx, it->first); + auto index = + ro.Gen({reinterpret_cast(&eval), sizeof eval}, + nonce) % + table.size(); + if (table[index] != uint64_t{0}) { + separable = false; + break; + } + table[index] = eval ^ it->second; + } + } while (!separable); + for (size_t i{}; i != TABLE_SIZE[c]; ++i) { + table[i] == uint64_t{0} && (table[i] = yc::FastRandU128()); + } + + ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(nonce), + fmt::format("OPPRF:Nonce={}", nonce)); + yacl::Buffer buf(table.data(), table.size() * sizeof(uint64_t)); + ctx->SendAsync(ctx->NextRank(), buf, "OPPRF:EncryptionTable"); + } + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_opprf.h b/psi/psi21_experiment/el_q_psi/el_opprf.h new file mode 100644 index 00000000..a26a095c --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_opprf.h @@ -0,0 +1,35 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) + +std::vector ElOpprfRecv(const std::shared_ptr&, + const std::vector& queries); + +void ElOpprfSend(const std::shared_ptr&, + const std::vector& xs, + const std::vector& ys); + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi.cc b/psi/psi21_experiment/el_q_psi/el_q_psi.cc new file mode 100644 index 00000000..00026102 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_q_psi.cc @@ -0,0 +1,165 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" + +#include + +#include "psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "psi/utils/communication.h" +#include "psi/utils/sync.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace { + +constexpr uint32_t kLinkRecvTimeout = 60 * 60 * 1000; + +} // namespace + +NcParty::NcParty(const Options& options) : options_{options} { + auto [ctx, wsize, me, leader] = CollectContext(); + ctx->SetRecvTimeout(kLinkRecvTimeout); + p2p_.resize(wsize); + for (size_t dst{}; dst != wsize; ++dst) { + if (me != dst) { + p2p_[dst] = CreateP2PLinkCtx("el_q_psi", ctx, dst); + } + } +} + +std::vector NcParty::Run( + const std::vector& inputs) { + auto [ctx, wsize, me, leader] = CollectContext(); + auto counts = AllGatherItemsSize(ctx, inputs.size()); + size_t count{}; + + for (auto cnt : counts) { + if (cnt == 0) { + return {}; + } + count = std::max(cnt, count); + } + auto items = EncodeInputs(inputs, count); + auto shares = ZeroSharing(count); + auto recv_share = SwapShares(items, shares); + auto recons = Reconstruct(items, recv_share); + std::vector intersection; + for (size_t k{}; k != count; ++k) { + if (recons[k] == 0) { + intersection.push_back("1"); + } else { + intersection.push_back("0"); + } + } + return intersection; +} + +std::vector NcParty::EncodeInputs( + const std::vector& inputs, size_t count) const { + std::vector items; + items.reserve(count); + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(items), + [](std::string_view input) { return yacl::crypto::Blake3_128(input); }); + // Add random dummy elements + std::generate_n(std::back_inserter(items), count - inputs.size(), + yacl::crypto::FastRandU128); + return items; +} + +auto NcParty::ZeroSharing(size_t count) const -> std::vector { + auto [ctx, wsize, me, leader] = CollectContext(); + std::vector shares(wsize, Share(count)); + for (size_t k{}; k != count; ++k) { + uint64_t sum{}; + for (size_t dst{1}; dst != wsize; ++dst) { + sum ^= shares[dst][k] = yacl::crypto::FastRandU64(); + } + shares[0][k] = sum; + } + return shares; +} + +auto NcParty::SwapShares(const std::vector& items, + const std::vector& shares) const -> Share { + auto [ctx, wsize, me, leader] = CollectContext(); + auto count = shares.front().size(); + std::vector recv_shares(count); + std::vector> futures(wsize); + // NOTE: First Send Then Receive for peers of smaller ranks + for (size_t id{}; id != me; ++id) { + futures[id] = std::async( + [&](size_t id) { + ElOpprfSend(p2p_[id], items, shares[id]); + return ElOpprfRecv(p2p_[id], items); + }, + id); + } + // NOTE: First Receive Then Send for peers of larger ranks + for (size_t id{me + 1}; id != wsize; ++id) { + futures[id] = std::async( + [&](size_t id) { + auto ret = ElOpprfRecv(p2p_[id], items); + ElOpprfSend(p2p_[id], items, shares[id]); + return ret; + }, + id); + } + for (size_t id{}; id != wsize; ++id) { + recv_shares[id] = (me == id ? shares[id] : futures[id].get()); + } + + Share share(count); // S(x_k) + for (size_t k{}; k != count; ++k) { + for (size_t src{}; src != wsize; ++src) { + share[k] ^= recv_shares[src][k]; + } + } + return share; +} + +auto NcParty::Reconstruct(const std::vector& items, + const Share& share) const -> Share { + auto [ctx, wsize, me, leader] = CollectContext(); + auto count = items.size(); + if (me == leader) { + std::vector recv_shares(count); + std::vector> futures(wsize); + for (size_t src{}; src != wsize; ++src) { + if (me != src) { + futures[src] = std::async( + [&](size_t src) { return ElOpprfRecv(p2p_[src], items); }, src); + } + } + for (size_t src{}; src != wsize; ++src) { + recv_shares[src] = (me == src ? share : futures[src].get()); + } + Share recons(count); // sum of S_i(x_k) over i + for (size_t k{}; k != count; ++k) { + for (size_t src{}; src != wsize; ++src) { + recons[k] ^= recv_shares[src][k]; + } + } + return recons; + } else { + ElOpprfSend(p2p_[leader], items, share); + return share; + } +} + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi.h b/psi/psi21_experiment/el_q_psi/el_q_psi.h new file mode 100644 index 00000000..35f3026d --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_q_psi.h @@ -0,0 +1,60 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { + +// Practical Multi-party Private Set Intersection from Symmetric-Key Techniques +// https://eprint.iacr.org/2017/799.pdf + +class NcParty { + public: + struct Options { + std::shared_ptr link_ctx; + size_t leader_rank; + }; + + NcParty(const Options& options); + virtual std::vector Run(const std::vector& inputs); + + private: + using Share = std::vector; + + std::vector EncodeInputs(const std::vector& inputs, + size_t count) const; + std::vector ZeroSharing(size_t count) const; + Share SwapShares(const std::vector& items, + const std::vector& shares) const; + Share Reconstruct(const std::vector& items, + const Share& share) const; + + // (ctx, world_size, my_rank, leader_rank) + auto CollectContext() const { + return std::make_tuple(options_.link_ctx, options_.link_ctx->WorldSize(), + options_.link_ctx->Rank(), options_.leader_rank); + } + + Options options_; + std::vector> p2p_; +}; + +} // namespace psi::psi diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc b/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc new file mode 100644 index 00000000..42e195b3 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc @@ -0,0 +1,87 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "benchmark/benchmark.h" +#include "psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "yacl/base/exception.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/link/test_util.h" + +namespace { +std::vector CreateRangeItems(size_t begin, size_t size) { + std::vector ret(size); + for (size_t i = 0; i < size; i++) { + auto hash = yacl::crypto::Blake3(std::to_string(begin + i)); + memcpy(&ret[i], hash.data(), sizeof(uint128_t)); + } + return ret; +} + +void ElQPsiSend(const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + // auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); + // return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); + std::vector shares; + for (size_t i = 0; i < items_hash.size(); i++) { + uint64_t item = 0; + shares.push_back(item); + } + + return psi::psi::ElOpprfSend(link_ctx, items_hash, shares); +} + +std::vector ElQPsiRecv( + const std::shared_ptr& link_ctx, + const std::vector& items_hash) { + // auto ot_send = psi::kkrt::GetKkrtOtReceiverOptions(link_ctx, 512); + // return psi::kkrt::KkrtPsiRecv(link_ctx, ot_send, items_hash); + return psi::psi::ElOpprfRecv(link_ctx, items_hash); +} + +} // namespace + +static void BM_El_Q_Psi(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + size_t n = state.range(0); + auto alice_items = CreateRangeItems(1, n); + auto bob_items = CreateRangeItems(2, n); + + auto contexts = yacl::link::test::SetupWorld(2); + + state.ResumeTiming(); + + std::future kkrt_psi_sender = + std::async([&] { return ElQPsiSend(contexts[0], alice_items); }); + std::future> kkrt_psi_receiver = + std::async([&] { return ElQPsiRecv(contexts[1], bob_items); }); + + kkrt_psi_sender.get(); + auto results_b = kkrt_psi_receiver.get(); + } +} + +// [256k, 512k, 1m, 2m, 4m, 8m] +BENCHMARK(BM_El_Q_Psi) + ->Unit(benchmark::kMillisecond) + ->Arg(256 << 10) + ->Arg(512 << 10) + ->Arg(1 << 20) + ->Arg(2 << 20) + ->Arg(4 << 20) + ->Arg(8 << 20); diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc b/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc new file mode 100644 index 00000000..38ed2848 --- /dev/null +++ b/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc @@ -0,0 +1,196 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "psi/utils/test_utils.h" +#include "yacl/link/test_util.h" + +namespace psi::psi { + +namespace { + +struct NCTestParams { + std::vector item_size; + size_t intersection_size; + size_t n; +}; + +std::vector> CreateNPartyItems( + const NCTestParams& params) { + std::vector> ret(params.item_size.size() + 1); + ret[params.item_size.size()] = + test::CreateRangeItems(1, params.intersection_size); + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + ret[idx] = + test::CreateRangeItems((idx + 1) * 1000000, params.item_size[idx]); + } + + for (size_t idx = 0; idx < params.item_size.size(); ++idx) { + std::set idx_set; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, params.item_size[idx] - 1); + + while (idx_set.size() < params.intersection_size) { + idx_set.insert(dis(gen)); + } + size_t j = 0; + for (const auto& iter : idx_set) { + ret[idx][iter] = ret[params.item_size.size()][j++]; + } + + if (/*idx > dis(gen) &&*/ idx >= params.n) { + break; + } + } + return ret; +} + +} // namespace + +class NCPsiTest : public testing::TestWithParam {}; + +// FIXME : this test is not stable in arm env +TEST_P(NCPsiTest, Works) { + std::vector> items; + std::vector> resultvec; + std::vector finalresult; + + auto params = GetParam(); + size_t n = params.n; + items = CreateNPartyItems(params); + size_t leader_rank = 0; + uint128_t maxlength = 0; + + if (n >= params.item_size.size()) { + SPDLOG_INFO("param error: n > items[0].size() "); + return; + } + + if (n <= 0) { + SPDLOG_INFO("param error: n <= 0 "); + return; + } + + /* + for (size_t j{}; j != items[0].size(); ++j) { + SPDLOG_INFO(" items[{}][{}] = {}, size{}", 0, j, items[0][j], + items[0].size()); + } + */ + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + std::vector> items1; + items1.push_back(items[0]); + items1.push_back(items[i + 1]); + leader_rank = 0; + + /* + for (size_t j{}; j != items[i + 1].size(); ++j) { + SPDLOG_INFO(" items[{}][{}] = {}, size{}", i+1, j, items[i+1][j], + items[i+1].size()); + } + */ + + auto ctxs = yacl::link::test::SetupWorld(2); + auto proc = [&](int idx) -> std::vector { + NcParty::Options opts; + opts.link_ctx = ctxs[idx]; + opts.leader_rank = leader_rank; + NcParty op(opts); + + return op.Run(items[idx]); + }; + + size_t world_size = ctxs.size(); + std::vector>> f_links(world_size); + for (size_t j = 0; j < world_size; j++) { + f_links[j] = std::async(proc, j); + } + sleep(1); + + std::vector result; + result = f_links[0].get(); + resultvec.push_back(result); + + /* + for (size_t j = 0; j < result.size(); j++) { + SPDLOG_INFO("i{} j{}, result[j] {} size{}", i, j, result[j], + result.size()); + }*/ + } + + maxlength = items[0].size(); + std::vector qpsivector; + for (size_t j = 0; j < maxlength; j++) { + uint128_t sum = 0; + for (size_t i = 0; i < params.item_size.size() - 1; i++) { + // 如果有的集合没有那么多项就continue + // results[i] = f_links[i].get(); + if (resultvec[i].size() <= j) { + continue; + } + + // SPDLOG_INFO(" result[{}][{}] = {}", i, j, resultvec[i][j]); + auto it = resultvec[i].begin() + j; + std::string element = *it; + if (element == "1") { + sum++; + } + } + if (sum >= n) { + // todo//推入对应input元素 之后再查输入变量从param中怎么取出推入 + qpsivector.push_back(1); + } else { + qpsivector.push_back(0); + } + } + + // std::vector intersectionnparty; + for (size_t k{}; k != items[0].size(); ++k) { + if (qpsivector[k] == 1) { + finalresult.push_back(items[0][k]); + } + } + + for (size_t i{}; i != finalresult.size(); ++i) { + SPDLOG_INFO("intersectionnparty = {}", finalresult[i]); + } + + std::vector intersection = items[params.item_size.size()]; + std::sort(intersection.begin(), intersection.end()); + + std::sort(finalresult.begin(), finalresult.end()); + EXPECT_EQ(finalresult.size(), intersection.size()); + EXPECT_EQ(finalresult, intersection); +} + +INSTANTIATE_TEST_SUITE_P( + Works_Instances, NCPsiTest, + testing::Values(NCTestParams{{0, 3}, 0, 1}, // + NCTestParams{{3, 0}, 0, 1}, // + NCTestParams{{0, 0}, 0, 1}, // + NCTestParams{{4, 3}, 2, 1}, // + NCTestParams{{20, 17, 14}, 10, 2}, // + NCTestParams{{20, 17, 14, 30}, 10, 3}, // + NCTestParams{{20, 17, 14, 30, 35}, 11, 4}, // + NCTestParams{{20, 17, 14, 30, 35}, 0, 4})); +// testing::Values(NCTestParams{{3, 0}, 0, 1})); +} // namespace psi::psi From c0b013b1c52f96e6c149e8597087bb166a46b128 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Mon, 18 Nov 2024 10:28:41 +0800 Subject: [PATCH 2/8] Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI --- psi/psi21_experiment/el_c_psi/el_c_psi.cc | 8 ++++---- psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc | 4 ++-- psi/psi21_experiment/el_c_psi/el_c_psi_test.cc | 4 ++-- psi/psi21_experiment/el_c_psi/el_hashing.cc | 2 +- psi/psi21_experiment/el_c_psi/el_opprf.cc | 4 ++-- psi/psi21_experiment/el_mp_psi/el_hashing.cc | 2 +- psi/psi21_experiment/el_mp_psi/el_mp_psi.cc | 6 +++--- psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc | 4 ++-- psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc | 4 ++-- psi/psi21_experiment/el_mp_psi/el_sender.cc | 4 ++-- psi/psi21_experiment/el_q_psi/el_hashing.cc | 2 +- psi/psi21_experiment/el_q_psi/el_opprf.cc | 4 ++-- psi/psi21_experiment/el_q_psi/el_q_psi.cc | 8 ++++---- psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc | 4 ++-- psi/psi21_experiment/el_q_psi/el_q_psi_test.cc | 4 ++-- 15 files changed, 32 insertions(+), 32 deletions(-) diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi.cc b/psi/psi21_experiment/el_c_psi/el_c_psi.cc index ca3d7427..4d06ae3f 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi.cc +++ b/psi/psi21_experiment/el_c_psi/el_c_psi.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" #include -#include "psi/psi21_experiment/el_c_psi/el_opprf.h" -#include "psi/utils/communication.h" -#include "psi/utils/sync.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "psi/psi/utils/communication.h" +#include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/rand/rand.h" #include "yacl/utils/serialize.h" diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc b/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc index 08721c75..fb51c731 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc +++ b/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" -#include "psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc b/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc index f48b4ee3..9de3523d 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc +++ b/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" #include #include #include #include "gtest/gtest.h" -#include "psi/utils/test_utils.h" +#include "psi/psi/utils/test_utils.h" #include "yacl/link/test_util.h" namespace psi::psi { diff --git a/psi/psi21_experiment/el_c_psi/el_hashing.cc b/psi/psi21_experiment/el_c_psi/el_hashing.cc index b5432915..aa37b76b 100644 --- a/psi/psi21_experiment/el_c_psi/el_hashing.cc +++ b/psi/psi21_experiment/el_c_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_c_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_c_psi/el_opprf.cc b/psi/psi21_experiment/el_c_psi/el_opprf.cc index 1f8b6add..0f501a34 100644 --- a/psi/psi21_experiment/el_c_psi/el_opprf.cc +++ b/psi/psi21_experiment/el_c_psi/el_opprf.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" #include #include #include -#include "el_c_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_c_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_hashing.cc b/psi/psi21_experiment/el_mp_psi/el_hashing.cc index 1eafee8e..ae9aa501 100644 --- a/psi/psi21_experiment/el_mp_psi/el_hashing.cc +++ b/psi/psi21_experiment/el_mp_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_mp_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc index 3e777e6a..2b8d6647 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc @@ -16,9 +16,9 @@ #include -#include "psi/psi21_experiment/el_mp_psi/el_sender.h" -#include "psi/utils/communication.h" -#include "psi/utils/sync.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "psi/psi/utils/communication.h" +#include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/rand/rand.h" #include "yacl/utils/serialize.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc index 61fc18c4..f513ead8 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" -#include "psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_mp_psi.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc b/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc index d951f153..53d2e397 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc +++ b/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_mp_psi.h" #include #include #include #include "gtest/gtest.h" -#include "psi/utils/test_utils.h" +#include "psi/psi/utils/test_utils.h" #include "yacl/link/test_util.h" namespace psi::psi { diff --git a/psi/psi21_experiment/el_mp_psi/el_sender.cc b/psi/psi21_experiment/el_mp_psi/el_sender.cc index 672e3dd8..21d6b534 100644 --- a/psi/psi21_experiment/el_mp_psi/el_sender.cc +++ b/psi/psi21_experiment/el_mp_psi/el_sender.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" #include #include #include -#include "psi/psi21_experiment/el_mp_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_mp_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_q_psi/el_hashing.cc b/psi/psi21_experiment/el_q_psi/el_hashing.cc index e29aab62..d5a27d8c 100644 --- a/psi/psi21_experiment/el_q_psi/el_hashing.cc +++ b/psi/psi21_experiment/el_q_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_q_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_q_psi/el_opprf.cc b/psi/psi21_experiment/el_q_psi/el_opprf.cc index d6a9742b..90e5c640 100644 --- a/psi/psi21_experiment/el_q_psi/el_opprf.cc +++ b/psi/psi21_experiment/el_q_psi/el_opprf.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" #include #include #include -#include "psi/psi21_experiment/el_q_psi/el_hashing.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi.cc b/psi/psi21_experiment/el_q_psi/el_q_psi.cc index 00026102..300b115f 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi.cc +++ b/psi/psi21_experiment/el_q_psi/el_q_psi.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" #include -#include "psi/psi21_experiment/el_q_psi/el_opprf.h" -#include "psi/utils/communication.h" -#include "psi/utils/sync.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "psi/psi/utils/communication.h" +#include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/rand/rand.h" #include "yacl/utils/serialize.h" diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc b/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc index 42e195b3..83dd3944 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc +++ b/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi21_experiment/el_q_psi/el_opprf.h" -#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc b/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc index 38ed2848..386c6fa8 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc +++ b/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" #include #include #include #include "gtest/gtest.h" -#include "psi/utils/test_utils.h" +#include "psi/psi/utils/test_utils.h" #include "yacl/link/test_util.h" namespace psi::psi { From 5c642fb333de860c002ee2a9c4a3fa34a0717543 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Tue, 19 Nov 2024 19:34:59 +0800 Subject: [PATCH 3/8] add the README.md --- psi/psi21_experiment/README.md | 15 +++++++++++++++ psi/psi21_experiment/el_c_psi/README.md | 10 ---------- psi/psi21_experiment/el_mp_psi/README.md | 8 -------- psi/psi21_experiment/el_q_psi/README.md | 10 ---------- 4 files changed, 15 insertions(+), 28 deletions(-) create mode 100644 psi/psi21_experiment/README.md delete mode 100644 psi/psi21_experiment/el_c_psi/README.md delete mode 100644 psi/psi21_experiment/el_mp_psi/README.md delete mode 100644 psi/psi21_experiment/el_q_psi/README.md diff --git a/psi/psi21_experiment/README.md b/psi/psi21_experiment/README.md new file mode 100644 index 00000000..6afef00e --- /dev/null +++ b/psi/psi21_experiment/README.md @@ -0,0 +1,15 @@ +# Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI + +(https://www.xueshufan.com/publication/3150904314) + +  Our protocols for all three problem settings, namely, mPSI, circuit PSI and qPSI, broadly have two phases. At a high level, in the first phase, a fixed designated party, say $P_1$, interacts with all other parties $P_2,...,P_n$ using 2-party protocols. In the second phase, all parties engage in 𝑛-party protocols to compute a circuit to get the requisite output.
+**multiparty_psi**
+  For mPSI, in the first phase, we invoke a two-party functionality, which we call weak private set membership (wPSM) functionality, +between a leader, $P_1$ and each $P_i$ (for $i ∈ {2, · · · , 𝑛}$). Informally, the wPSM functionality, when invoked between $P_1$ and $P_i$ on their individual private sets does the following: for each element in $P_1$’s set, it outputs the same random value to both $P_1$ and $P_i$ , if that element is in $P_i$’s set, and outputs independent random values, otherwise. By invoking only n instances ofthe wPSM functionality overall, we ensure that the total communication complexity of this phase is linear in n.
+  In the second phase, all the parties together run a secure multiparty computation to obtain shares of 0 for each element in $P_1$’s set that is in the intersection and shares of a random element for other elements.
+**circuit_psi**
+  For circuit psi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if $P_1$ and $P_i$ received the same random value, and shares of 0, otherwise.
+  In the second phase, for every element of $P_1$, all parties must get shares of 1 if that element belongs to the intersection, and shares of 0, otherwise.
+**quorum_psi**
+  For qpsi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if $P_1$ and $P_i$ received the same random value, and shares of 0, otherwise.
+  In the second phase, we appropriately choose another polynomial such that for each element in $P_1$’s set, the polynomial evaluates to 0 if and only if that element belongs to the quorum intersection, and random otherwise.
\ No newline at end of file diff --git a/psi/psi21_experiment/el_c_psi/README.md b/psi/psi21_experiment/el_c_psi/README.md deleted file mode 100644 index f3dd2202..00000000 --- a/psi/psi21_experiment/el_c_psi/README.md +++ /dev/null @@ -1,10 +0,0 @@ -论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI - -论文地址:https://www.xueshufan.com/publication/3150904314 - -方案概括:circuit_psi
-1、参与方分别对持有的隐私集合中的元素进行分桶,通过不经意伪随机函数协议为元素计算伪随机函数值;
-2、发送方为元素的伪随机函数值生成不经意键值存储,发送给接收方;
-3、接收方将桶中元素的伪随机函数值与不经意键值存储进行异或操作,生成桶向量;
-4、将接收方的桶向量及发送方选取的向量输入到零分享测试电路中,测试所有输入是否为零的加性秘密共享,生成每个参与方的比特秘密共享;
-5、基于所有参与方的比特秘密共享,通过计算对称函数的电路得到在交集上的对称函数。
diff --git a/psi/psi21_experiment/el_mp_psi/README.md b/psi/psi21_experiment/el_mp_psi/README.md deleted file mode 100644 index 78f422ba..00000000 --- a/psi/psi21_experiment/el_mp_psi/README.md +++ /dev/null @@ -1,8 +0,0 @@ -论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI - -论文地址:https://www.xueshufan.com/publication/3150904314 - -方案概括:multiparty_psi
-1、参与者为$P_1$,...,$P_n$。选择$P_1$作为接收方,$P_2$,...,$P_n$为发送方。所用参与者共同产生三个hash函数$h_1,h_2,h_3$。$P_1$通过Cuchoo Hashing产生一个$Table_1$,$P_2,...,P_n$使用普通的三次hash产生$Table_2,...,Table_n$(就是将一个数x hash三次放在表的三个位置上)。
-2、使用PSM协议,将$P_1$作为接收方发送$Table_1$,$P_2,...,P_n$为发送方发送$Table_2,...,Table_n$。
-3、结束时$P_1$对于表中的每一个位置的元素查询了n-2次,计作$y_{ij}$,其中ij代表在$P_i$的表中查询第j个元素,$P_2$,...,$P_n$获得$w_{ij}$。$P_1$计算$_1=\sum_{i=2}^{n} -y_{ij}$,$P_2$,...,$P_n$计算$_i=w_{ij}$,若查询到元素,可以看出来所有$
$加起来为0。
\ No newline at end of file diff --git a/psi/psi21_experiment/el_q_psi/README.md b/psi/psi21_experiment/el_q_psi/README.md deleted file mode 100644 index 0426ae00..00000000 --- a/psi/psi21_experiment/el_q_psi/README.md +++ /dev/null @@ -1,10 +0,0 @@ -论文题目:Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI - -论文地址:https://www.xueshufan.com/publication/3150904314 - -方案概括:quorum_psi
-1、参与方分别对持有的隐私集合中的元素进行分桶,通过不经意伪随机函数协议为元素计算伪随机函数值;
-2、发送方为元素的伪随机函数值生成不经意键值存储,发送给接收方;
-3、接收方将桶中元素的伪随机函数值与不经意键值存储进行异或操作,生成桶向量;
-4、将接收方的桶向量及发送方选取的向量输入到零分享测试电路中,测试所有输入是否为零的加性秘密共享,生成每个参与方的比特秘密共享;
-5、基于所有参与方的比特秘密共享,通过计算对称函数的电路得到在交集上的对称函数。
\ No newline at end of file From 9990654a72a18ade0826842e7620d2387b85c18a Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Tue, 19 Nov 2024 20:19:21 +0800 Subject: [PATCH 4/8] modify readme format --- psi/psi21_experiment/README.md | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/psi/psi21_experiment/README.md b/psi/psi21_experiment/README.md index 6afef00e..2e23d9cc 100644 --- a/psi/psi21_experiment/README.md +++ b/psi/psi21_experiment/README.md @@ -1,15 +1,18 @@ -# Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI +Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI (https://www.xueshufan.com/publication/3150904314) -  Our protocols for all three problem settings, namely, mPSI, circuit PSI and qPSI, broadly have two phases. At a high level, in the first phase, a fixed designated party, say $P_1$, interacts with all other parties $P_2,...,P_n$ using 2-party protocols. In the second phase, all parties engage in 𝑛-party protocols to compute a circuit to get the requisite output.
-**multiparty_psi**
-  For mPSI, in the first phase, we invoke a two-party functionality, which we call weak private set membership (wPSM) functionality, -between a leader, $P_1$ and each $P_i$ (for $i ∈ {2, · · · , 𝑛}$). Informally, the wPSM functionality, when invoked between $P_1$ and $P_i$ on their individual private sets does the following: for each element in $P_1$’s set, it outputs the same random value to both $P_1$ and $P_i$ , if that element is in $P_i$’s set, and outputs independent random values, otherwise. By invoking only n instances ofthe wPSM functionality overall, we ensure that the total communication complexity of this phase is linear in n.
-  In the second phase, all the parties together run a secure multiparty computation to obtain shares of 0 for each element in $P_1$’s set that is in the intersection and shares of a random element for other elements.
-**circuit_psi**
-  For circuit psi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if $P_1$ and $P_i$ received the same random value, and shares of 0, otherwise.
-  In the second phase, for every element of $P_1$, all parties must get shares of 1 if that element belongs to the intersection, and shares of 0, otherwise.
-**quorum_psi**
-  For qpsi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if $P_1$ and $P_i$ received the same random value, and shares of 0, otherwise.
-  In the second phase, we appropriately choose another polynomial such that for each element in $P_1$’s set, the polynomial evaluates to 0 if and only if that element belongs to the quorum intersection, and random otherwise.
\ No newline at end of file +Our protocols for all three problem settings, namely, mPSI, circuit PSI and qPSI, broadly have two phases. At a high level, in the first phase, a fixed designated party, say P1, interacts with all other parties P2,...,Pn using 2-party protocols. In the second phase, all parties engage in n-party protocols to compute a circuit to get the requisite output. + +**multiparty_psi** + +For mPSI, in the first phase, we invoke a two-party functionality, which we call weak private set membership (wPSM) functionality, +between a leader, P1 and each Pi (for i ∈ {2, · · · , n}). Informally, the wPSM functionality, when invoked between P1 and Pi on their individual private sets does the following: for each element in P1’s set, it outputs the same random value to both P1 and Pi , if that element is in Pi’s set, and outputs independent random values, otherwise. By invoking only n instances ofthe wPSM functionality overall, we ensure that the total communication complexity of this phase is linear in n. In the second phase, all the parties together run a secure multiparty computation to obtain shares of 0 for each element in P1’s set that is in the intersection and shares of a random element for other elements. + +**circuit_psi** + +For circuit psi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if P1 and Pi received the same random value, and shares of 0, otherwise. In the second phase, for every element of P1, all parties must get shares of 1 if that element belongs to the intersection, and shares of 0, otherwise. + +**quorum_psi** + +For qpsi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if P1 and Pi received the same random value, and shares of 0, otherwise. In the second phase, we appropriately choose another polynomial such that for each element in P1’s set, the polynomial evaluates to 0 if and only if that element belongs to the quorum intersection, and random otherwise. \ No newline at end of file From 98513ae5e2a433e4bbeadd9d6670699f110cac15 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Wed, 20 Nov 2024 14:56:29 +0800 Subject: [PATCH 5/8] Efficient Linear Multiparty PSI and Extension to Circuit/Quorum PSI --- {psi/psi21_experiment => experimental/psi/psi21}/README.md | 0 .../psi/psi21}/el_c_psi/BUILD.bazel | 0 .../psi/psi21}/el_c_psi/el_c_psi.cc | 4 ++-- .../psi/psi21}/el_c_psi/el_c_psi.h | 0 .../psi/psi21}/el_c_psi/el_c_psi_benchmark.cc | 4 ++-- .../psi/psi21}/el_c_psi/el_c_psi_test.cc | 2 +- .../psi/psi21}/el_c_psi/el_hashing.cc | 2 +- .../psi/psi21}/el_c_psi/el_hashing.h | 0 .../psi/psi21}/el_c_psi/el_opprf.cc | 4 ++-- .../psi/psi21}/el_c_psi/el_opprf.h | 0 .../psi/psi21}/el_mp_psi/BUILD.bazel | 0 .../psi/psi21/el_mp_psi}/el_hashing.cc | 2 +- .../psi/psi21}/el_mp_psi/el_hashing.h | 0 .../psi/psi21}/el_mp_psi/el_mp_psi.cc | 4 ++-- .../psi/psi21}/el_mp_psi/el_mp_psi.h | 0 .../psi/psi21}/el_mp_psi/el_mp_psi_benchmark.cc | 4 ++-- .../psi/psi21}/el_mp_psi/el_mp_psi_test.cc | 2 +- .../psi/psi21}/el_mp_psi/el_sender.cc | 4 ++-- .../psi/psi21}/el_mp_psi/el_sender.h | 0 .../psi/psi21}/el_q_psi/BUILD.bazel | 0 .../psi/psi21/el_q_psi}/el_hashing.cc | 2 +- .../psi/psi21}/el_q_psi/el_hashing.h | 0 .../psi/psi21}/el_q_psi/el_opprf.cc | 4 ++-- .../psi/psi21}/el_q_psi/el_opprf.h | 0 .../psi/psi21}/el_q_psi/el_q_psi.cc | 4 ++-- .../psi/psi21}/el_q_psi/el_q_psi.h | 0 .../psi/psi21}/el_q_psi/el_q_psi_benchmark.cc | 4 ++-- .../psi/psi21}/el_q_psi/el_q_psi_test.cc | 2 +- 28 files changed, 24 insertions(+), 24 deletions(-) rename {psi/psi21_experiment => experimental/psi/psi21}/README.md (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/BUILD.bazel (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_c_psi.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_c_psi.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_c_psi_benchmark.cc (96%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_c_psi_test.cc (98%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_hashing.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_hashing.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_opprf.cc (98%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_c_psi/el_opprf.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/BUILD.bazel (100%) rename {psi/psi21_experiment/el_q_psi => experimental/psi/psi21/el_mp_psi}/el_hashing.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_hashing.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_mp_psi.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_mp_psi.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_mp_psi_benchmark.cc (95%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_mp_psi_test.cc (98%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_sender.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_mp_psi/el_sender.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/BUILD.bazel (100%) rename {psi/psi21_experiment/el_mp_psi => experimental/psi/psi21/el_q_psi}/el_hashing.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_hashing.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_opprf.cc (98%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_opprf.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_q_psi.cc (97%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_q_psi.h (100%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_q_psi_benchmark.cc (96%) rename {psi/psi21_experiment => experimental/psi/psi21}/el_q_psi/el_q_psi_test.cc (99%) diff --git a/psi/psi21_experiment/README.md b/experimental/psi/psi21/README.md similarity index 100% rename from psi/psi21_experiment/README.md rename to experimental/psi/psi21/README.md diff --git a/psi/psi21_experiment/el_c_psi/BUILD.bazel b/experimental/psi/psi21/el_c_psi/BUILD.bazel similarity index 100% rename from psi/psi21_experiment/el_c_psi/BUILD.bazel rename to experimental/psi/psi21/el_c_psi/BUILD.bazel diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi.cc b/experimental/psi/psi21/el_c_psi/el_c_psi.cc similarity index 97% rename from psi/psi21_experiment/el_c_psi/el_c_psi.cc rename to experimental/psi/psi21/el_c_psi/el_c_psi.cc index 4d06ae3f..64aca255 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi.cc +++ b/experimental/psi/psi21/el_c_psi/el_c_psi.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" #include -#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "experimental/psi/psi21/el_c_psi/el_opprf.h" #include "psi/psi/utils/communication.h" #include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi.h b/experimental/psi/psi21/el_c_psi/el_c_psi.h similarity index 100% rename from psi/psi21_experiment/el_c_psi/el_c_psi.h rename to experimental/psi/psi21/el_c_psi/el_c_psi.h diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc b/experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc similarity index 96% rename from psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc rename to experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc index fb51c731..3823161d 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi_benchmark.cc +++ b/experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" -#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" +#include "experimental/psi/psi21/el_c_psi/el_opprf.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc b/experimental/psi/psi21/el_c_psi/el_c_psi_test.cc similarity index 98% rename from psi/psi21_experiment/el_c_psi/el_c_psi_test.cc rename to experimental/psi/psi21/el_c_psi/el_c_psi_test.cc index 9de3523d..3a4ca20d 100644 --- a/psi/psi21_experiment/el_c_psi/el_c_psi_test.cc +++ b/experimental/psi/psi21/el_c_psi/el_c_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_c_psi/el_c_psi.h" +#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" #include #include diff --git a/psi/psi21_experiment/el_c_psi/el_hashing.cc b/experimental/psi/psi21/el_c_psi/el_hashing.cc similarity index 97% rename from psi/psi21_experiment/el_c_psi/el_hashing.cc rename to experimental/psi/psi21/el_c_psi/el_hashing.cc index aa37b76b..9956709b 100644 --- a/psi/psi21_experiment/el_c_psi/el_hashing.cc +++ b/experimental/psi/psi21/el_c_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_c_psi/el_hashing.h" +#include "experimental/psi/psi21/el_c_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_c_psi/el_hashing.h b/experimental/psi/psi21/el_c_psi/el_hashing.h similarity index 100% rename from psi/psi21_experiment/el_c_psi/el_hashing.h rename to experimental/psi/psi21/el_c_psi/el_hashing.h diff --git a/psi/psi21_experiment/el_c_psi/el_opprf.cc b/experimental/psi/psi21/el_c_psi/el_opprf.cc similarity index 98% rename from psi/psi21_experiment/el_c_psi/el_opprf.cc rename to experimental/psi/psi21/el_c_psi/el_opprf.cc index 0f501a34..f685583b 100644 --- a/psi/psi21_experiment/el_c_psi/el_opprf.cc +++ b/experimental/psi/psi21/el_c_psi/el_opprf.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_c_psi/el_opprf.h" +#include "experimental/psi/psi21/el_c_psi/el_opprf.h" #include #include #include -#include "psi/psi/psi21_experiment/el_c_psi/el_hashing.h" +#include "experimental/psi/psi21/el_c_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_c_psi/el_opprf.h b/experimental/psi/psi21/el_c_psi/el_opprf.h similarity index 100% rename from psi/psi21_experiment/el_c_psi/el_opprf.h rename to experimental/psi/psi21/el_c_psi/el_opprf.h diff --git a/psi/psi21_experiment/el_mp_psi/BUILD.bazel b/experimental/psi/psi21/el_mp_psi/BUILD.bazel similarity index 100% rename from psi/psi21_experiment/el_mp_psi/BUILD.bazel rename to experimental/psi/psi21/el_mp_psi/BUILD.bazel diff --git a/psi/psi21_experiment/el_q_psi/el_hashing.cc b/experimental/psi/psi21/el_mp_psi/el_hashing.cc similarity index 97% rename from psi/psi21_experiment/el_q_psi/el_hashing.cc rename to experimental/psi/psi21/el_mp_psi/el_hashing.cc index d5a27d8c..41085622 100644 --- a/psi/psi21_experiment/el_q_psi/el_hashing.cc +++ b/experimental/psi/psi21/el_mp_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_q_psi/el_hashing.h" +#include "experimental/psi/psi21/el_mp_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_mp_psi/el_hashing.h b/experimental/psi/psi21/el_mp_psi/el_hashing.h similarity index 100% rename from psi/psi21_experiment/el_mp_psi/el_hashing.h rename to experimental/psi/psi21/el_mp_psi/el_hashing.h diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc similarity index 97% rename from psi/psi21_experiment/el_mp_psi/el_mp_psi.cc rename to experimental/psi/psi21/el_mp_psi/el_mp_psi.cc index 2b8d6647..b1ccb313 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi21_experiment/el_mp_psi/el_mp_psi.h" +#include "experimental/psi/psi21/el_mp_psi/el_mp_psi.h" #include -#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "experimental/psi/psi21/el_mp_psi/el_sender.h" #include "psi/psi/utils/communication.h" #include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi.h b/experimental/psi/psi21/el_mp_psi/el_mp_psi.h similarity index 100% rename from psi/psi21_experiment/el_mp_psi/el_mp_psi.h rename to experimental/psi/psi21/el_mp_psi/el_mp_psi.h diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi_benchmark.cc similarity index 95% rename from psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc rename to experimental/psi/psi21/el_mp_psi/el_mp_psi_benchmark.cc index f513ead8..7653dd10 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi_benchmark.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi/psi21_experiment/el_mp_psi/el_mp_psi.h" -#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "experimental/psi/psi21/el_mp_psi/el_mp_psi.h" +#include "experimental/psi/psi21/el_mp_psi/el_sender.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc similarity index 98% rename from psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc rename to experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc index 53d2e397..aba1dc15 100644 --- a/psi/psi21_experiment/el_mp_psi/el_mp_psi_test.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_mp_psi/el_mp_psi.h" +#include "experimental/psi/psi21/el_mp_psi/el_mp_psi.h" #include #include diff --git a/psi/psi21_experiment/el_mp_psi/el_sender.cc b/experimental/psi/psi21/el_mp_psi/el_sender.cc similarity index 97% rename from psi/psi21_experiment/el_mp_psi/el_sender.cc rename to experimental/psi/psi21/el_mp_psi/el_sender.cc index 21d6b534..1d4bedf5 100644 --- a/psi/psi21_experiment/el_mp_psi/el_sender.cc +++ b/experimental/psi/psi21/el_mp_psi/el_sender.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_mp_psi/el_sender.h" +#include "experimental/psi/psi21/el_mp_psi/el_sender.h" #include #include #include -#include "psi/psi/psi21_experiment/el_mp_psi/el_hashing.h" +#include "experimental/psi/psi21/el_mp_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_mp_psi/el_sender.h b/experimental/psi/psi21/el_mp_psi/el_sender.h similarity index 100% rename from psi/psi21_experiment/el_mp_psi/el_sender.h rename to experimental/psi/psi21/el_mp_psi/el_sender.h diff --git a/psi/psi21_experiment/el_q_psi/BUILD.bazel b/experimental/psi/psi21/el_q_psi/BUILD.bazel similarity index 100% rename from psi/psi21_experiment/el_q_psi/BUILD.bazel rename to experimental/psi/psi21/el_q_psi/BUILD.bazel diff --git a/psi/psi21_experiment/el_mp_psi/el_hashing.cc b/experimental/psi/psi21/el_q_psi/el_hashing.cc similarity index 97% rename from psi/psi21_experiment/el_mp_psi/el_hashing.cc rename to experimental/psi/psi21/el_q_psi/el_hashing.cc index ae9aa501..ba6e621a 100644 --- a/psi/psi21_experiment/el_mp_psi/el_hashing.cc +++ b/experimental/psi/psi21/el_q_psi/el_hashing.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_mp_psi/el_hashing.h" +#include "experimental/psi/psi21/el_q_psi/el_hashing.h" #include #include diff --git a/psi/psi21_experiment/el_q_psi/el_hashing.h b/experimental/psi/psi21/el_q_psi/el_hashing.h similarity index 100% rename from psi/psi21_experiment/el_q_psi/el_hashing.h rename to experimental/psi/psi21/el_q_psi/el_hashing.h diff --git a/psi/psi21_experiment/el_q_psi/el_opprf.cc b/experimental/psi/psi21/el_q_psi/el_opprf.cc similarity index 98% rename from psi/psi21_experiment/el_q_psi/el_opprf.cc rename to experimental/psi/psi21/el_q_psi/el_opprf.cc index 90e5c640..c060512c 100644 --- a/psi/psi21_experiment/el_q_psi/el_opprf.cc +++ b/experimental/psi/psi21/el_q_psi/el_opprf.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "experimental/psi/psi21/el_q_psi/el_opprf.h" #include #include #include -#include "psi/psi/psi21_experiment/el_q_psi/el_hashing.h" +#include "experimental/psi/psi21/el_q_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" #include "yacl/kernel/algorithms/base_ot.h" diff --git a/psi/psi21_experiment/el_q_psi/el_opprf.h b/experimental/psi/psi21/el_q_psi/el_opprf.h similarity index 100% rename from psi/psi21_experiment/el_q_psi/el_opprf.h rename to experimental/psi/psi21/el_q_psi/el_opprf.h diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi.cc b/experimental/psi/psi21/el_q_psi/el_q_psi.cc similarity index 97% rename from psi/psi21_experiment/el_q_psi/el_q_psi.cc rename to experimental/psi/psi21/el_q_psi/el_q_psi.cc index 300b115f..6129ecc0 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi.cc +++ b/experimental/psi/psi21/el_q_psi/el_q_psi.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" #include -#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" +#include "experimental/psi/psi21/el_q_psi/el_opprf.h" #include "psi/psi/utils/communication.h" #include "psi/psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi.h b/experimental/psi/psi21/el_q_psi/el_q_psi.h similarity index 100% rename from psi/psi21_experiment/el_q_psi/el_q_psi.h rename to experimental/psi/psi21/el_q_psi/el_q_psi.h diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc b/experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc similarity index 96% rename from psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc rename to experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc index 83dd3944..d80daf29 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi_benchmark.cc +++ b/experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc @@ -16,8 +16,8 @@ #include #include "benchmark/benchmark.h" -#include "psi/psi/psi21_experiment/el_q_psi/el_opprf.h" -#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "experimental/psi/psi21/el_q_psi/el_opprf.h" +#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" #include "yacl/base/exception.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/link/test_util.h" diff --git a/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc b/experimental/psi/psi21/el_q_psi/el_q_psi_test.cc similarity index 99% rename from psi/psi21_experiment/el_q_psi/el_q_psi_test.cc rename to experimental/psi/psi21/el_q_psi/el_q_psi_test.cc index 386c6fa8..4fcacbab 100644 --- a/psi/psi21_experiment/el_q_psi/el_q_psi_test.cc +++ b/experimental/psi/psi21/el_q_psi/el_q_psi_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "psi/psi/psi21_experiment/el_q_psi/el_q_psi.h" +#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" #include #include From 3f634f9b0d0376195274a92c1d62dcebc7538dd9 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Fri, 22 Nov 2024 15:23:06 +0800 Subject: [PATCH 6/8] modify readme format --- experimental/psi/psi21/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/psi/psi21/README.md b/experimental/psi/psi21/README.md index 2e23d9cc..4119fae2 100644 --- a/experimental/psi/psi21/README.md +++ b/experimental/psi/psi21/README.md @@ -15,4 +15,4 @@ For circuit psi, the first phase additionally includes conversion ofthe outputs **quorum_psi** -For qpsi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if P1 and Pi received the same random value, and shares of 0, otherwise. In the second phase, we appropriately choose another polynomial such that for each element in P1’s set, the polynomial evaluates to 0 if and only if that element belongs to the quorum intersection, and random otherwise. \ No newline at end of file +For qpsi, the first phase additionally includes conversion ofthe outputs from the wPSM functionality to arithmetic shares of 1 if P1 and Pi received the same random value, and shares of 0, otherwise. In the second phase, we appropriately choose another polynomial such that for each element in P1’s set, the polynomial evaluates to 0 if and only if that element belongs to the quorum intersection, and random otherwise. From fea374cde170ffadfb6bbe163b6bfa8088a6d685 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Fri, 14 Feb 2025 14:49:14 +0800 Subject: [PATCH 7/8] modify mpsi --- experimental/psi/psi21/el_mp_psi/BUILD.bazel | 6 + experimental/psi/psi21/el_mp_psi/Mersenne.cc | 79 ++ experimental/psi/psi21/el_mp_psi/Mersenne.h | 397 ++++++++ experimental/psi/psi21/el_mp_psi/el_mp_psi.cc | 12 +- experimental/psi/psi21/el_mp_psi/el_mp_psi.h | 1 + .../psi/psi21/el_mp_psi/el_mp_psi_test.cc | 2 +- .../psi/psi21/el_mp_psi/el_protocol.cc | 889 ++++++++++++++++++ .../psi/psi21/el_mp_psi/el_protocol.h | 52 + experimental/psi/psi21/el_mp_psi/el_sender.cc | 12 +- experimental/psi/psi21/el_mp_psi/el_sender.h | 9 +- 10 files changed, 1443 insertions(+), 16 deletions(-) create mode 100644 experimental/psi/psi21/el_mp_psi/Mersenne.cc create mode 100644 experimental/psi/psi21/el_mp_psi/Mersenne.h create mode 100644 experimental/psi/psi21/el_mp_psi/el_protocol.cc create mode 100644 experimental/psi/psi21/el_mp_psi/el_protocol.h diff --git a/experimental/psi/psi21/el_mp_psi/BUILD.bazel b/experimental/psi/psi21/el_mp_psi/BUILD.bazel index f4aa9260..f64e8181 100644 --- a/experimental/psi/psi21/el_mp_psi/BUILD.bazel +++ b/experimental/psi/psi21/el_mp_psi/BUILD.bazel @@ -16,9 +16,11 @@ load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") package(default_visibility = ["//visibility:public"]) +WARNING_OPTS = ["-Wall"] psi_cc_binary( name = "el_mp_psi_benchmark", srcs = ["el_mp_psi_benchmark.cc"], + copts = WARNING_OPTS, deps = [ ":el_mp_psi", "@com_github_google_benchmark//:benchmark_main", @@ -31,11 +33,15 @@ psi_cc_library( "el_hashing.cc", "el_mp_psi.cc", "el_sender.cc", + "Mersenne.cc", + "el_protocol.cc", ], hdrs = [ "el_hashing.h", "el_mp_psi.h", "el_sender.h", + "Mersenne.h", + "el_protocol.h", ], deps = [ "//psi/utils:communication", diff --git a/experimental/psi/psi21/el_mp_psi/Mersenne.cc b/experimental/psi/psi21/el_mp_psi/Mersenne.cc new file mode 100644 index 00000000..5c8ff121 --- /dev/null +++ b/experimental/psi/psi21/el_mp_psi/Mersenne.cc @@ -0,0 +1,79 @@ +// \author Avishay Yanay +// \organization Bar-Ilan University +// \email ay.yanay@gmail.com +// +// MIT License +// +// Copyright (c) 2018 AvishayYanay +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "Mersenne.h" + +template <> +ZpMersenneIntElement1 TemplateField::GetElement(long b) { + if (b == 1) { + return *m_ONE; + } + if (b == 0) { + return *m_ZERO; + } else { + ZpMersenneIntElement1 element(b); + return element; + } +} + +template <> +ZpMersenneLongElement1 TemplateField::GetElement( + long b) { + if (b == 1) { + return *m_ONE; + } + if (b == 0) { + return *m_ZERO; + } else { + ZpMersenneLongElement1 element(b); + return element; + } +} + +template <> +void TemplateField::elementToBytes( + unsigned char* elemenetInBytes, ZpMersenneIntElement1& element) { + memcpy(elemenetInBytes, (byte*)(&element.elem), 4); +} + +template <> +void TemplateField::elementToBytes( + unsigned char* elemenetInBytes, ZpMersenneLongElement1& element) { + memcpy(elemenetInBytes, (byte*)(&element.elem), 8); +} + +template <> +ZpMersenneIntElement1 TemplateField::bytesToElement( + unsigned char* elemenetInBytes) { + return ZpMersenneIntElement1((unsigned int)(*(unsigned int*)elemenetInBytes)); +} + +template <> +ZpMersenneLongElement1 TemplateField::bytesToElement( + unsigned char* elemenetInBytes) { + return ZpMersenneLongElement1( + (unsigned long)(*(unsigned long*)elemenetInBytes)); +} diff --git a/experimental/psi/psi21/el_mp_psi/Mersenne.h b/experimental/psi/psi21/el_mp_psi/Mersenne.h new file mode 100644 index 00000000..f4ba95aa --- /dev/null +++ b/experimental/psi/psi21/el_mp_psi/Mersenne.h @@ -0,0 +1,397 @@ +#pragma once + +// \author Avishay Yanay +// \organization Bar-Ilan University +// \email ay.yanay@gmail.com +// +// MIT License +// +// Copyright (c) 2018 AvishayYanay +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "NTL/ZZ.h" +#include "NTL/ZZ_p.h" +#include "gmp.h" +typedef unsigned char byte; + +#include +#include +#include + +using namespace NTL; + +class ZpMersenneIntElement1 { + // private: + public: // TODO return to private after tesing + static const unsigned int p = 2147483647; + unsigned int elem; + + public: + ZpMersenneIntElement1() { elem = 0; }; + ZpMersenneIntElement1(int elem) { + this->elem = elem; + if (this->elem < p) { + return; + } + this->elem -= p; + if (this->elem < p) { + return; + } + this->elem -= p; + } + + ZpMersenneIntElement1& operator=(const ZpMersenneIntElement1& other) { + elem = other.elem; + return *this; + }; + bool operator!=(const ZpMersenneIntElement1& other) { + return !(other.elem == elem); + }; + + ZpMersenneIntElement1 operator+(const ZpMersenneIntElement1& f2) { + ZpMersenneIntElement1 answer; + + answer.elem = (elem + f2.elem); + + if (answer.elem >= p) answer.elem -= p; + + return answer; + } + ZpMersenneIntElement1 operator-(const ZpMersenneIntElement1& f2) { + ZpMersenneIntElement1 answer; + + int temp = (int)elem - (int)f2.elem; + + if (temp < 0) { + answer.elem = temp + p; + } else { + answer.elem = temp; + } + + return answer; + } + ZpMersenneIntElement1 operator/(const ZpMersenneIntElement1& f2) { + // code taken from NTL for the function XGCD + int a = f2.elem; + int b = p; + long s; + + int u, v, q, r; + long u0, v0, u1, v1, u2, v2; + + int aneg = 0; + + if (a < 0) { + if (a < -NTL_MAX_LONG) Error("XGCD: integer overflow"); + a = -a; + aneg = 1; + } + + if (b < 0) { + if (b < -NTL_MAX_LONG) Error("XGCD: integer overflow"); + b = -b; + } + + u1 = 1; + v1 = 0; + u2 = 0; + v2 = 1; + u = a; + v = b; + + while (v != 0) { + q = u / v; + r = u % v; + u = v; + v = r; + u0 = u2; + v0 = v2; + u2 = u1 - q * u2; + v2 = v1 - q * v2; + u1 = u0; + v1 = v0; + } + + if (aneg) u1 = -u1; + + s = u1; + + if (s < 0) s = s + p; + + ZpMersenneIntElement1 inverse(s); + + return inverse * (*this); + } + + ZpMersenneIntElement1 operator*(const ZpMersenneIntElement1& f2) { + ZpMersenneIntElement1 answer; + + long multLong = (long)elem * (long)f2.elem; + + // get the bottom 31 bit + unsigned int bottom = multLong & p; + + // get the top 31 bits + unsigned int top = (multLong >> 31); + + answer.elem = bottom + top; + + // maximim the value of 2p-2 + if (answer.elem >= p) answer.elem -= p; + + // return ZpMersenneIntElement((bottom + top) %p); + return answer; + } + + ZpMersenneIntElement1& operator+=(const ZpMersenneIntElement1& f2) { + elem = (f2.elem + elem) % p; + return *this; + }; + ZpMersenneIntElement1& operator*=(const ZpMersenneIntElement1& f2) { + long multLong = (long)elem * (long)f2.elem; + + // get the bottom 31 bit + unsigned int bottom = multLong & p; + + // get the top 31 bits + unsigned int top = (multLong >> 31); + + elem = bottom + top; + + // maximim the value of 2p-2 + if (elem >= p) elem -= p; + + return *this; + } +}; + +inline std::ostream& operator<<(std::ostream& s, + const ZpMersenneIntElement1& a) { + return s << a.elem; +}; + +class ZpMersenneLongElement1 { + // private: + public: // TODO return to private after tesing + static const unsigned long long p = 2305843009213693951; + unsigned long long elem; + + ZpMersenneLongElement1() { elem = 0; }; + ZpMersenneLongElement1(unsigned long elem) { + this->elem = elem; + if (this->elem >= p) { + this->elem = (this->elem & p) + (this->elem >> 61); + + if (this->elem >= p) this->elem -= p; + } + } + + inline ZpMersenneLongElement1& operator=(const ZpMersenneLongElement1& other) + + { + elem = other.elem; + return *this; + }; + inline bool operator!=(const ZpMersenneLongElement1& other) + + { + return !(other.elem == elem); + }; + + ZpMersenneLongElement1 operator+(const ZpMersenneLongElement1& f2) { + ZpMersenneLongElement1 answer; + + answer.elem = (elem + f2.elem); + + if (answer.elem >= p) answer.elem -= p; + + return answer; + } + + ZpMersenneLongElement1 operator-(const ZpMersenneLongElement1& f2) { + ZpMersenneLongElement1 answer; + + int64_t temp = elem - f2.elem; + + if (temp < 0) { + answer.elem = temp + p; + } else { + answer.elem = temp; + } + + return answer; + } + + ZpMersenneLongElement1 operator/(const ZpMersenneLongElement1& f2) { + ZpMersenneLongElement1 answer; + mpz_t d; + mpz_t result; + mpz_t mpz_elem; + mpz_t mpz_me; + mpz_init_set_str(d, "2305843009213693951", 10); + mpz_init(mpz_elem); + mpz_init(mpz_me); + + mpz_set_ui(mpz_elem, f2.elem); + mpz_set_ui(mpz_me, elem); + + mpz_init(result); + + mpz_invert(result, mpz_elem, d); + + mpz_mul(result, result, mpz_me); + mpz_mod(result, result, d); + + answer.elem = mpz_get_ui(result); + + return answer; + } + + ZpMersenneLongElement1 operator*(const ZpMersenneLongElement1& f2) { + ZpMersenneLongElement1 answer; + + unsigned long long high; + unsigned long long low = _mulx_u64(elem, f2.elem, &high); + + unsigned long long low61 = (low & p); + unsigned long long low61to64 = (low >> 61); + unsigned long long highShift3 = (high << 3); + + unsigned long long res = low61 + low61to64 + highShift3; + + if (res >= p) res -= p; + + answer.elem = res; + + return answer; + } + + ZpMersenneLongElement1& operator+=(const ZpMersenneLongElement1& f2) { + elem = (elem + f2.elem); + + if (elem >= p) elem -= p; + + return *this; + } + + ZpMersenneLongElement1& operator*=(const ZpMersenneLongElement1& f2) { + unsigned long long high; + unsigned long long low = _mulx_u64(elem, f2.elem, &high); + + unsigned long long low61 = (low & p); + unsigned long long low61to64 = (low >> 61); + unsigned long long highShift3 = (high << 3); + + unsigned long long res = low61 + low61to64 + highShift3; + + if (res >= p) res -= p; + + elem = res; + + return *this; + } +}; + +inline std::ostream& operator<<(std::ostream& s, + const ZpMersenneLongElement1& a) { + return s << a.elem; +}; + +template +class TemplateField { + private: + long fieldParam; + int elementSizeInBytes; + int elementSizeInBits; + FieldType* m_ZERO; + FieldType* m_ONE; + + public: + /** + * the function create a field by: + * generate the irreducible polynomial x^8 + x^4 + x^3 + x + 1 to work with + * init the field with the newly generated polynomial + */ + TemplateField(long fieldParam); + + /** + * return the field + */ + + std::string elementToString(const FieldType& element); + FieldType stringToElement(const std::string& str); + + void elementToBytes(unsigned char* output, FieldType& element); + + FieldType bytesToElement(unsigned char* elemenetInBytes); + void elementVectorToByteVector(std::vector& elementVector, + std::vector& byteVector); + + FieldType* GetZero(); + FieldType* GetOne(); + + int getElementSizeInBytes() { return elementSizeInBytes; } + int getElementSizeInBits() { return elementSizeInBits; } + /* + * The i-th field element. The ordering is arbitrary, *except* that + * the 0-th field element must be the neutral w.r.t. addition, and the + * 1-st field element must be the neutral w.r.t. multiplication. + */ + FieldType GetElement(long b); + FieldType Random(); + ~TemplateField(); +}; + +template +std::string TemplateField::elementToString( + const FieldType& element) { + std::ostringstream stream; + stream << element; + std::string str = stream.str(); + return str; +} + +template +FieldType TemplateField::stringToElement(const std::string& str) { + FieldType element; + + std::istringstream iss(str); + iss >> element; + + return element; +} + +template +FieldType* TemplateField::GetZero() { + return m_ZERO; +} + +template +FieldType* TemplateField::GetOne() { + return m_ONE; +} + +template +TemplateField::~TemplateField() { + delete m_ZERO; + delete m_ONE; +} diff --git a/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc index b1ccb313..5e6be9d9 100644 --- a/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc @@ -16,9 +16,10 @@ #include +#include "experimental/psi/psi21/el_mp_psi/el_protocol.h" #include "experimental/psi/psi21/el_mp_psi/el_sender.h" -#include "psi/psi/utils/communication.h" -#include "psi/psi/utils/sync.h" +#include "psi/utils/communication.h" +#include "psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/rand/rand.h" #include "yacl/utils/serialize.h" @@ -42,8 +43,8 @@ NmpParty::NmpParty(const Options& options) : options_{options} { } } -std::vector NmpParty::Run( - const std::vector& inputs) { +// template +std::vector NmpParty::Run(const std::vector& inputs) { auto [ctx, wsize, me, leader] = CollectContext(); auto counts = AllGatherItemsSize(ctx, inputs.size()); size_t count{}; @@ -66,6 +67,9 @@ std::vector NmpParty::Run( ss << recons[i]; intersection.push_back(ss.str()); } + + randomParse(items, count, shares, p2p_, me, wsize, M, N); + return intersection; } diff --git a/experimental/psi/psi21/el_mp_psi/el_mp_psi.h b/experimental/psi/psi21/el_mp_psi/el_mp_psi.h index 6ecc369b..3d2e43f9 100644 --- a/experimental/psi/psi21/el_mp_psi/el_mp_psi.h +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi.h @@ -26,6 +26,7 @@ namespace psi::psi { // Practical Multi-party Private Set Intersection from Symmetric-Key Techniques // https://eprint.iacr.org/2017/799.pdf +// template class NmpParty { public: struct Options { diff --git a/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc index aba1dc15..38ad695d 100644 --- a/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc @@ -19,7 +19,7 @@ #include #include "gtest/gtest.h" -#include "psi/psi/utils/test_utils.h" +#include "psi/utils/test_utils.h" #include "yacl/link/test_util.h" namespace psi::psi { diff --git a/experimental/psi/psi21/el_mp_psi/el_protocol.cc b/experimental/psi/psi21/el_mp_psi/el_protocol.cc new file mode 100644 index 00000000..b0ff6e1b --- /dev/null +++ b/experimental/psi/psi21/el_mp_psi/el_protocol.cc @@ -0,0 +1,889 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "experimental/psi/psi21/el_mp_psi/el_sender.h" +// #include + +#include + +#include "experimental/psi/psi21/el_mp_psi/Mersenne.h" +#include "experimental/psi/psi21/el_mp_psi/el_protocol.h" +#include "experimental/psi/psi21/el_mp_psi/el_sender.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/kernel/algorithms/base_ot.h" +#include "yacl/kernel/algorithms/iknp_ote.h" +#include "yacl/kernel/algorithms/kkrt_ote.h" +#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" + +namespace psi::psi { + +namespace yc = yacl::crypto; + +namespace { + +// PSI-related constants +// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) +constexpr float ZETA[]{1.12f, 0.17f}; +// constexpr size_t BETA[]{31, 63}; +constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs + +// OTe-related constants +constexpr size_t NUM_BASE_OT{128}; +constexpr size_t NUM_INKP_OT{512}; +constexpr size_t BATCH_SIZE{896}; + +static auto ro = yc::RandomOracle::GetDefault(); + +} // namespace + +// template +// vector> parties; +// std::vector a_vals; + +// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) +#define flag_print false + +template +class ProtocolParty { + uint64_t num_bins; + int N, M, T, m_partyId; + TemplateField* field; + std::vector parties; + std::vector randomTAnd2TShares; + VDMTranspose matrix_vand_transpose; + VDM matrix_vand; + HIM matrix_for_interpolate; + + public: + std::vector items_; + size_t count_; + std::vector> shares_; + std::vector> p2p_; + size_t me; + size_t wsize; + + public: + void roundFunctionSync(std::vector>& sendBufs, + std::vector>& recBufs, int round); + void roundFunctionSyncForP1(std::vector& myShare, + std::vector>& recBufs); + void generateRandomShares(int numOfRandoms, + std::vector& randomElementsToFill); + void modDoubleRandom(uint64_t no_random, + std::vector& randomElementsToFill); + void generateRandom2TAndTShares(int numOfRandomPairs, + std::vector& randomElementsToFill); + void addShareOpen(uint64_t num_vals, std::vector& shares, + std::vector& secrets); + void reshare(std::vector& vals, std::vector& shares); + void additiveToThreshold(); + void sendFromP1(std::vector& sendBuf); + // void DNHonestMultiplication(FieldType *a, FieldType *b, + // std::vector &cToFill, int numOfTrupples); + FieldType interpolate(std::vector x); + void DNHonestMultiplication(std::vector a, + std::vector b, + std::vector& cToFill, + uint64_t numOfTrupples, int offset); + void mult_sj(); + uint64_t evaluateCircuit(); + void protocolIni(int M_, int N_, size_t me_id); + // void randomParse(); +}; + +template +void ProtocolParty::roundFunctionSync( + std::vector>& sendBufs, + std::vector>& recBufs, int round) { + // cout<<"in roundFunctionSync "<< round<< endl; + + int numThreads = parties.size(); + int numPartiesForEachThread; + + if (parties.size() <= numThreads) { + numThreads = parties.size(); + numPartiesForEachThread = 1; + } else { + numPartiesForEachThread = (parties.size() + numThreads - 1) / numThreads; + } + + recBufs[m_partyId] = move(sendBufs[m_partyId]); + count_ = count_ + 1; + // auto [ctx, wsize, me, leader] = CollectContext(); + // std::vector recv_shares(count); + + for (size_t id{}; id != me; ++id) { + ElSend(p2p_[id], items_, shares_[id]); + ElRecv(p2p_[id], items_); + return; + } + + for (size_t id{me + 1}; id != wsize; ++id) { + ElRecv(p2p_[id], items_); + ElSend(p2p_[id], items_, shares_[id]); + return; + } +} +template +void ProtocolParty::roundFunctionSyncForP1( + std::vector& myShare, std::vector>& recBufs) { + // cout<<"in roundFunctionSyncBroadcast "<< endl; + + int numThreads = parties.size(); + // int numThreads = 10; + int numPartiesForEachThread; + + if (parties.size() <= numThreads) { + numThreads = parties.size(); + numPartiesForEachThread = 1; + } else { + numPartiesForEachThread = (parties.size() + numThreads - 1) / numThreads; + } + + recBufs[m_partyId] = myShare; + count_ = count_ + 1; + for (size_t id{}; id != me; ++id) { + ElSend(p2p_[id], items_, shares_[id]); + ElRecv(p2p_[id], items_); + } +} + +template +void ProtocolParty::generateRandomShares( + int numOfRandoms, std::vector& randomElementsToFill) { + numOfRandoms = 100; + randomElementsToFill.resize(numOfRandoms); + int index = 0; + std::vector> recBufsBytes(N); + // std::vector> recBufsBytes(N); + int robin = 0; + int no_random = numOfRandoms; + + std::vector x1(N), y1(N), x2(N), y2(N), t1(N), r1(N), t2(N), r2(N); + ; + + std::vector> sendBufsElements(N); + std::vector> sendBufsBytes(N); + // std:: vector> sendBufsBytes(N); + + // the number of buckets (each bucket requires one double-sharing + // from each party and gives N-2T random double-sharings) + int no_buckets = (no_random / (N - T)) + 1; + + // sharingBufTElements.resize(no_buckets*(N-2*T)); // my shares of the + // double-sharings sharingBuf2TElements.resize(no_buckets*(N-2*T)); // my + // shares of the double-sharings + + // maybe add some elements if a partial bucket is needed + randomElementsToFill.resize(no_buckets * (N - T)); + + for (int i = 0; i < N; i++) { + sendBufsElements[i].resize(no_buckets); + sendBufsBytes[i].resize(no_buckets * field->getElementSizeInBytes()); + recBufsBytes[i].resize(no_buckets * field->getElementSizeInBytes()); + } + + /** + * generate random sharings. + * first degree t. + * + */ + for (int k = 0; k < no_buckets; k++) { + // generate random degree-T polynomial + for (int i = 0; i < T + 1; i++) { + // A random field element, uniform distribution, note that x1[0] is the + // secret which is also random + x1[i] = field->Random(); + } + matrix_vand.MatrixMult(x1, y1, T + 1); // eval poly at alpha-positions + + // prepare shares to be sent + for (int i = 0; i < N; i++) { + // cout << "y1[ " <getElementSizeInBytes(); + for (int i = 0; i < N; i++) { + for (int j = 0; j < sendBufsElements[i].size(); j++) { + field->elementToBytes((sendBufsBytes[i].data() + (j * fieldByteSize)), + sendBufsElements[i][j]); + } + } + + roundFunctionSync(sendBufsBytes, recBufsBytes, 4); + + if (flag_print) { + for (int i = 0; i < N; i++) { + for (int k = 0; k < sendBufsBytes[0].size(); k++) { + std::cout << "roundfunction4 send to " << i << " element: " << k << " " + << (int)sendBufsBytes[i][k] << std::endl; + } + } + for (int i = 0; i < N; i++) { + for (int k = 0; k < recBufsBytes[0].size(); k++) { + std::cout << "roundfunction4 receive from " << i << " element: " << k + << " " << (int)recBufsBytes[i][k] << std::endl; + } + } + } + + for (int k = 0; k < no_buckets; k++) { + for (int i = 0; i < N; i++) { + t1[i] = + field->bytesToElement((recBufsBytes[i].data() + (k * fieldByteSize))); + } + matrix_vand_transpose.MatrixMult(t1, r1, N - T); + + // copy the resulting vector to the array of randoms + for (int i = 0; i < N - T; i++) { + randomElementsToFill[index] = r1[i]; + index++; + } + } +} + +/* + * prepare additive and T-threshold sharings of secret random value r_j using + * DN07's protocol + */ +#define mpsi_print false +template +void ProtocolParty::modDoubleRandom( + uint64_t no_random, std::vector& randomElementsToFill) { + TemplateField* field = new TemplateField(0); + // int N = 5;//佈°彗¶佀~Y仼 佅¥ + // int T = N/2 -1;//佈°彗¶佀~Y仼 佅¥ + + // cout << this->m_partyId << ": Generating double sharings..." << endl; + int index = 0; + + std::vector x1(N), y1(N), y2(N), t1(N), r1(N), t2(N), r2(N); + + std::vector> sendBufsElements(N); + + std::vector> sendBufsBytes(N); + std::vector> recBufsBytes(N); + // the number of buckets (each bucket requires one double-sharing + // from each party and gives N-2T random double-sharings) + uint64_t no_buckets = (no_random / (N - T)) + 1; + + // int fieldByteSize = this->field->getElementSizeInBytes(); + int fieldByteSize = field->getElementSizeInBytes(); + + // maybe add some elements if a partial bucket is needed + randomElementsToFill.resize(no_buckets * (N - T) * 2); + + for (int i = 0; i < N; i++) { + sendBufsElements[i].resize(no_buckets * 2); + sendBufsBytes[i].resize(no_buckets * fieldByteSize * 2); + recBufsBytes[i].resize(no_buckets * fieldByteSize * 2); + } + + // cout << this->m_partyId << ": no_random: " << no_random << " no_buckets: " + // << no_buckets << " N: " << N << " T: " << T << endl; + + /** + * generate random sharings. + * first degree T, then additive + * + */ + for (uint64_t k = 0; k < no_buckets; k++) { + // generate random degree-T polynomial + for (int i = 0; i < T + 1; i++) { + // A random field element, uniform distribution, + // note that x1[0] is the secret which is also random + // x1[i] = this->field->Random(); + x1[i] = field->Random(); + } + + matrix_vand.MatrixMult(x1, y1, T + 1); // eval poly at alpha-positions + + y2[0] = x1[0]; + // generate N-1 random elements + for (int i = 1; i < N; i++) { + // A random field element, uniform distribution + // y2[i] = this->field->Random(); + y2[i] = field->Random(); + // all y2[i] generated so far are additive shares of the secret x1[0] + y2[0] = y2[0] - y2[i]; + } + + // prepare shares to be sent + for (int i = 0; i < N; i++) { + // cout << "y1[ " <field->elementToBytes(sendBufsBytes[i].data() + (j * + // fieldByteSize), sendBufsElements[i][j]); + field->elementToBytes(sendBufsBytes[i].data() + (j * fieldByteSize), + sendBufsElements[i][j]); + } + } + + roundFunctionSync(sendBufsBytes, recBufsBytes, 1); + + for (uint64_t k = 0; k < no_buckets; k++) { + for (int i = 0; i < N; i++) { + t1[i] = field->bytesToElement(recBufsBytes[i].data() + + (2 * k * fieldByteSize)); + t2[i] = field->bytesToElement(recBufsBytes[i].data() + + ((2 * k + 1) * fieldByteSize)); + // t1[i] = this->field->bytesToElement(recBufsBytes[i].data() + (2*k * + // fieldByteSize)); t2[i] = + // this->field->bytesToElement(recBufsBytes[i].data() + ((2*k +1) * + // fieldByteSize)); + } + + matrix_vand_transpose.MatrixMult(t1, r1, N - T); + matrix_vand_transpose.MatrixMult(t2, r2, N - T); + + // copy the resulting vector to the array of randoms + for (int i = 0; i < (N - T); i++) { + randomElementsToFill[index * 2] = r1[i]; + randomElementsToFill[index * 2 + 1] = r2[i]; + index++; + } + } + + if (mpsi_print == true) { + // std::cout << this->m_partyId << ": First pair of shares is " << + // randomElementsToFill[0] << " " << randomElementsToFill[1] << std::endl; + std::cout << ": First pair of shares is " << randomElementsToFill[0] << " " + << randomElementsToFill[1] << std::endl; + } +} + +template +void ProtocolParty::generateRandom2TAndTShares( + int numOfRandomPairs, std::vector& randomElementsToFill) { + TemplateField* field = new TemplateField(0); + // int N = 5;//佈°彗¶佀~Y仼 佅¥ + // int T = N/2 -1;//佈°彗¶佀~Y仼 佅¥ + int index = 0; + std::vector> recBufsBytes(N); + int robin = 0; + int no_random = numOfRandomPairs; + + std::vector x1(N), y1(N), x2(N), y2(N), t1(N), r1(N), t2(N), r2(N); + ; + + std::vector> sendBufsElements(N); + std::vector> sendBufsBytes(N); + + // the number of buckets (each bucket requires one double-sharing + // from each party and gives N-2T random double-sharings) + int no_buckets = (no_random / (N - T)) + 1; + + // sharingBufTElements.resize(no_buckets*(N-2*T)); // my shares of the + // double-sharings sharingBuf2TElements.resize(no_buckets*(N-2*T)); // my + // shares of the double-sharings + + // maybe add some elements if a partial bucket is needed + randomElementsToFill.resize(no_buckets * (N - T) * 2); + + for (int i = 0; i < N; i++) { + sendBufsElements[i].resize(no_buckets * 2); + sendBufsBytes[i].resize(no_buckets * field->getElementSizeInBytes() * 2); + recBufsBytes[i].resize(no_buckets * field->getElementSizeInBytes() * 2); + } + + /** + * generate random sharings. + * first degree t. + * + */ + for (int k = 0; k < no_buckets; k++) { + // generate random degree-T polynomial + for (int i = 0; i < T + 1; i++) { + // A random field element, uniform distribution, note that x1[0] is the + // secret which is also random + x1[i] = field->Random(); + } + + matrix_vand.MatrixMult(x1, y1, T + 1); // eval poly at alpha-positions + + x2[0] = x1[0]; + // generate random degree-T polynomial + for (int i = 1; i < 2 * T + 1; i++) { + // A random field element, uniform distribution, note that x1[0] is the + // secret which is also random + x2[i] = field->Random(); + } + + matrix_vand.MatrixMult(x2, y2, 2 * T + 1); + + // prepare shares to be sent + for (int i = 0; i < N; i++) { + // cout << "y1[ " <getElementSizeInBytes(); + for (int i = 0; i < N; i++) { + for (int j = 0; j < sendBufsElements[i].size(); j++) { + field->elementToBytes(sendBufsBytes[i].data() + (j * fieldByteSize), + sendBufsElements[i][j]); + } + } + + roundFunctionSync(sendBufsBytes, recBufsBytes, 4); + + if (flag_print) { + for (int i = 0; i < N; i++) { + for (int k = 0; k < sendBufsBytes[0].size(); k++) { + std::cout << "roundfunction4 send to " << i << " element: " << k << " " + << (int)sendBufsBytes[i][k] << std::endl; + } + } + for (int i = 0; i < N; i++) { + for (int k = 0; k < recBufsBytes[0].size(); k++) { + std::cout << "roundfunction4 receive from " << i << " element: " << k + << " " << (int)recBufsBytes[i][k] << std::endl; + } + } + } + + for (int k = 0; k < no_buckets; k++) { + for (int i = 0; i < N; i++) { + t1[i] = field->bytesToElement(recBufsBytes[i].data() + + (2 * k * fieldByteSize)); + t2[i] = field->bytesToElement(recBufsBytes[i].data() + + ((2 * k + 1) * fieldByteSize)); + } + matrix_vand_transpose.MatrixMult(t1, r1, N - T); + matrix_vand_transpose.MatrixMult(t2, r2, N - T); + + // copy the resulting vector to the array of randoms + for (int i = 0; i < (N - T); i++) { + randomElementsToFill[index * 2] = r1[i]; + randomElementsToFill[index * 2 + 1] = r2[i]; + index++; + } + } +} + +template +void /*CircuitPSI::*/ ProtocolParty::addShareOpen( + uint64_t num_vals, std::vector& shares, + std::vector& secrets) { + // cout << this->m_partyId << ": Reconstructing additive shares..." << endl; + + // int fieldByteSize = this->field->getElementSizeInBytes(); + int fieldByteSize = 10; + std::vector> recBufsBytes; + std::vector sendBufsBytes; + std::vector aPlusRSharesBytes(num_vals * fieldByteSize); + int i; + uint64_t j; + // int N = 10;//this->N; + // int m_partyId = 0; // + + secrets.resize(num_vals); + for (j = 0; j < num_vals; j++) { + field->elementToBytes(aPlusRSharesBytes.data() + (j * fieldByteSize), + shares[j]); + } + if (m_partyId == 0) { + recBufsBytes.resize(N); + + for (i = 0; i < N; i++) { + recBufsBytes[i].resize(num_vals * fieldByteSize); + } + + roundFunctionSyncForP1(aPlusRSharesBytes, recBufsBytes); + } + + if (/*this-*>*/ m_partyId == 0) { + for (j = 0; j < num_vals; j++) { + // secrets[j] = *(this->field->GetZero()); + for (i = 0; i < N; i++) { + secrets[j] += + field->bytesToElement(recBufsBytes[i].data() + (j * fieldByteSize)); + } + } + } +} + +template +void ProtocolParty::reshare(std::vector& vals, + std::vector& shares) { + uint64_t no_vals = vals.size(); + + std::vector x1(N), y1(N); + // std::vector> sendBufsElements(N); + std::vector> sendBufsElements(N); + std::vector> sendBufsBytes(N); + std::vector> recBufsBytes(N); + std::vector> recBufsElements(N); + + int fieldByteSize = 10; /// this->field->getElementSizeInBytes(); + + // int m_partyId = 0; + if (/*///this->*/ m_partyId == 0) { + // generate T-sharings of the values in vals + for (uint64_t k = 0; k < no_vals; k++) { + // set x1[0] as the secret to be shared + x1[0] = vals[k]; + // generate random degree-T polynomial + for (int i = 1; i < T + 1; i++) { + // A random field element, uniform distribution + // x1[i] = this->field->Random(); + } + + /// this->matrix_vand.MatrixMult(x1, y1,T+1); // eval poly at + /// alpha-positions + + // prepare shares to be sent + for (int i = 0; i < N; i++) { + // cout << "y1[ " <field->elementToBytes(sendBufsBytes[i].data() + (j * + /// fieldByteSize), sendBufsElements[i][j]); + } + // cout << sendBufsElements[i].size() << " " << sendBufsBytes[i].size() << + // " " << recBufsBytes[i].size(); + } + } else { + for (int i = 0; i < N; i++) { + sendBufsBytes[i].resize(num_bins * fieldByteSize); + recBufsBytes[i].resize(num_bins * fieldByteSize); + for (uint64_t j = 0; j < num_bins; j++) { + /// this->field->elementToBytes(sendBufsBytes[i].data(), + /// *(this->field->GetZero())); + } + } + } + + // cout << "byte conversion done \n"; + + // this-> + roundFunctionSync(sendBufsBytes, recBufsBytes, 2); + + // cout << "roundFunctionSync() done "; + + if (/*///this->*/ m_partyId != 0) { + for (uint64_t k = 0; k < no_vals; k++) { + /// shares[k] = this->field->bytesToElement(recBufsBytes[0].data() + (k * + /// fieldByteSize)); + } + } + // cout << "converted back to field elements...\n"; + + if (mpsi_print == true) { + /// cout << this->m_partyId << "First t-sharing received is: " << shares[0] + /// << endl; + } +} + +template +void /*MPSI_Party::*/ +ProtocolParty::additiveToThreshold() { + uint64_t j; + std::vector reconar; // reconstructed aj+rj + std::vector add_a; + uint64_t num_bins = 10; + std::vector a_vals; + std::vector randomTAndAddShares; + + reconar.resize(num_bins); + // add additive share of rj to corresponding share of aj + for (j = 0; j < num_bins; j++) { + add_a[j] = add_a[j] + randomTAndAddShares[j * 2 + 1]; + } + + // reconstruct additive shares, store in reconar + addShareOpen(num_bins, add_a, reconar); + + // reshare and save in a_vals; + reshare(reconar, a_vals); + + // subtract T-threshold shares of rj + for (j = 0; j < num_bins; j++) { + a_vals[j] = a_vals[j] - randomTAndAddShares[j * 2]; + } +} + +template +void ProtocolParty::sendFromP1(std::vector& sendBuf) { + // cout<<"in roundFunctionSyncBroadcast "<< endl; + + int numThreads = parties.size(); + int numPartiesForEachThread; + + if (parties.size() <= numThreads) { + numThreads = parties.size(); + numPartiesForEachThread = 1; + } else { + numPartiesForEachThread = (parties.size() + numThreads - 1) / numThreads; + } + + // recieve the data using threads + // vector threads(numThreads); + for (int t = 0; t < numThreads; t++) { + /*if ((t + 1) * numPartiesForEachThread <= parties.size()) { + // threads[t] = thread(&ProtocolParty::sendDataFromP1, this, + ref(sendBuf), + // t * numPartiesForEachThread, (t + 1) * + numPartiesForEachThread); } else { + // threads[t] = thread(&ProtocolParty::sendDataFromP1, this, + ref(sendBuf), t * numPartiesForEachThread, parties.size()); + }*/ + } + for (int t = 0; t < numThreads; t++) { + // threads[t].join(); + } + + count_ = count_ + 1; + // auto [ctx, wsize, me, leader] = CollectContext(); + // std::vector recv_shares(count); + + for (size_t id{}; id != me; ++id) { + ElSend(p2p_[id], items_, shares_[id]); + ElRecv(p2p_[id], items_); + } +} + +// Interpolate polynomial at position Zero +template +FieldType ProtocolParty::interpolate(std::vector x) { + std::vector y(N); // result of interpolate + matrix_for_interpolate.MatrixMult(x, y); + return y[0]; +} + +template +void ProtocolParty::DNHonestMultiplication( + std::vector a, std::vector b, + std::vector& cToFill, uint64_t numOfTrupples, int offset) { + int index = 0; + // int numOfMultGates = circuit.getNrOfMultiplicationGates(); + uint64_t numOfMultGates = numOfTrupples; + int fieldByteSize = field->getElementSizeInBytes(); + std::vector xyMinusRShares( + numOfTrupples); // hold both in the same vector to send in one batch + std::vector xyMinusRSharesBytes( + numOfTrupples * + fieldByteSize); // hold both in the same vector to send in one batch + + std::vector + xyMinusR; // hold both in the same vector to send in one batch + std::vector xyMinusRBytes; + + std::vector> recBufsBytes; + + // int offset = numOfMultGates*2; + // int offset = 0; + // generate the shares for x+a and y+b. do it in the same array to send once + for (int k = 0; k < numOfTrupples; k++) // go over only the logit gates + { + xyMinusRShares[k] = a[k] * b[k] - randomTAnd2TShares[offset + 2 * k + 1]; + } + // set the acctual number of mult gate proccessed in this layer + xyMinusRSharesBytes.resize(numOfTrupples * fieldByteSize); + xyMinusR.resize(numOfTrupples); + xyMinusRBytes.resize(numOfTrupples * fieldByteSize); + for (int j = 0; j < xyMinusRShares.size(); j++) { + field->elementToBytes(xyMinusRSharesBytes.data() + (j * fieldByteSize), + xyMinusRShares[j]); + } + + if (m_partyId == 0) { + // just party 1 needs the recbuf + recBufsBytes.resize(N); + + for (int i = 0; i < N; i++) { + recBufsBytes[i].resize(numOfTrupples * fieldByteSize); + } + + // receive the shares from all the other parties + roundFunctionSyncForP1(xyMinusRSharesBytes, recBufsBytes); + } + + // reconstruct the shares recieved from the other parties + if (m_partyId == 0) { + std::vector xyMinurAllShares(N), yPlusB(N); + + for (int k = 0; k < numOfTrupples; k++) // go over only the logit gates + { + for (int i = 0; i < N; i++) { + xyMinurAllShares[i] = + field->bytesToElement(recBufsBytes[i].data() + (k * fieldByteSize)); + } + + // reconstruct the shares by P0 + xyMinusR[k] = interpolate(xyMinurAllShares); + + // convert to bytes + field->elementToBytes(xyMinusRBytes.data() + (k * fieldByteSize), + xyMinusR[k]); + } + + // send the reconstructed vector to all the other parties + sendFromP1(xyMinusRBytes); + } + + // fill the xPlusAAndYPlusB array for all the parties except for party 1 that + // already have this array filled + if (m_partyId != 0) { + for (int i = 0; i < numOfTrupples; i++) { + xyMinusR[i] = + field->bytesToElement(xyMinusRBytes.data() + (i * fieldByteSize)); + } + } + + index = 0; + + // after the xPlusAAndYPlusB array is filled, we are ready to fill the output + // of the mult gates + for (int k = 0; k < numOfTrupples; k++) // go over only the logit gates + { + cToFill[k] = randomTAnd2TShares[offset + 2 * k] + xyMinusR[k]; + } +} + +template +void ProtocolParty::mult_sj() { + int fieldByteSize = field->getElementSizeInBytes(); + std::vector masks; + std::vector a_vals; + + //// + std::vector mult_outs; // threshold shares of s_j*a_j + /// + int num_bins = 64; + ; + std::vector multbytes(num_bins * fieldByteSize); + std::vector> recBufsBytes; + int i; + uint64_t j; + + DNHonestMultiplication(masks, a_vals, mult_outs, num_bins, 0); + for (j = 0; j < num_bins; j++) { + // this-> + field->elementToBytes(multbytes.data() + (j * fieldByteSize), mult_outs[j]); + } + + if (m_partyId == 0) { + recBufsBytes.resize(N); + for (i = 0; i < N; i++) { + recBufsBytes[i].resize(num_bins * fieldByteSize); + } + // this-> + roundFunctionSyncForP1(multbytes, recBufsBytes); + } + + if (m_partyId == 0) { + std::vector x1(N); + for (j = 0; j < num_bins; j++) { + for (i = 0; i < N; i++) { + x1[i] = + field->bytesToElement(recBufsBytes[i].data() + (j * fieldByteSize)); + } + // outputs[j] = this->interpolate(x1); + } + } +} + +template +uint64_t ProtocolParty::evaluateCircuit() { + additiveToThreshold(); + mult_sj(); + return 0; +} + +template +void ProtocolParty::protocolIni(int M_, int N_, size_t me_id) { + N = N_; + T = N / 2 - 1; + m_partyId = me_id; + M = M_; + // ini party num +} + +void randomParse(std::vector& items, size_t count, + std::vector>& shares, + std::vector> p2p, + size_t me_id, size_t wsize, int M, int N) { + uint64_t num_bins = 64; + std::vector masks; + std::vector randomTAndAddShares; + std::vector randomTAnd2TShares; + masks.resize(num_bins); + + ProtocolParty mpsi; + mpsi.items_ = items; + mpsi.count_ = count; + mpsi.shares_ = shares; + mpsi.p2p_ = p2p; + mpsi.protocolIni(M, N, me_id); + mpsi.me = me_id; + mpsi.wsize = wsize; + mpsi.generateRandomShares(num_bins, masks); + mpsi.modDoubleRandom(num_bins, randomTAndAddShares); + mpsi.generateRandom2TAndTShares(num_bins, randomTAnd2TShares); + + mpsi.evaluateCircuit(); +} + +} // namespace psi::psi diff --git a/experimental/psi/psi21/el_mp_psi/el_protocol.h b/experimental/psi/psi21/el_mp_psi/el_protocol.h new file mode 100644 index 00000000..dfa493d6 --- /dev/null +++ b/experimental/psi/psi21/el_mp_psi/el_protocol.h @@ -0,0 +1,52 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/link.h" + +namespace psi::psi { +template +void generateRandomShares(int numOfRnadoms, + std::vector& randomElementsToFill); +// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) + +// prepare additive and T-threshold sharings of secret random value r_j using +// DN07's protocol template void modDoubleRandom(uint64_t +// no_random, std::vector& randomElementsToFill); + +// template +// void generateRandom2TAndTShares(int numOfRandomPairs, std::vector& +// randomElementsToFill); + +// template +void randomParse(std::vector& items, size_t count, + std::vector>& shares, + std::vector> p2p, + size_t me_id, size_t wsize, int M, int N); + +// evaluate circuit +// template +// uint64_t evaluateCircuit(); + +// convert additive shares to T-threshold +// void additiveToThreshold(); +// void mult_sj(); + +} // namespace psi::psi diff --git a/experimental/psi/psi21/el_mp_psi/el_sender.cc b/experimental/psi/psi21/el_mp_psi/el_sender.cc index 1d4bedf5..c117758e 100644 --- a/experimental/psi/psi21/el_mp_psi/el_sender.cc +++ b/experimental/psi/psi21/el_mp_psi/el_sender.cc @@ -17,7 +17,9 @@ #include #include #include +// #include +#include "experimental/psi/psi21/el_mp_psi/Mersenne.h" #include "experimental/psi/psi21/el_mp_psi/el_hashing.h" #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/ro.h" @@ -49,10 +51,10 @@ static auto ro = yc::RandomOracle::GetDefault(); } // namespace // Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) +#define flag_print false -std::vector ElRecv( - const std::shared_ptr& ctx, - const std::vector& queries) { +std::vector ElRecv(const std::shared_ptr& ctx, + const std::vector& queries) { const size_t size{queries.size()}; const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), static_cast(std::ceil(size * ZETA[1]))}; @@ -72,7 +74,6 @@ std::vector ElRecv( auto buf = ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); std::vector table(xssize); std::memcpy(table.data(), buf.data(), table.size() * sizeof(uint128_t)); - // todo for (size_t i{}; i != table.size(); ++i) { table[i] = table[i] + (nonce >> 10); // SPDLOG_INFO(" table[i] = {}, size{}", @@ -83,8 +84,7 @@ std::vector ElRecv( } void ElSend(const std::shared_ptr& ctx, - const std::vector& xs, - const std::vector& ys) { + const std::vector& xs, const std::vector& ys) { YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); const size_t size{xs.size()}; const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), diff --git a/experimental/psi/psi21/el_mp_psi/el_sender.h b/experimental/psi/psi21/el_mp_psi/el_sender.h index 82feac81..941b5342 100644 --- a/experimental/psi/psi21/el_mp_psi/el_sender.h +++ b/experimental/psi/psi21/el_mp_psi/el_sender.h @@ -23,13 +23,12 @@ namespace psi::psi { -// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) - std::vector ElRecv(const std::shared_ptr&, - const std::vector& queries); + const std::vector& queries); void ElSend(const std::shared_ptr&, - const std::vector& xs, - const std::vector& ys); + const std::vector& xs, const std::vector& ys); + +// convert additive shares to T-threshold } // namespace psi::psi From 42d3a2dcb529c9582707483b5a9b85e19e4e9dc0 Mon Sep 17 00:00:00 2001 From: xsjwyh Date: Fri, 14 Feb 2025 15:50:23 +0800 Subject: [PATCH 8/8] format mpsi with clang-format --- .vscode/settings.json | 5 + experimental/psi/psi21/el_c_psi/BUILD.bazel | 61 ------ experimental/psi/psi21/el_c_psi/el_c_psi.cc | 165 --------------- experimental/psi/psi21/el_c_psi/el_c_psi.h | 59 ------ .../psi/psi21/el_c_psi/el_c_psi_benchmark.cc | 87 -------- .../psi/psi21/el_c_psi/el_c_psi_test.cc | 171 --------------- experimental/psi/psi21/el_c_psi/el_hashing.cc | 69 ------ experimental/psi/psi21/el_c_psi/el_hashing.h | 65 ------ experimental/psi/psi21/el_c_psi/el_opprf.cc | 188 ----------------- experimental/psi/psi21/el_c_psi/el_opprf.h | 35 ---- experimental/psi/psi21/el_mp_psi/el_mp_psi.cc | 5 +- .../psi/psi21/el_mp_psi/el_mp_psi_test.cc | 3 +- experimental/psi/psi21/el_q_psi/BUILD.bazel | 61 ------ experimental/psi/psi21/el_q_psi/el_hashing.cc | 69 ------ experimental/psi/psi21/el_q_psi/el_hashing.h | 65 ------ experimental/psi/psi21/el_q_psi/el_opprf.cc | 189 ----------------- experimental/psi/psi21/el_q_psi/el_opprf.h | 35 ---- experimental/psi/psi21/el_q_psi/el_q_psi.cc | 165 --------------- experimental/psi/psi21/el_q_psi/el_q_psi.h | 60 ------ .../psi/psi21/el_q_psi/el_q_psi_benchmark.cc | 87 -------- .../psi/psi21/el_q_psi/el_q_psi_test.cc | 196 ------------------ 21 files changed, 10 insertions(+), 1830 deletions(-) create mode 100644 .vscode/settings.json delete mode 100644 experimental/psi/psi21/el_c_psi/BUILD.bazel delete mode 100644 experimental/psi/psi21/el_c_psi/el_c_psi.cc delete mode 100644 experimental/psi/psi21/el_c_psi/el_c_psi.h delete mode 100644 experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc delete mode 100644 experimental/psi/psi21/el_c_psi/el_c_psi_test.cc delete mode 100644 experimental/psi/psi21/el_c_psi/el_hashing.cc delete mode 100644 experimental/psi/psi21/el_c_psi/el_hashing.h delete mode 100644 experimental/psi/psi21/el_c_psi/el_opprf.cc delete mode 100644 experimental/psi/psi21/el_c_psi/el_opprf.h delete mode 100644 experimental/psi/psi21/el_q_psi/BUILD.bazel delete mode 100644 experimental/psi/psi21/el_q_psi/el_hashing.cc delete mode 100644 experimental/psi/psi21/el_q_psi/el_hashing.h delete mode 100644 experimental/psi/psi21/el_q_psi/el_opprf.cc delete mode 100644 experimental/psi/psi21/el_q_psi/el_opprf.h delete mode 100644 experimental/psi/psi21/el_q_psi/el_q_psi.cc delete mode 100644 experimental/psi/psi21/el_q_psi/el_q_psi.h delete mode 100644 experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc delete mode 100644 experimental/psi/psi21/el_q_psi/el_q_psi_test.cc diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..56e50026 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.associations": { + "set": "cpp" + } +} \ No newline at end of file diff --git a/experimental/psi/psi21/el_c_psi/BUILD.bazel b/experimental/psi/psi21/el_c_psi/BUILD.bazel deleted file mode 100644 index b60d86bb..00000000 --- a/experimental/psi/psi21/el_c_psi/BUILD.bazel +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 zhangwfjh -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_binary( - name = "el_c_psi_benchmark", - srcs = ["el_c_psi_benchmark.cc"], - deps = [ - ":el_c_psi", - "@com_github_google_benchmark//:benchmark_main", - ], -) - -psi_cc_library( - name = "el_c_psi", - srcs = [ - "el_hashing.cc", - "el_c_psi.cc", - "el_opprf.cc", - ], - hdrs = [ - "el_hashing.h", - "el_c_psi.h", - "el_opprf.h", - ], - deps = [ - "//psi/utils:communication", - "//psi/utils:sync", - "//psi/utils:test_utils", - "@com_google_absl//absl/types:span", - "@yacl//yacl/base:exception", - "@yacl//yacl/base:int128", - "@yacl//yacl/crypto/hash:hash_utils", - "@yacl//yacl/crypto/rand", - "@yacl//yacl/kernel/algorithms:base_ot", - "@yacl//yacl/kernel/algorithms:iknp_ote", - "@yacl//yacl/kernel/algorithms:kkrt_ote", - "@yacl//yacl/link", - ], -) - -psi_cc_test( - name = "el_c_psi_test", - srcs = ["el_c_psi_test.cc"], - tags = ["manual"], - deps = [":el_c_psi"], -) diff --git a/experimental/psi/psi21/el_c_psi/el_c_psi.cc b/experimental/psi/psi21/el_c_psi/el_c_psi.cc deleted file mode 100644 index 64aca255..00000000 --- a/experimental/psi/psi21/el_c_psi/el_c_psi.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" - -#include - -#include "experimental/psi/psi21/el_c_psi/el_opprf.h" -#include "psi/psi/utils/communication.h" -#include "psi/psi/utils/sync.h" -#include "yacl/crypto/hash/hash_utils.h" -#include "yacl/crypto/rand/rand.h" -#include "yacl/utils/serialize.h" - -namespace psi::psi { - -namespace { - -constexpr uint32_t kLinkRecvTimeout = 60 * 60 * 1000; - -} // namespace - -NcParty::NcParty(const Options& options) : options_{options} { - auto [ctx, wsize, me, leader] = CollectContext(); - ctx->SetRecvTimeout(kLinkRecvTimeout); - p2p_.resize(wsize); - for (size_t dst{}; dst != wsize; ++dst) { - if (me != dst) { - p2p_[dst] = CreateP2PLinkCtx("el_c_psi", ctx, dst); - } - } -} - -std::vector NcParty::Run( - const std::vector& inputs) { - auto [ctx, wsize, me, leader] = CollectContext(); - auto counts = AllGatherItemsSize(ctx, inputs.size()); - size_t count{}; - for (auto cnt : counts) { - if (cnt == 0) { - return {}; - } - count = std::max(cnt, count); - } - - auto items = EncodeInputs(inputs, count); - auto shares = ZeroSharing(count); - auto recv_share = SwapShares(items, shares); - auto recons = Reconstruct(items, recv_share); - std::vector intersection; - for (size_t k{}; k != count; ++k) { - if (recons[k] == 0) { - intersection.push_back("1"); - } else { - intersection.push_back("0"); - } - } - return intersection; -} - -std::vector NcParty::EncodeInputs( - const std::vector& inputs, size_t count) const { - std::vector items; - items.reserve(count); - std::transform( - inputs.begin(), inputs.end(), std::back_inserter(items), - [](std::string_view input) { return yacl::crypto::Blake3_128(input); }); - // Add random dummy elements - std::generate_n(std::back_inserter(items), count - inputs.size(), - yacl::crypto::FastRandU128); - return items; -} - -auto NcParty::ZeroSharing(size_t count) const -> std::vector { - auto [ctx, wsize, me, leader] = CollectContext(); - std::vector shares(wsize, Share(count)); - for (size_t k{}; k != count; ++k) { - uint64_t sum{}; - for (size_t dst{1}; dst != wsize; ++dst) { - sum ^= shares[dst][k] = yacl::crypto::FastRandU64(); - } - shares[0][k] = sum; - } - return shares; -} - -auto NcParty::SwapShares(const std::vector& items, - const std::vector& shares) const -> Share { - auto [ctx, wsize, me, leader] = CollectContext(); - auto count = shares.front().size(); - std::vector recv_shares(count); - std::vector> futures(wsize); - // NOTE: First Send Then Receive for peers of smaller ranks - for (size_t id{}; id != me; ++id) { - futures[id] = std::async( - [&](size_t id) { - ElOpprfSend(p2p_[id], items, shares[id]); - return ElOpprfRecv(p2p_[id], items); - }, - id); - } - // NOTE: First Receive Then Send for peers of larger ranks - for (size_t id{me + 1}; id != wsize; ++id) { - futures[id] = std::async( - [&](size_t id) { - auto ret = ElOpprfRecv(p2p_[id], items); - ElOpprfSend(p2p_[id], items, shares[id]); - return ret; - }, - id); - } - for (size_t id{}; id != wsize; ++id) { - recv_shares[id] = (me == id ? shares[id] : futures[id].get()); - } - - Share share(count); // S(x_k) - for (size_t k{}; k != count; ++k) { - for (size_t src{}; src != wsize; ++src) { - share[k] ^= recv_shares[src][k]; - } - } - return share; -} - -auto NcParty::Reconstruct(const std::vector& items, - const Share& share) const -> Share { - auto [ctx, wsize, me, leader] = CollectContext(); - auto count = items.size(); - if (me == leader) { - std::vector recv_shares(count); - std::vector> futures(wsize); - for (size_t src{}; src != wsize; ++src) { - if (me != src) { - futures[src] = std::async( - [&](size_t src) { return ElOpprfRecv(p2p_[src], items); }, src); - } - } - for (size_t src{}; src != wsize; ++src) { - recv_shares[src] = (me == src ? share : futures[src].get()); - } - Share recons(count); // sum of S_i(x_k) over i - for (size_t k{}; k != count; ++k) { - for (size_t src{}; src != wsize; ++src) { - recons[k] ^= recv_shares[src][k]; - } - } - return recons; - } else { - ElOpprfSend(p2p_[leader], items, share); - return share; - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_c_psi.h b/experimental/psi/psi21/el_c_psi/el_c_psi.h deleted file mode 100644 index 00176003..00000000 --- a/experimental/psi/psi21/el_c_psi/el_c_psi.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include -#include - -#include "yacl/base/int128.h" -#include "yacl/link/link.h" - -namespace psi::psi { - -// Practical Multi-party Private Set Intersection from Symmetric-Key Techniques -// https://eprint.iacr.org/2017/799.pdf - -class NcParty { - public: - struct Options { - std::shared_ptr link_ctx; - size_t leader_rank; - }; - - NcParty(const Options& options); - virtual std::vector Run(const std::vector& inputs); - - private: - using Share = std::vector; - - std::vector EncodeInputs(const std::vector& inputs, - size_t count) const; - std::vector ZeroSharing(size_t count) const; - Share SwapShares(const std::vector& items, - const std::vector& shares) const; - Share Reconstruct(const std::vector& items, - const Share& share) const; - - // (ctx, world_size, my_rank, leader_rank) - auto CollectContext() const { - return std::make_tuple(options_.link_ctx, options_.link_ctx->WorldSize(), - options_.link_ctx->Rank(), options_.leader_rank); - } - - Options options_; - std::vector> p2p_; -}; - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc b/experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc deleted file mode 100644 index 3823161d..00000000 --- a/experimental/psi/psi21/el_c_psi/el_c_psi_benchmark.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "benchmark/benchmark.h" -#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" -#include "experimental/psi/psi21/el_c_psi/el_opprf.h" -#include "yacl/base/exception.h" -#include "yacl/crypto/hash/hash_utils.h" -#include "yacl/link/test_util.h" - -namespace { -std::vector CreateRangeItems(size_t begin, size_t size) { - std::vector ret(size); - for (size_t i = 0; i < size; i++) { - auto hash = yacl::crypto::Blake3(std::to_string(begin + i)); - memcpy(&ret[i], hash.data(), sizeof(uint128_t)); - } - return ret; -} - -void ElCPsiSend(const std::shared_ptr& link_ctx, - const std::vector& items_hash) { - // auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); - // return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); - std::vector shares; - for (size_t i = 0; i < items_hash.size(); i++) { - uint64_t item = 0; - shares.push_back(item); - } - - return psi::psi::ElOpprfSend(link_ctx, items_hash, shares); -} - -std::vector ElCPsiRecv( - const std::shared_ptr& link_ctx, - const std::vector& items_hash) { - // auto ot_send = psi::kkrt::GetKkrtOtReceiverOptions(link_ctx, 512); - // return psi::kkrt::KkrtPsiRecv(link_ctx, ot_send, items_hash); - return psi::psi::ElOpprfRecv(link_ctx, items_hash); -} - -} // namespace - -static void BM_El_C_Psi(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - auto alice_items = CreateRangeItems(1, n); - auto bob_items = CreateRangeItems(2, n); - - auto contexts = yacl::link::test::SetupWorld(2); - - state.ResumeTiming(); - - std::future kkrt_psi_sender = - std::async([&] { return ElCPsiSend(contexts[0], alice_items); }); - std::future> kkrt_psi_receiver = - std::async([&] { return ElCPsiRecv(contexts[1], bob_items); }); - - kkrt_psi_sender.get(); - auto results_b = kkrt_psi_receiver.get(); - } -} - -// [256k, 512k, 1m, 2m, 4m, 8m] -BENCHMARK(BM_El_C_Psi) - ->Unit(benchmark::kMillisecond) - ->Arg(256 << 10) - ->Arg(512 << 10) - ->Arg(1 << 20) - ->Arg(2 << 20) - ->Arg(4 << 20) - ->Arg(8 << 20); diff --git a/experimental/psi/psi21/el_c_psi/el_c_psi_test.cc b/experimental/psi/psi21/el_c_psi/el_c_psi_test.cc deleted file mode 100644 index 3a4ca20d..00000000 --- a/experimental/psi/psi21/el_c_psi/el_c_psi_test.cc +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_c_psi/el_c_psi.h" - -#include -#include -#include - -#include "gtest/gtest.h" -#include "psi/psi/utils/test_utils.h" -#include "yacl/link/test_util.h" - -namespace psi::psi { - -namespace { - -struct NCTestParams { - std::vector item_size; - size_t intersection_size; -}; - -std::vector> CreateNPartyItems( - const NCTestParams& params) { - std::vector> ret(params.item_size.size() + 1); - ret[params.item_size.size()] = - test::CreateRangeItems(1, params.intersection_size); - - for (size_t idx = 0; idx < params.item_size.size(); ++idx) { - ret[idx] = - test::CreateRangeItems((idx + 1) * 1000000, params.item_size[idx]); - } - - for (size_t idx = 0; idx < params.item_size.size(); ++idx) { - std::set idx_set; - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, params.item_size[idx] - 1); - - while (idx_set.size() < params.intersection_size) { - idx_set.insert(dis(gen)); - } - size_t j = 0; - for (const auto& iter : idx_set) { - ret[idx][iter] = ret[params.item_size.size()][j++]; - } - } - return ret; -} - -} // namespace - -class NCPsiTest : public testing::TestWithParam {}; - -// FIXME : this test is not stable in arm env -TEST_P(NCPsiTest, Works) { - std::vector> items; - std::vector> resultvec; - std::vector finalresult; - - auto params = GetParam(); - items = CreateNPartyItems(params); - size_t leader_rank = 0; - uint128_t maxlength = 0; - uint128_t n = params.item_size.size() - 1; - - for (size_t i = 0; i < params.item_size.size() - 1; i++) { - std::vector> items1; - items1.push_back(items[0]); - items1.push_back(items[i + 1]); - leader_rank = 0; - - auto ctxs = yacl::link::test::SetupWorld(2); - auto proc = [&](int idx) -> std::vector { - NcParty::Options opts; - opts.link_ctx = ctxs[idx]; - opts.leader_rank = leader_rank; - NcParty op(opts); - // for (size_t j{}; j != items1[idx].size(); ++j) { - // SPDLOG_INFO(" items[{}][{}] = {}, size{}", idx, i, items[idx][i], - // items[idx].size()); - // } - - return op.Run(items[idx]); - }; - - size_t world_size = ctxs.size(); - std::vector>> f_links(world_size); - for (size_t j = 0; j < world_size; j++) { - f_links[j] = std::async(proc, j); - } - sleep(1); - - std::vector result; - result = f_links[0].get(); - resultvec.push_back(result); - - /*for (size_t j = 0; j < result.size(); j++) { - SPDLOG_INFO("i{} j{}, result[j] {} size{}", i, j, result[j], - result.size()); - }*/ - } - - maxlength = items[0].size(); - std::vector qpsivector; - for (size_t j = 0; j < maxlength; j++) { - uint128_t sum = 0; - for (size_t i = 0; i < params.item_size.size() - 1; i++) { - // 如果有的集合没有那么多项就continue - // results[i] = f_links[i].get(); - if (resultvec[i].size() <= j) { - continue; - } - - // SPDLOG_INFO(" result[{}][{}] = {}", i, j, resultvec[i][j]); - auto it = resultvec[i].begin() + j; - std::string element = *it; - if (element == "1") { - sum++; - } - } - if (sum >= n) { - // todo//推入对应input元素 之后再查输入变量从param中怎么取出推入 - qpsivector.push_back(1); - } else { - qpsivector.push_back(0); - } - } - - // std::vector intersectionnparty; - for (size_t k{}; k != items[0].size(); ++k) { - if (qpsivector[k] == 1) { - finalresult.push_back(items[0][k]); - } - } - - /*for (size_t i{}; i != finalresult.size(); ++i) { - SPDLOG_INFO("intersectionnparty = {}", finalresult[i]); - }*/ - - std::vector intersection = items[params.item_size.size()]; - std::sort(intersection.begin(), intersection.end()); - - std::sort(finalresult.begin(), finalresult.end()); - EXPECT_EQ(finalresult.size(), intersection.size()); - EXPECT_EQ(finalresult, intersection); -} - -INSTANTIATE_TEST_SUITE_P( - Works_Instances, NCPsiTest, - // testing::Values(NCTestParams{{1, 3}, 1})); - testing::Values(NCTestParams{{0, 3}, 0}, // - NCTestParams{{3, 0}, 0}, // - NCTestParams{{0, 0}, 0}, // - NCTestParams{{4, 3}, 2}, // - NCTestParams{{20, 17, 14}, 10}, // - NCTestParams{{20, 17, 14, 30}, 10}, // - NCTestParams{{20, 17, 14, 30, 35}, 11}, // - NCTestParams{{20, 17, 14, 30, 35}, 0})); -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_hashing.cc b/experimental/psi/psi21/el_c_psi/el_hashing.cc deleted file mode 100644 index 9956709b..00000000 --- a/experimental/psi/psi21/el_c_psi/el_hashing.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_c_psi/el_hashing.h" - -#include -#include - -#include "yacl/base/exception.h" -#include "yacl/crypto/rand/rand.h" - -namespace psi::psi { - -void KmprtCuckooHashing::Insert(uint128_t elem) { - auto insert_into = [this, &elem](uint8_t c) { - for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) { - uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c]; - uint8_t idx = (rand_idx + 1) % num_hashes_[c]; - size_t addr; - do { - addr = HashU128{}(elem, idx) % num_bins_[c]; - if (auto &bin = bins_[c][addr]; bin == NONE || bin == elem) { - bin = std::exchange(elem, NONE); - return; - } - idx = (idx + 1) % num_hashes_[c]; - } while (idx != rand_idx); - std::swap(bins_[c][addr], elem); - } - }; - for (uint8_t c{}; c != 2; ++c) { - insert_into(c); - } - YACL_ENFORCE_EQ(elem, NONE, "Failed to insert element."); -} - -auto KmprtCuckooHashing::Lookup(uint128_t elem) const - -> std::pair { - for (uint8_t c{}; c != 2; ++c) { - for (uint8_t idx{}; idx != num_hashes_[c]; ++idx) { - if (size_t addr = HashU128{}(elem, idx) % num_bins_[c]; - bins_[c][addr] == elem) { - return {c, addr}; - } - } - } - return {-1, -1}; -} - -void KmprtSimpleHashing::Insert(std::pair point) { - for (uint8_t c{}; c != 2; ++c) { - for (size_t idx{}; idx != num_hashes_[c]; ++idx) { - bins_[c][HashU128{}(point.first, idx) % num_bins_[c]].emplace(point); - } - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_hashing.h b/experimental/psi/psi21/el_c_psi/el_hashing.h deleted file mode 100644 index be80df99..00000000 --- a/experimental/psi/psi21/el_c_psi/el_hashing.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "absl/numeric/int128.h" -#include "yacl/base/int128.h" - -namespace psi::psi { - -struct HashU128 { - size_t operator()(uint128_t x, uint8_t idx = 0) const { - return absl::Uint128High64(x) + idx * absl::Uint128Low64(x); - } -}; - -template -struct KmprtDoubleHashing { - KmprtDoubleHashing(size_t m1, size_t m2) : num_bins_{m1, m2} { - bins_[0].resize(m1); - bins_[1].resize(m2); - } - - const Bin &GetBin(uint8_t c, size_t addr) const { return bins_[c][addr]; } - - const uint8_t num_hashes_[2]{3, 2}; - const size_t num_bins_[2]; - std::vector bins_[2]; -}; - -class KmprtCuckooHashing : public KmprtDoubleHashing { - public: - constexpr static uint128_t NONE{}; - - KmprtCuckooHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} - - void Insert(uint128_t); - std::pair Lookup(uint128_t) const; -}; - -class KmprtSimpleHashing - : public KmprtDoubleHashing< - std::unordered_map> { - public: - KmprtSimpleHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} - - void Insert(std::pair); -}; - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_opprf.cc b/experimental/psi/psi21/el_c_psi/el_opprf.cc deleted file mode 100644 index f685583b..00000000 --- a/experimental/psi/psi21/el_c_psi/el_opprf.cc +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_c_psi/el_opprf.h" - -#include -#include -#include - -#include "experimental/psi/psi21/el_c_psi/el_hashing.h" -#include "yacl/crypto/rand/rand.h" -#include "yacl/crypto/tools/ro.h" -#include "yacl/kernel/algorithms/base_ot.h" -#include "yacl/kernel/algorithms/iknp_ote.h" -#include "yacl/kernel/algorithms/kkrt_ote.h" -#include "yacl/link/link.h" -#include "yacl/utils/serialize.h" - -namespace psi::psi { - -namespace yc = yacl::crypto; - -namespace { - -// PSI-related constants -// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) -constexpr float ZETA[]{1.12f, 0.17f}; -// constexpr size_t BETA[]{31, 63}; -constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs - -// OTe-related constants -constexpr size_t NUM_BASE_OT{128}; -constexpr size_t NUM_INKP_OT{512}; -constexpr size_t BATCH_SIZE{896}; - -static auto ro = yc::RandomOracle::GetDefault(); - -} // namespace - -// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) - -std::vector ElOpprfRecv( - const std::shared_ptr& ctx, - const std::vector& queries) { - const size_t size{queries.size()}; - const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), - static_cast(std::ceil(size * ZETA[1]))}; - // Step 0. Prepares OPRF - yc::KkrtOtExtReceiver receiver; - size_t num_ot{bin_sizes[0] + bin_sizes[1]}; - auto choice = yc::RandBits(NUM_BASE_OT); - auto base_ot = yc::BaseOtRecv(ctx, choice, NUM_BASE_OT); - auto store = yc::IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); - receiver.Init(ctx, store, num_ot); - receiver.SetBatchSize(BATCH_SIZE); - - // Step 1. Hashes queries into Cuckoo hashing - KmprtCuckooHashing hashing{bin_sizes[0], bin_sizes[1]}; - for (size_t i{}; i != size; ++i) { - hashing.Insert(queries[i]); - } - - std::vector evals; - evals.reserve(num_ot); - size_t ot_idx{}, b{}; - std::array batch_evals; - // Step 2. For each bin, invokes single-query OPPRF - for (uint8_t c{}; c != 2; ++c) { - size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; - size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; - for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { - auto elem = hashing.GetBin(c, addr); - elem == KmprtCuckooHashing::NONE && (elem = yc::FastRandU128()); - receiver.Encode( - ot_idx, elem, - {reinterpret_cast(&batch_evals[b++]), sizeof(uint64_t)}); - if (auto batch_size = (ot_idx - ot_begin) % BATCH_SIZE + 1; - batch_size == BATCH_SIZE || ot_idx + 1 == ot_end) { - b = 0; - receiver.SendCorrection(ctx, batch_size); - // For each query in a batch - for (size_t i{}; i != batch_size; ++i) { - uint128_t nonce = yacl::DeserializeUint128( - ctx->Recv(ctx->NextRank(), "Receive OPPRF nonce")); - std::vector table(TABLE_SIZE[c]); - auto buf = - ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); - std::memcpy(table.data(), buf.data(), - TABLE_SIZE[c] * sizeof(uint64_t)); - uint64_t eval = batch_evals[i]; - auto index = - ro.Gen(absl::MakeSpan(reinterpret_cast(&eval), - sizeof eval), - nonce) % - table.size(); - evals.emplace_back(eval ^ table[index]); - } - } - } - } - - // Step 3. Filters and obtains the results - std::vector results(size); - std::transform(queries.cbegin(), queries.cend(), results.begin(), - [&](auto q) { - auto [c, addr] = hashing.Lookup(q); - return evals[c * bin_sizes[0] + addr]; - }); - return results; -} - -void ElOpprfSend(const std::shared_ptr& ctx, - const std::vector& xs, - const std::vector& ys) { - YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); - const size_t size{xs.size()}; - const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), - static_cast(std::ceil(size * ZETA[1]))}; - // Step 0. Prepares OPRF - yc::KkrtOtExtSender sender; - size_t num_ot{bin_sizes[0] + bin_sizes[1]}; - auto base_ot = yc::BaseOtSend(ctx, NUM_BASE_OT); - auto choice = yc::RandBits(NUM_INKP_OT); - auto store = yc::IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); - sender.Init(ctx, store, num_ot); - sender.SetBatchSize(BATCH_SIZE); - - // Step 1. Hashes points into Simple hashing - KmprtSimpleHashing hashing{bin_sizes[0], bin_sizes[1]}; - for (size_t i{}; i != size; ++i) { - hashing.Insert({xs[i], ys[i]}); - } - size_t ot_idx{}; - auto evaluator = sender.GetOprf(); - // Step 2. For each bin, invokes single-query OPPRF - for (uint8_t c{}; c != 2; ++c) { - size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; - size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; - // For each programmable point in a batch - for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { - if ((ot_idx - ot_begin) % BATCH_SIZE == 0) { - sender.RecvCorrection( - ctx, ot_idx + BATCH_SIZE >= ot_end ? ot_end - ot_idx : BATCH_SIZE); - } - auto bin = hashing.GetBin(c, addr); - uint128_t nonce; - std::vector table; - bool separable; - do { - separable = true; - nonce = yc::FastRandSeed(); - table.assign(TABLE_SIZE[c], uint64_t{0}); - for (auto it = bin.cbegin(); it != bin.cend(); ++it) { - uint64_t eval = evaluator->Eval(ot_idx, it->first); - auto index = - ro.Gen({reinterpret_cast(&eval), sizeof eval}, - nonce) % - table.size(); - if (table[index] != uint64_t{0}) { - separable = false; - break; - } - table[index] = eval ^ it->second; - } - } while (!separable); - for (size_t i{}; i != TABLE_SIZE[c]; ++i) { - table[i] == uint64_t{0} && (table[i] = yc::FastRandU128()); - } - ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(nonce), - fmt::format("OPPRF:Nonce={}", nonce)); - yacl::Buffer buf(table.data(), table.size() * sizeof(uint64_t)); - ctx->SendAsync(ctx->NextRank(), buf, "OPPRF:EncryptionTable"); - } - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_c_psi/el_opprf.h b/experimental/psi/psi21/el_c_psi/el_opprf.h deleted file mode 100644 index a26a095c..00000000 --- a/experimental/psi/psi21/el_c_psi/el_opprf.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "yacl/base/int128.h" -#include "yacl/link/link.h" - -namespace psi::psi { - -// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) - -std::vector ElOpprfRecv(const std::shared_ptr&, - const std::vector& queries); - -void ElOpprfSend(const std::shared_ptr&, - const std::vector& xs, - const std::vector& ys); - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc index 5e6be9d9..f683a00a 100644 --- a/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi.cc @@ -18,12 +18,13 @@ #include "experimental/psi/psi21/el_mp_psi/el_protocol.h" #include "experimental/psi/psi21/el_mp_psi/el_sender.h" -#include "psi/utils/communication.h" -#include "psi/utils/sync.h" #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/rand/rand.h" #include "yacl/utils/serialize.h" +#include "psi/utils/communication.h" +#include "psi/utils/sync.h" + namespace psi::psi { namespace { diff --git a/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc index 38ad695d..27392880 100644 --- a/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc +++ b/experimental/psi/psi21/el_mp_psi/el_mp_psi_test.cc @@ -19,9 +19,10 @@ #include #include "gtest/gtest.h" -#include "psi/utils/test_utils.h" #include "yacl/link/test_util.h" +#include "psi/utils/test_utils.h" + namespace psi::psi { namespace { diff --git a/experimental/psi/psi21/el_q_psi/BUILD.bazel b/experimental/psi/psi21/el_q_psi/BUILD.bazel deleted file mode 100644 index ed5299e8..00000000 --- a/experimental/psi/psi21/el_q_psi/BUILD.bazel +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 zhangwfjh -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_binary( - name = "el_q_psi_benchmark", - srcs = ["el_q_psi_benchmark.cc"], - deps = [ - ":el_q_psi", - "@com_github_google_benchmark//:benchmark_main", - ], -) - -psi_cc_library( - name = "el_q_psi", - srcs = [ - "el_hashing.cc", - "el_q_psi.cc", - "el_opprf.cc", - ], - hdrs = [ - "el_hashing.h", - "el_q_psi.h", - "el_opprf.h", - ], - deps = [ - "//psi/utils:communication", - "//psi/utils:sync", - "//psi/utils:test_utils", - "@com_google_absl//absl/types:span", - "@yacl//yacl/base:exception", - "@yacl//yacl/base:int128", - "@yacl//yacl/crypto/hash:hash_utils", - "@yacl//yacl/crypto/rand", - "@yacl//yacl/kernel/algorithms:base_ot", - "@yacl//yacl/kernel/algorithms:iknp_ote", - "@yacl//yacl/kernel/algorithms:kkrt_ote", - "@yacl//yacl/link", - ], -) - -psi_cc_test( - name = "el_q_psi_test", - srcs = ["el_q_psi_test.cc"], - tags = ["manual"], - deps = [":el_q_psi"], -) diff --git a/experimental/psi/psi21/el_q_psi/el_hashing.cc b/experimental/psi/psi21/el_q_psi/el_hashing.cc deleted file mode 100644 index ba6e621a..00000000 --- a/experimental/psi/psi21/el_q_psi/el_hashing.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_q_psi/el_hashing.h" - -#include -#include - -#include "yacl/base/exception.h" -#include "yacl/crypto/rand/rand.h" - -namespace psi::psi { - -void KmprtCuckooHashing::Insert(uint128_t elem) { - auto insert_into = [this, &elem](uint8_t c) { - for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) { - uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c]; - uint8_t idx = (rand_idx + 1) % num_hashes_[c]; - size_t addr; - do { - addr = HashU128{}(elem, idx) % num_bins_[c]; - if (auto &bin = bins_[c][addr]; bin == NONE || bin == elem) { - bin = std::exchange(elem, NONE); - return; - } - idx = (idx + 1) % num_hashes_[c]; - } while (idx != rand_idx); - std::swap(bins_[c][addr], elem); - } - }; - for (uint8_t c{}; c != 2; ++c) { - insert_into(c); - } - YACL_ENFORCE_EQ(elem, NONE, "Failed to insert element."); -} - -auto KmprtCuckooHashing::Lookup(uint128_t elem) const - -> std::pair { - for (uint8_t c{}; c != 2; ++c) { - for (uint8_t idx{}; idx != num_hashes_[c]; ++idx) { - if (size_t addr = HashU128{}(elem, idx) % num_bins_[c]; - bins_[c][addr] == elem) { - return {c, addr}; - } - } - } - return {-1, -1}; -} - -void KmprtSimpleHashing::Insert(std::pair point) { - for (uint8_t c{}; c != 2; ++c) { - for (size_t idx{}; idx != num_hashes_[c]; ++idx) { - bins_[c][HashU128{}(point.first, idx) % num_bins_[c]].emplace(point); - } - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_hashing.h b/experimental/psi/psi21/el_q_psi/el_hashing.h deleted file mode 100644 index be80df99..00000000 --- a/experimental/psi/psi21/el_q_psi/el_hashing.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "absl/numeric/int128.h" -#include "yacl/base/int128.h" - -namespace psi::psi { - -struct HashU128 { - size_t operator()(uint128_t x, uint8_t idx = 0) const { - return absl::Uint128High64(x) + idx * absl::Uint128Low64(x); - } -}; - -template -struct KmprtDoubleHashing { - KmprtDoubleHashing(size_t m1, size_t m2) : num_bins_{m1, m2} { - bins_[0].resize(m1); - bins_[1].resize(m2); - } - - const Bin &GetBin(uint8_t c, size_t addr) const { return bins_[c][addr]; } - - const uint8_t num_hashes_[2]{3, 2}; - const size_t num_bins_[2]; - std::vector bins_[2]; -}; - -class KmprtCuckooHashing : public KmprtDoubleHashing { - public: - constexpr static uint128_t NONE{}; - - KmprtCuckooHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} - - void Insert(uint128_t); - std::pair Lookup(uint128_t) const; -}; - -class KmprtSimpleHashing - : public KmprtDoubleHashing< - std::unordered_map> { - public: - KmprtSimpleHashing(size_t m1, size_t m2) : KmprtDoubleHashing{m1, m2} {} - - void Insert(std::pair); -}; - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_opprf.cc b/experimental/psi/psi21/el_q_psi/el_opprf.cc deleted file mode 100644 index c060512c..00000000 --- a/experimental/psi/psi21/el_q_psi/el_opprf.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_q_psi/el_opprf.h" - -#include -#include -#include - -#include "experimental/psi/psi21/el_q_psi/el_hashing.h" -#include "yacl/crypto/rand/rand.h" -#include "yacl/crypto/tools/ro.h" -#include "yacl/kernel/algorithms/base_ot.h" -#include "yacl/kernel/algorithms/iknp_ote.h" -#include "yacl/kernel/algorithms/kkrt_ote.h" -#include "yacl/link/link.h" -#include "yacl/utils/serialize.h" - -namespace psi::psi { - -namespace yc = yacl::crypto; - -namespace { - -// PSI-related constants -// Ref https://eprint.iacr.org/2017/799.pdf (Table 2) -constexpr float ZETA[]{1.12f, 0.17f}; -// constexpr size_t BETA[]{31, 63}; -constexpr size_t TABLE_SIZE[]{32, 64}; // next power of BETAs - -// OTe-related constants -constexpr size_t NUM_BASE_OT{128}; -constexpr size_t NUM_INKP_OT{512}; -constexpr size_t BATCH_SIZE{896}; - -static auto ro = yc::RandomOracle::GetDefault(); - -} // namespace - -// Ref https://eprint.iacr.org/2017/799.pdf (Figure 6, 7) - -std::vector ElOpprfRecv( - const std::shared_ptr& ctx, - const std::vector& queries) { - const size_t size{queries.size()}; - const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), - static_cast(std::ceil(size * ZETA[1]))}; - // Step 0. Prepares OPRF - yc::KkrtOtExtReceiver receiver; - size_t num_ot{bin_sizes[0] + bin_sizes[1]}; - auto choice = yc::RandBits(NUM_BASE_OT); - auto base_ot = yc::BaseOtRecv(ctx, choice, NUM_BASE_OT); - auto store = yc::IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); - receiver.Init(ctx, store, num_ot); - receiver.SetBatchSize(BATCH_SIZE); - - // Step 1. Hashes queries into Cuckoo hashing - KmprtCuckooHashing hashing{bin_sizes[0], bin_sizes[1]}; - for (size_t i{}; i != size; ++i) { - hashing.Insert(queries[i]); - } - - std::vector evals; - evals.reserve(num_ot); - size_t ot_idx{}, b{}; - std::array batch_evals; - // Step 2. For each bin, invokes single-query OPPRF - for (uint8_t c{}; c != 2; ++c) { - size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; - size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; - for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { - auto elem = hashing.GetBin(c, addr); - elem == KmprtCuckooHashing::NONE && (elem = yc::FastRandU128()); - receiver.Encode( - ot_idx, elem, - {reinterpret_cast(&batch_evals[b++]), sizeof(uint64_t)}); - if (auto batch_size = (ot_idx - ot_begin) % BATCH_SIZE + 1; - batch_size == BATCH_SIZE || ot_idx + 1 == ot_end) { - b = 0; - receiver.SendCorrection(ctx, batch_size); - // For each query in a batch - for (size_t i{}; i != batch_size; ++i) { - uint128_t nonce = yacl::DeserializeUint128( - ctx->Recv(ctx->NextRank(), "Receive OPPRF nonce")); - std::vector table(TABLE_SIZE[c]); - auto buf = - ctx->Recv(ctx->NextRank(), "Receive OPPRF EncryptionTable"); - std::memcpy(table.data(), buf.data(), - TABLE_SIZE[c] * sizeof(uint64_t)); - uint64_t eval = batch_evals[i]; - auto index = - ro.Gen(absl::MakeSpan(reinterpret_cast(&eval), - sizeof eval), - nonce) % - table.size(); - evals.emplace_back(eval ^ table[index]); - } - } - } - } - - // Step 3. Filters and obtains the results - std::vector results(size); - std::transform(queries.cbegin(), queries.cend(), results.begin(), - [&](auto q) { - auto [c, addr] = hashing.Lookup(q); - return evals[c * bin_sizes[0] + addr]; - }); - return results; -} - -void ElOpprfSend(const std::shared_ptr& ctx, - const std::vector& xs, - const std::vector& ys) { - YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); - const size_t size{xs.size()}; - const size_t bin_sizes[]{static_cast(std::ceil(size * ZETA[0])), - static_cast(std::ceil(size * ZETA[1]))}; - // Step 0. Prepares OPRF - yc::KkrtOtExtSender sender; - size_t num_ot{bin_sizes[0] + bin_sizes[1]}; - auto base_ot = yc::BaseOtSend(ctx, NUM_BASE_OT); - auto choice = yc::RandBits(NUM_INKP_OT); - auto store = yc::IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); - sender.Init(ctx, store, num_ot); - sender.SetBatchSize(BATCH_SIZE); - - // Step 1. Hashes points into Simple hashing - KmprtSimpleHashing hashing{bin_sizes[0], bin_sizes[1]}; - for (size_t i{}; i != size; ++i) { - hashing.Insert({xs[i], ys[i]}); - } - size_t ot_idx{}; - auto evaluator = sender.GetOprf(); - // Step 2. For each bin, invokes single-query OPPRF - for (uint8_t c{}; c != 2; ++c) { - size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]}; - size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot}; - // For each programmable point in a batch - for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) { - if ((ot_idx - ot_begin) % BATCH_SIZE == 0) { - sender.RecvCorrection( - ctx, ot_idx + BATCH_SIZE >= ot_end ? ot_end - ot_idx : BATCH_SIZE); - } - auto bin = hashing.GetBin(c, addr); - uint128_t nonce; - std::vector table; - bool separable; - do { - separable = true; - nonce = yc::FastRandSeed(); - table.assign(TABLE_SIZE[c], uint64_t{0}); - for (auto it = bin.cbegin(); it != bin.cend(); ++it) { - uint64_t eval = evaluator->Eval(ot_idx, it->first); - auto index = - ro.Gen({reinterpret_cast(&eval), sizeof eval}, - nonce) % - table.size(); - if (table[index] != uint64_t{0}) { - separable = false; - break; - } - table[index] = eval ^ it->second; - } - } while (!separable); - for (size_t i{}; i != TABLE_SIZE[c]; ++i) { - table[i] == uint64_t{0} && (table[i] = yc::FastRandU128()); - } - - ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(nonce), - fmt::format("OPPRF:Nonce={}", nonce)); - yacl::Buffer buf(table.data(), table.size() * sizeof(uint64_t)); - ctx->SendAsync(ctx->NextRank(), buf, "OPPRF:EncryptionTable"); - } - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_opprf.h b/experimental/psi/psi21/el_q_psi/el_opprf.h deleted file mode 100644 index a26a095c..00000000 --- a/experimental/psi/psi21/el_q_psi/el_opprf.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "yacl/base/int128.h" -#include "yacl/link/link.h" - -namespace psi::psi { - -// Table-based OPPRF, see https://eprint.iacr.org/2017/799.pdf (Figure 6) - -std::vector ElOpprfRecv(const std::shared_ptr&, - const std::vector& queries); - -void ElOpprfSend(const std::shared_ptr&, - const std::vector& xs, - const std::vector& ys); - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_q_psi.cc b/experimental/psi/psi21/el_q_psi/el_q_psi.cc deleted file mode 100644 index 6129ecc0..00000000 --- a/experimental/psi/psi21/el_q_psi/el_q_psi.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" - -#include - -#include "experimental/psi/psi21/el_q_psi/el_opprf.h" -#include "psi/psi/utils/communication.h" -#include "psi/psi/utils/sync.h" -#include "yacl/crypto/hash/hash_utils.h" -#include "yacl/crypto/rand/rand.h" -#include "yacl/utils/serialize.h" - -namespace psi::psi { - -namespace { - -constexpr uint32_t kLinkRecvTimeout = 60 * 60 * 1000; - -} // namespace - -NcParty::NcParty(const Options& options) : options_{options} { - auto [ctx, wsize, me, leader] = CollectContext(); - ctx->SetRecvTimeout(kLinkRecvTimeout); - p2p_.resize(wsize); - for (size_t dst{}; dst != wsize; ++dst) { - if (me != dst) { - p2p_[dst] = CreateP2PLinkCtx("el_q_psi", ctx, dst); - } - } -} - -std::vector NcParty::Run( - const std::vector& inputs) { - auto [ctx, wsize, me, leader] = CollectContext(); - auto counts = AllGatherItemsSize(ctx, inputs.size()); - size_t count{}; - - for (auto cnt : counts) { - if (cnt == 0) { - return {}; - } - count = std::max(cnt, count); - } - auto items = EncodeInputs(inputs, count); - auto shares = ZeroSharing(count); - auto recv_share = SwapShares(items, shares); - auto recons = Reconstruct(items, recv_share); - std::vector intersection; - for (size_t k{}; k != count; ++k) { - if (recons[k] == 0) { - intersection.push_back("1"); - } else { - intersection.push_back("0"); - } - } - return intersection; -} - -std::vector NcParty::EncodeInputs( - const std::vector& inputs, size_t count) const { - std::vector items; - items.reserve(count); - std::transform( - inputs.begin(), inputs.end(), std::back_inserter(items), - [](std::string_view input) { return yacl::crypto::Blake3_128(input); }); - // Add random dummy elements - std::generate_n(std::back_inserter(items), count - inputs.size(), - yacl::crypto::FastRandU128); - return items; -} - -auto NcParty::ZeroSharing(size_t count) const -> std::vector { - auto [ctx, wsize, me, leader] = CollectContext(); - std::vector shares(wsize, Share(count)); - for (size_t k{}; k != count; ++k) { - uint64_t sum{}; - for (size_t dst{1}; dst != wsize; ++dst) { - sum ^= shares[dst][k] = yacl::crypto::FastRandU64(); - } - shares[0][k] = sum; - } - return shares; -} - -auto NcParty::SwapShares(const std::vector& items, - const std::vector& shares) const -> Share { - auto [ctx, wsize, me, leader] = CollectContext(); - auto count = shares.front().size(); - std::vector recv_shares(count); - std::vector> futures(wsize); - // NOTE: First Send Then Receive for peers of smaller ranks - for (size_t id{}; id != me; ++id) { - futures[id] = std::async( - [&](size_t id) { - ElOpprfSend(p2p_[id], items, shares[id]); - return ElOpprfRecv(p2p_[id], items); - }, - id); - } - // NOTE: First Receive Then Send for peers of larger ranks - for (size_t id{me + 1}; id != wsize; ++id) { - futures[id] = std::async( - [&](size_t id) { - auto ret = ElOpprfRecv(p2p_[id], items); - ElOpprfSend(p2p_[id], items, shares[id]); - return ret; - }, - id); - } - for (size_t id{}; id != wsize; ++id) { - recv_shares[id] = (me == id ? shares[id] : futures[id].get()); - } - - Share share(count); // S(x_k) - for (size_t k{}; k != count; ++k) { - for (size_t src{}; src != wsize; ++src) { - share[k] ^= recv_shares[src][k]; - } - } - return share; -} - -auto NcParty::Reconstruct(const std::vector& items, - const Share& share) const -> Share { - auto [ctx, wsize, me, leader] = CollectContext(); - auto count = items.size(); - if (me == leader) { - std::vector recv_shares(count); - std::vector> futures(wsize); - for (size_t src{}; src != wsize; ++src) { - if (me != src) { - futures[src] = std::async( - [&](size_t src) { return ElOpprfRecv(p2p_[src], items); }, src); - } - } - for (size_t src{}; src != wsize; ++src) { - recv_shares[src] = (me == src ? share : futures[src].get()); - } - Share recons(count); // sum of S_i(x_k) over i - for (size_t k{}; k != count; ++k) { - for (size_t src{}; src != wsize; ++src) { - recons[k] ^= recv_shares[src][k]; - } - } - return recons; - } else { - ElOpprfSend(p2p_[leader], items, share); - return share; - } -} - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_q_psi.h b/experimental/psi/psi21/el_q_psi/el_q_psi.h deleted file mode 100644 index 35f3026d..00000000 --- a/experimental/psi/psi21/el_q_psi/el_q_psi.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "yacl/base/int128.h" -#include "yacl/link/link.h" - -namespace psi::psi { - -// Practical Multi-party Private Set Intersection from Symmetric-Key Techniques -// https://eprint.iacr.org/2017/799.pdf - -class NcParty { - public: - struct Options { - std::shared_ptr link_ctx; - size_t leader_rank; - }; - - NcParty(const Options& options); - virtual std::vector Run(const std::vector& inputs); - - private: - using Share = std::vector; - - std::vector EncodeInputs(const std::vector& inputs, - size_t count) const; - std::vector ZeroSharing(size_t count) const; - Share SwapShares(const std::vector& items, - const std::vector& shares) const; - Share Reconstruct(const std::vector& items, - const Share& share) const; - - // (ctx, world_size, my_rank, leader_rank) - auto CollectContext() const { - return std::make_tuple(options_.link_ctx, options_.link_ctx->WorldSize(), - options_.link_ctx->Rank(), options_.leader_rank); - } - - Options options_; - std::vector> p2p_; -}; - -} // namespace psi::psi diff --git a/experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc b/experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc deleted file mode 100644 index d80daf29..00000000 --- a/experimental/psi/psi21/el_q_psi/el_q_psi_benchmark.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "benchmark/benchmark.h" -#include "experimental/psi/psi21/el_q_psi/el_opprf.h" -#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" -#include "yacl/base/exception.h" -#include "yacl/crypto/hash/hash_utils.h" -#include "yacl/link/test_util.h" - -namespace { -std::vector CreateRangeItems(size_t begin, size_t size) { - std::vector ret(size); - for (size_t i = 0; i < size; i++) { - auto hash = yacl::crypto::Blake3(std::to_string(begin + i)); - memcpy(&ret[i], hash.data(), sizeof(uint128_t)); - } - return ret; -} - -void ElQPsiSend(const std::shared_ptr& link_ctx, - const std::vector& items_hash) { - // auto ot_recv = psi::kkrt::GetKkrtOtSenderOptions(link_ctx, 512); - // return psi::kkrt::KkrtPsiSend(link_ctx, ot_recv, items_hash); - std::vector shares; - for (size_t i = 0; i < items_hash.size(); i++) { - uint64_t item = 0; - shares.push_back(item); - } - - return psi::psi::ElOpprfSend(link_ctx, items_hash, shares); -} - -std::vector ElQPsiRecv( - const std::shared_ptr& link_ctx, - const std::vector& items_hash) { - // auto ot_send = psi::kkrt::GetKkrtOtReceiverOptions(link_ctx, 512); - // return psi::kkrt::KkrtPsiRecv(link_ctx, ot_send, items_hash); - return psi::psi::ElOpprfRecv(link_ctx, items_hash); -} - -} // namespace - -static void BM_El_Q_Psi(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - auto alice_items = CreateRangeItems(1, n); - auto bob_items = CreateRangeItems(2, n); - - auto contexts = yacl::link::test::SetupWorld(2); - - state.ResumeTiming(); - - std::future kkrt_psi_sender = - std::async([&] { return ElQPsiSend(contexts[0], alice_items); }); - std::future> kkrt_psi_receiver = - std::async([&] { return ElQPsiRecv(contexts[1], bob_items); }); - - kkrt_psi_sender.get(); - auto results_b = kkrt_psi_receiver.get(); - } -} - -// [256k, 512k, 1m, 2m, 4m, 8m] -BENCHMARK(BM_El_Q_Psi) - ->Unit(benchmark::kMillisecond) - ->Arg(256 << 10) - ->Arg(512 << 10) - ->Arg(1 << 20) - ->Arg(2 << 20) - ->Arg(4 << 20) - ->Arg(8 << 20); diff --git a/experimental/psi/psi21/el_q_psi/el_q_psi_test.cc b/experimental/psi/psi21/el_q_psi/el_q_psi_test.cc deleted file mode 100644 index 4fcacbab..00000000 --- a/experimental/psi/psi21/el_q_psi/el_q_psi_test.cc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2024 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "experimental/psi/psi21/el_q_psi/el_q_psi.h" - -#include -#include -#include - -#include "gtest/gtest.h" -#include "psi/psi/utils/test_utils.h" -#include "yacl/link/test_util.h" - -namespace psi::psi { - -namespace { - -struct NCTestParams { - std::vector item_size; - size_t intersection_size; - size_t n; -}; - -std::vector> CreateNPartyItems( - const NCTestParams& params) { - std::vector> ret(params.item_size.size() + 1); - ret[params.item_size.size()] = - test::CreateRangeItems(1, params.intersection_size); - - for (size_t idx = 0; idx < params.item_size.size(); ++idx) { - ret[idx] = - test::CreateRangeItems((idx + 1) * 1000000, params.item_size[idx]); - } - - for (size_t idx = 0; idx < params.item_size.size(); ++idx) { - std::set idx_set; - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, params.item_size[idx] - 1); - - while (idx_set.size() < params.intersection_size) { - idx_set.insert(dis(gen)); - } - size_t j = 0; - for (const auto& iter : idx_set) { - ret[idx][iter] = ret[params.item_size.size()][j++]; - } - - if (/*idx > dis(gen) &&*/ idx >= params.n) { - break; - } - } - return ret; -} - -} // namespace - -class NCPsiTest : public testing::TestWithParam {}; - -// FIXME : this test is not stable in arm env -TEST_P(NCPsiTest, Works) { - std::vector> items; - std::vector> resultvec; - std::vector finalresult; - - auto params = GetParam(); - size_t n = params.n; - items = CreateNPartyItems(params); - size_t leader_rank = 0; - uint128_t maxlength = 0; - - if (n >= params.item_size.size()) { - SPDLOG_INFO("param error: n > items[0].size() "); - return; - } - - if (n <= 0) { - SPDLOG_INFO("param error: n <= 0 "); - return; - } - - /* - for (size_t j{}; j != items[0].size(); ++j) { - SPDLOG_INFO(" items[{}][{}] = {}, size{}", 0, j, items[0][j], - items[0].size()); - } - */ - for (size_t i = 0; i < params.item_size.size() - 1; i++) { - std::vector> items1; - items1.push_back(items[0]); - items1.push_back(items[i + 1]); - leader_rank = 0; - - /* - for (size_t j{}; j != items[i + 1].size(); ++j) { - SPDLOG_INFO(" items[{}][{}] = {}, size{}", i+1, j, items[i+1][j], - items[i+1].size()); - } - */ - - auto ctxs = yacl::link::test::SetupWorld(2); - auto proc = [&](int idx) -> std::vector { - NcParty::Options opts; - opts.link_ctx = ctxs[idx]; - opts.leader_rank = leader_rank; - NcParty op(opts); - - return op.Run(items[idx]); - }; - - size_t world_size = ctxs.size(); - std::vector>> f_links(world_size); - for (size_t j = 0; j < world_size; j++) { - f_links[j] = std::async(proc, j); - } - sleep(1); - - std::vector result; - result = f_links[0].get(); - resultvec.push_back(result); - - /* - for (size_t j = 0; j < result.size(); j++) { - SPDLOG_INFO("i{} j{}, result[j] {} size{}", i, j, result[j], - result.size()); - }*/ - } - - maxlength = items[0].size(); - std::vector qpsivector; - for (size_t j = 0; j < maxlength; j++) { - uint128_t sum = 0; - for (size_t i = 0; i < params.item_size.size() - 1; i++) { - // 如果有的集合没有那么多项就continue - // results[i] = f_links[i].get(); - if (resultvec[i].size() <= j) { - continue; - } - - // SPDLOG_INFO(" result[{}][{}] = {}", i, j, resultvec[i][j]); - auto it = resultvec[i].begin() + j; - std::string element = *it; - if (element == "1") { - sum++; - } - } - if (sum >= n) { - // todo//推入对应input元素 之后再查输入变量从param中怎么取出推入 - qpsivector.push_back(1); - } else { - qpsivector.push_back(0); - } - } - - // std::vector intersectionnparty; - for (size_t k{}; k != items[0].size(); ++k) { - if (qpsivector[k] == 1) { - finalresult.push_back(items[0][k]); - } - } - - for (size_t i{}; i != finalresult.size(); ++i) { - SPDLOG_INFO("intersectionnparty = {}", finalresult[i]); - } - - std::vector intersection = items[params.item_size.size()]; - std::sort(intersection.begin(), intersection.end()); - - std::sort(finalresult.begin(), finalresult.end()); - EXPECT_EQ(finalresult.size(), intersection.size()); - EXPECT_EQ(finalresult, intersection); -} - -INSTANTIATE_TEST_SUITE_P( - Works_Instances, NCPsiTest, - testing::Values(NCTestParams{{0, 3}, 0, 1}, // - NCTestParams{{3, 0}, 0, 1}, // - NCTestParams{{0, 0}, 0, 1}, // - NCTestParams{{4, 3}, 2, 1}, // - NCTestParams{{20, 17, 14}, 10, 2}, // - NCTestParams{{20, 17, 14, 30}, 10, 3}, // - NCTestParams{{20, 17, 14, 30, 35}, 11, 4}, // - NCTestParams{{20, 17, 14, 30, 35}, 0, 4})); -// testing::Values(NCTestParams{{3, 0}, 0, 1})); -} // namespace psi::psi