Skip to content

Commit

Permalink
migrate modopt write
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Jun 19, 2024
1 parent bd2416f commit e385875
Show file tree
Hide file tree
Showing 20 changed files with 268 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import org.neo4j.gds.algorithms.estimation.AlgorithmEstimator;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationBaseConfig;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationMemoryEstimateDefinition;
import org.neo4j.gds.scc.SccBaseConfig;
import org.neo4j.gds.scc.SccMemoryEstimateDefinition;
import org.neo4j.gds.triangle.IntersectingTriangleCountMemoryEstimateDefinition;
Expand Down Expand Up @@ -77,16 +75,4 @@ public <C extends LocalClusteringCoefficientBaseConfig> MemoryEstimateResult loc
new LocalClusteringCoefficientMemoryEstimateDefinition(configuration.seedProperty())
);
}

public <C extends ModularityOptimizationBaseConfig> MemoryEstimateResult modularityOptimization(
Object graphNameOrConfiguration,
C configuration
) {
return algorithmEstimator.estimate(
graphNameOrConfiguration,
configuration,
configuration.relationshipWeightProperty(),
new ModularityOptimizationMemoryEstimateDefinition()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import org.neo4j.gds.algorithms.AlgorithmComputationResult;
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationBaseConfig;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationFactory;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationResult;
import org.neo4j.gds.scc.SccAlgorithmFactory;
import org.neo4j.gds.scc.SccCommonBaseConfig;
import org.neo4j.gds.triangle.IntersectingTriangleCountFactory;
Expand Down Expand Up @@ -81,16 +78,4 @@ public AlgorithmComputationResult<LocalClusteringCoefficientResult> localCluster
new LocalClusteringCoefficientFactory<>()
);
}

public AlgorithmComputationResult<ModularityOptimizationResult> modularityOptimization(
String graphName,
ModularityOptimizationBaseConfig config
) {
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
new ModularityOptimizationFactory<>()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.neo4j.gds.algorithms.community.specificfields.AlphaSccSpecificFields;
import org.neo4j.gds.algorithms.community.specificfields.CommunityStatisticsSpecificFields;
import org.neo4j.gds.algorithms.community.specificfields.LocalClusteringCoefficientSpecificFields;
import org.neo4j.gds.algorithms.community.specificfields.ModularityOptimizationSpecificFields;
import org.neo4j.gds.algorithms.community.specificfields.StandardCommunityStatisticsSpecificFields;
import org.neo4j.gds.algorithms.community.specificfields.TriangleCountSpecificFields;
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
Expand All @@ -35,7 +34,6 @@
import org.neo4j.gds.config.ArrowConnectionInfo;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationWriteConfig;
import org.neo4j.gds.result.CommunityStatistics;
import org.neo4j.gds.result.StatisticsComputationInstructions;
import org.neo4j.gds.scc.SccAlphaWriteConfig;
Expand Down Expand Up @@ -133,52 +131,6 @@ public NodePropertyWriteResult<AlphaSccSpecificFields> alphaScc(

}

public NodePropertyWriteResult<ModularityOptimizationSpecificFields> modularityOptimization(
String graphName,
ModularityOptimizationWriteConfig configuration,
StatisticsComputationInstructions statisticsComputationInstructions
) {
// 1. Run the algorithm and time the execution
var intermediateResult = runWithTiming(
() -> communityAlgorithmsFacade.modularityOptimization(graphName, configuration)
);
var algorithmResult = intermediateResult.algorithmResult;

Supplier<ModularityOptimizationSpecificFields> emptySupplier = () -> ModularityOptimizationSpecificFields.EMPTY;

return writeToDatabase(
algorithmResult,
configuration,
(result, config) -> CommunityCompanion.nodePropertyValues(
config.isIncremental(),
config.writeProperty(),
config.seedProperty(),
config.consecutiveIds(),
result.asNodeProperties(),
config.minCommunitySize(),
config.concurrency(),
() -> algorithmResult.graphStore().nodeProperty(config.seedProperty())
),
result -> result::communityId,
(result, componentCount, communitySummary) -> new ModularityOptimizationSpecificFields(
result.modularity(),
result.ranIterations(),
result.didConverge(),
result.asNodeProperties().nodeCount(),
componentCount,
communitySummary
),
statisticsComputationInstructions,
intermediateResult.computeMilliseconds,
emptySupplier,
"ModularityOptimizationWrite",
configuration.writeConcurrency(),
configuration.writeProperty(),
configuration.arrowConnectionInfo(),
configuration.resolveResultStore(algorithmResult.resultStore())
);
}

public NodePropertyWriteResult<TriangleCountSpecificFields> triangleCount(
String graphName,
TriangleCountWriteConfig config
Expand Down

This file was deleted.

4 changes: 2 additions & 2 deletions algo/src/main/java/org/neo4j/gds/louvain/Louvain.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ private ModularityOptimizationResult runModularityOptimization(Graph louvainGrap
concurrency,
DEFAULT_BATCH_SIZE,
DefaultPool.INSTANCE,
progressTracker
progressTracker,
terminationFlag
);
modularityOptimization.setTerminationFlag(terminationFlag);

return modularityOptimization.compute();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ public ModularityOptimization(
Concurrency concurrency,
int minBatchSize,
ExecutorService executor,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -113,6 +114,8 @@ public ModularityOptimization(
}

this.modularityManager = ModularityManager.create(graph, concurrency);

this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.neo4j.gds.k1coloring.K1ColoringAlgorithmFactory;
import org.neo4j.gds.k1coloring.K1ColoringBaseConfig;
import org.neo4j.gds.k1coloring.K1ColoringStreamConfigImpl;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;

Expand Down Expand Up @@ -82,7 +83,8 @@ public ModularityOptimization build(
parameters.concurrency(),
parameters.batchSize(),
DefaultPool.INSTANCE,
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.modularity.TestGraphs;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.logging.Log;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -123,7 +124,8 @@ private ModularityOptimizationResult compute(
concurrency,
minBatchSize,
DefaultPool.INSTANCE,
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
).compute();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.logging.Log;

import java.util.Optional;
Expand Down Expand Up @@ -212,7 +213,8 @@ private ModularityOptimizationResult compute(
concurrency,
minBatchSize,
DefaultPool.INSTANCE,
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
).compute();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.logging.Log;

import java.util.Optional;
Expand Down Expand Up @@ -283,7 +284,8 @@ private ModularityOptimizationResult compute(
concurrency,
minBatchSize,
DefaultPool.INSTANCE,
progressTracker
progressTracker,
TerminationFlag.RUNNING_TRUE
).compute();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ ModularityOptimizationResult modularityOptimization(Graph graph, ModularityOptim
parameters.concurrency(),
parameters.batchSize(),
DefaultPool.INSTANCE,
progressTracker
progressTracker,
terminationFlag
);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.louvain.LouvainResult;
import org.neo4j.gds.louvain.LouvainWriteConfig;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationResult;
import org.neo4j.gds.modularityoptimization.ModularityOptimizationWriteConfig;
import org.neo4j.gds.wcc.WccWriteConfig;

import java.util.Optional;
Expand All @@ -52,6 +54,7 @@
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.LabelPropagation;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.Leiden;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.Louvain;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.ModularityOptimization;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.WCC;

public final class CommunityAlgorithmsWriteModeBusinessFacade {
Expand Down Expand Up @@ -198,6 +201,24 @@ public <RESULT> RESULT louvain(
);
}

public <RESULT> RESULT modularityOptimization(
GraphName graphName,
ModularityOptimizationWriteConfig configuration,
ResultBuilder<ModularityOptimizationWriteConfig, ModularityOptimizationResult, RESULT, NodePropertiesWritten> resultBuilder
) {
var writeStep = new ModularityOptimizationWriteStep(writeToDatabase, configuration);

return algorithmProcessingTemplate.processAlgorithm(
graphName,
configuration,
ModularityOptimization,
estimationFacade::modularityOptimization,
graph -> algorithms.modularityOptimization(graph, configuration),
Optional.of(writeStep),
resultBuilder
);
}

public <RESULT> RESULT wcc(
GraphName graphName,
WccWriteConfig configuration,
Expand Down
Loading

0 comments on commit e385875

Please sign in to comment.