Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel PIRProcessDatabase #120

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 102 additions & 61 deletions Sources/PIRProcessDatabase/ProcessDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable {
}

@main
struct ProcessDatabase: ParsableCommand {
struct ProcessDatabase: AsyncParsableCommand {
static let configuration: CommandConfiguration = .init(
commandName: "PIRProcessDatabase")

Expand All @@ -311,13 +311,18 @@ struct ProcessDatabase: ParsableCommand {
""")
var configFile: String

@Flag(name: .customLong("parallel"),
inversion: .prefixedNo,
help: "Enables parallel processing.")
var parallel = true

/// Performs the processing on the given database.
/// - Parameters:
/// - config: The configuration for the PIR processing.
/// - scheme: The HE scheme.
/// - Throws: Error upon processing the database.
@inlinable
mutating func process<Scheme: HeScheme>(config: Arguments, scheme: Scheme.Type) throws {
mutating func process<Scheme: HeScheme>(config: Arguments, scheme: Scheme.Type) async throws {
let database: [KeywordValuePair] =
try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: config.inputDatabase).native()

Expand All @@ -339,69 +344,39 @@ struct ProcessDatabase: ParsableCommand {
keyCompression: config.keyCompression,
trialsPerShard: config.trialsPerShard)

var evaluationKeyConfig = EvaluationKeyConfig()
let context = try Context(encryptionParameters: processArgs.encryptionParameters)
let keywordDatabase = try KeywordDatabase(rows: database, sharding: processArgs.databaseConfig.sharding)
ProcessDatabase.logger
.info("Sharded database into \(keywordDatabase.shards.count) shards")
for (shardID, shard) in keywordDatabase.shards
.sorted(by: { $0.0.localizedStandardCompare($1.0) == .orderedAscending })
{
func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws {
switch event {
case let .cuckooTableEvent(.createdTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Created cuckoo table \(summary)")
case let .cuckooTableEvent(.expandingTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.finishedExpandingTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Finished expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)):
let reportingPercentage = 10
let shardFraction = shard.rows.count / reportingPercentage
if (index + 1).isMultiple(of: shardFraction) {
let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction)
ProcessDatabase.logger
.info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%")
ProcessDatabase.logger.info("Sharded database into \(keywordDatabase.shards.count) shards")

let shards = keywordDatabase.shards.sorted { $0.0.localizedStandardCompare($1.0) == .orderedAscending }

var evaluationKeyConfig = EvaluationKeyConfig()
if parallel {
try await withThrowingTaskGroup(of: EvaluationKeyConfig.self) { group in
for (shardID, shard) in shards {
group.addTask { @Sendable [self] in
try await processShard(
shardID: shardID,
shard: shard,
config: config,
context: context,
processArgs: processArgs)
}
}
}

ProcessDatabase.logger.info("Processing shard \(shardID) with \(shard.rows.count) rows")
let processed = try ProcessKeywordDatabase.processShard(
shard: shard,
with: processArgs,
onEvent: logEvent)
if config.trialsPerShard > 0 {
guard let row = shard.rows.first else {
throw PirError.emptyDatabase
for try await processedEvaluationKeyConfig in group {
evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union()
}
ProcessDatabase.logger.info("Validating shard \(shardID)")
let validationResults = try ProcessKeywordDatabase
.validateShard(shard: processed,
row: KeywordValuePair(keyword: row.key, value: row.value),
trials: config.trialsPerShard, context: context)
let description = try validationResults.description()
ProcessDatabase.logger.info("ValidationResults \(description)")
}

let outputDatabaseFilename = config.outputDatabase.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try processed.database.save(to: outputDatabaseFilename)
ProcessDatabase.logger.info("Saved shard \(shardID) to \(outputDatabaseFilename)")

let shardEvaluationKeyConfig = processed.evaluationKeyConfig
evaluationKeyConfig = [evaluationKeyConfig, shardEvaluationKeyConfig].union()

let shardPirParameters = try processed.proto(context: context)
let outputParametersFilename = config.outputPirParameters.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try shardPirParameters.save(to: outputParametersFilename)
ProcessDatabase.logger.info("Saved shard \(shardID) PIR parameters to \(outputParametersFilename)")
} else {
for (shardID, shard) in shards {
let processedEvaluationKeyConfig = try await processShard(
shardID: shardID,
shard: shard, config:
config, context: context,
processArgs: processArgs)
evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union()
}
}

if let evaluationKeyConfigFile = config.outputEvaluationKeyConfig {
Expand All @@ -411,14 +386,80 @@ struct ProcessDatabase: ParsableCommand {
}
}

mutating func run() throws {
private func processShard<Scheme: HeScheme>(
shardID: String,
shard: KeywordDatabaseShard,
config: ResolvedArguments,
context: Context<Scheme>,
processArgs: ProcessKeywordDatabase.Arguments<Scheme>) async throws -> EvaluationKeyConfig
{
var logger = ProcessDatabase.logger
logger[metadataKey: "shardID"] = .string(shardID)

func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws {
switch event {
case let .cuckooTableEvent(.createdTable(table)):
let summary = try table.summarize()
logger.info("Created cuckoo table \(summary)")
case let .cuckooTableEvent(.expandingTable(table)):
let summary = try table.summarize()
logger.info("Expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.finishedExpandingTable(table)):
let summary = try table.summarize()
logger.info("Finished expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)):
let reportingPercentage = 10
let shardFraction = shard.rows.count / reportingPercentage
if (index + 1).isMultiple(of: shardFraction) {
let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction)
logger.info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%")
}
}
}

logger.info("Processing shard with \(shard.rows.count) rows")
let processed = try ProcessKeywordDatabase.processShard(
shard: shard,
with: processArgs,
onEvent: logEvent)

if config.trialsPerShard > 0 {
guard let row = shard.rows.first else {
throw PirError.emptyDatabase
}
logger.info("Validating shard")
let validationResults = try ProcessKeywordDatabase
.validateShard(shard: processed,
row: KeywordValuePair(keyword: row.key, value: row.value),
trials: config.trialsPerShard, context: context)
let description = try validationResults.description()
logger.info("ValidationResults \(description)")
}

let outputDatabaseFilename = config.outputDatabase.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try processed.database.save(to: outputDatabaseFilename)
logger.info("Saved shard to \(outputDatabaseFilename)")

let shardPirParameters = try processed.proto(context: context)
let outputParametersFilename = config.outputPirParameters.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try shardPirParameters.save(to: outputParametersFilename)
logger.info("Saved shard PIR parameters to \(outputParametersFilename)")

return processed.evaluationKeyConfig
}

mutating func run() async throws {
let configURL = URL(fileURLWithPath: configFile)
let configData = try Data(contentsOf: configURL)
let config = try JSONDecoder().decode(Arguments.self, from: configData)
if config.rlweParameters.supportsScalar(UInt32.self) {
try process(config: config, scheme: Bfv<UInt32>.self)
try await process(config: config, scheme: Bfv<UInt32>.self)
} else {
try process(config: config, scheme: Bfv<UInt64>.self)
try await process(config: config, scheme: Bfv<UInt64>.self)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/PrivateInformationRetrieval/KeywordDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ extension Sharding {
}

/// A shard of a ``KeywordDatabase``.
public struct KeywordDatabaseShard: Hashable, Codable {
public struct KeywordDatabaseShard: Hashable, Codable, Sendable {
/// Identifier for the shard.
public let shardID: String
/// Rows in the database.
Expand Down Expand Up @@ -204,7 +204,7 @@ extension KeywordDatabaseShard: Collection {
}

/// Configuration for a ``KeywordDatabase``.
public struct KeywordDatabaseConfig: Hashable, Codable {
public struct KeywordDatabaseConfig: Hashable, Codable, Sendable {
public let sharding: Sharding
public let keywordPirConfig: KeywordPirConfig

Expand Down Expand Up @@ -264,7 +264,7 @@ public struct KeywordDatabase {
/// Utilities for processing a ``KeywordDatabase``.
public enum ProcessKeywordDatabase {
/// Arguments for processing a keyword database.
public struct Arguments<Scheme: HeScheme>: Codable {
public struct Arguments<Scheme: HeScheme>: Codable, Sendable {
/// Database configuration.
public let databaseConfig: KeywordDatabaseConfig
/// Encryption parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Foundation
import HomomorphicEncryption

/// Configuration for a ``KeywordDatabase``.
public struct KeywordPirConfig: Hashable, Codable {
public struct KeywordPirConfig: Hashable, Codable, Sendable {
/// Number of dimensions in the database.
@usableFromInline let dimensionCount: Int

Expand Down
Loading