Skip to content

Commit

Permalink
Merge pull request #9288 from lassewesth/gs2
Browse files Browse the repository at this point in the history
migrate gs stream
  • Loading branch information
lassewesth authored Jul 2, 2024
2 parents 0326a75 + a779912 commit c9e5b95
Show file tree
Hide file tree
Showing 22 changed files with 356 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import org.neo4j.gds.algorithms.AlgorithmComputationResult;
import org.neo4j.gds.algorithms.StreamComputationResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageStreamConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
import org.neo4j.gds.embeddings.hashgnn.HashGNNStreamConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecResult;
Expand All @@ -47,18 +45,6 @@ public StreamComputationResult<Node2VecResult> node2Vec(
return createStreamComputationResult(result);
}

public StreamComputationResult<GraphSageResult> graphSage(
String graphName,
GraphSageStreamConfig config
) {
var result = this.nodeEmbeddingsAlgorithmsFacade.graphSage(
graphName,
config
);

return createStreamComputationResult(result);
}

public StreamComputationResult<HashGNNResult> hashGNN(
String graphName,
HashGNNStreamConfig config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,24 @@

import org.neo4j.gds.algorithms.estimation.AlgorithmEstimator;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainEstimateDefinition;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecMemoryEstimateDefinition;
import org.neo4j.gds.modelcatalogservices.ModelCatalogService;

import java.util.Optional;

import static org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver.resolveModel;

public class NodeEmbeddingsAlgorithmsEstimateBusinessFacade {

private final AlgorithmEstimator algorithmEstimator;
private final ModelCatalogService modelCatalogService;


public NodeEmbeddingsAlgorithmsEstimateBusinessFacade(
AlgorithmEstimator algorithmEstimator, ModelCatalogService modelCatalogService
AlgorithmEstimator algorithmEstimator
) {
this.algorithmEstimator = algorithmEstimator;
this.modelCatalogService = modelCatalogService;
}

public <C extends Node2VecBaseConfig> MemoryEstimateResult node2Vec(
Expand All @@ -60,24 +53,6 @@ public <C extends Node2VecBaseConfig> MemoryEstimateResult node2Vec(
);
}

public <C extends GraphSageBaseConfig> MemoryEstimateResult graphSage(
Object graphNameOrConfiguration,
C configuration,
boolean mutating
) {
var model = resolveModel(modelCatalogService.get(), configuration.username(), configuration.modelName());

return algorithmEstimator.estimate(
graphNameOrConfiguration,
configuration,
Optional.empty(),
new GraphSageMemoryEstimateDefinition(
model.trainConfig().toMemoryEstimateParameters(),
mutating
)
);
}

public <C extends GraphSageTrainConfig> MemoryEstimateResult graphSageTrain(
Object graphNameOrConfiguration,
C configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
import org.neo4j.gds.applications.algorithms.machinery.WriteNodePropertyService;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.ArrowConnectionInfo;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageWriteConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecWriteConfig;

import java.util.Optional;
Expand Down Expand Up @@ -77,30 +75,6 @@ public NodePropertyWriteResult<Node2VecSpecificFields> node2Vec(
);
}

public NodePropertyWriteResult<Long> graphSage(
String graphName,
GraphSageWriteConfig configuration
) {

var intermediateResult = AlgorithmRunner.runWithTiming(
() -> nodeEmbeddingsAlgorithmsFacade.graphSage(graphName, configuration)
);

return writeToDatabase(
intermediateResult.algorithmResult,
configuration,
(result) -> NodePropertyValuesAdapter.adapt(result.embeddings()),
(result) -> intermediateResult.algorithmResult.graph().nodeCount(),
intermediateResult.computeMilliseconds,
() -> 0l,
"GraphSageWrite",
configuration.writeConcurrency(),
configuration.writeProperty(),
configuration.arrowConnectionInfo(),
configuration.resolveResultStore(intermediateResult.algorithmResult.resultStore())
);
}

<RESULT, CONFIG extends AlgoBaseConfig, ASF> NodePropertyWriteResult<ASF> writeToDatabase(
AlgorithmComputationResult<RESULT> algorithmResult,
CONFIG configuration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.WCC;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FastRP;

class FastRPWriteStep implements MutateOrWriteStep<FastRPResult, NodePropertiesWritten> {
private final WriteToDatabase writeToDatabase;
Expand All @@ -57,7 +57,7 @@ public NodePropertiesWritten execute(
resultStore,
configuration,
configuration,
WCC,
FastRP,
jobId,
nodePropertyValues
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageMutateConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.ml.core.EmbeddingUtils;

class GraphSageValidationHook implements PostLoadValidationHook {
private final GraphSageMutateConfig configuration;
private final GraphSageBaseConfig configuration;
private final Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model;

GraphSageValidationHook(
GraphSageMutateConfig configuration,
GraphSageBaseConfig configuration,
Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model
) {
this.configuration = configuration;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.embeddings;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
import org.neo4j.gds.applications.algorithms.machinery.MutateOrWriteStep;
import org.neo4j.gds.applications.algorithms.machinery.WriteToDatabase;
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageWriteConfig;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.GraphSage;

class GraphSageWriteStep implements MutateOrWriteStep<GraphSageResult, NodePropertiesWritten> {
private final WriteToDatabase writeToDatabase;
private final GraphSageWriteConfig configuration;

GraphSageWriteStep(WriteToDatabase writeToDatabase, GraphSageWriteConfig configuration) {
this.writeToDatabase = writeToDatabase;
this.configuration = configuration;
}

@Override
public NodePropertiesWritten execute(
Graph graph,
GraphStore graphStore,
ResultStore resultStore,
GraphSageResult result,
JobId jobId
) {
var nodePropertyValues = NodePropertyValuesAdapter.adapt(result.embeddings());

return writeToDatabase.perform(
graph,
graphStore,
resultStore,
configuration,
configuration,
GraphSage,
jobId,
nodePropertyValues
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.embeddings.fastrp.FastRPBaseConfig;
import org.neo4j.gds.embeddings.fastrp.FastRPMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageMemoryEstimateDefinition;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.model.ModelConfig;
Expand Down Expand Up @@ -60,4 +61,14 @@ public MemoryEstimation graphSage(ModelConfig configuration, boolean mutating) {

return new GraphSageMemoryEstimateDefinition(memoryEstimateParameters, mutating).memoryEstimation();
}

public MemoryEstimateResult graphSage(GraphSageBaseConfig configuration, Object graphNameOrConfiguration) {
var memoryEstimation = graphSage(configuration, false);

return algorithmEstimationTemplate.estimate(
configuration,
graphNameOrConfiguration,
memoryEstimation
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,28 @@
import org.neo4j.gds.applications.algorithms.machinery.ResultBuilder;
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.fastrp.FastRPStreamConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageStreamConfig;

import java.util.List;
import java.util.Optional;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FastRP;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.GraphSage;

public class NodeEmbeddingAlgorithmsStreamModeBusinessFacade {
private final GraphSageModelCatalog graphSageModelCatalog;
private final NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade;
private final NodeEmbeddingAlgorithms algorithms;
private final AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience;

public NodeEmbeddingAlgorithmsStreamModeBusinessFacade(
GraphSageModelCatalog graphSageModelCatalog,
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade,
NodeEmbeddingAlgorithms algorithms,
AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience
) {
this.graphSageModelCatalog = graphSageModelCatalog;
this.estimationFacade = estimationFacade;
this.algorithms = algorithms;
this.algorithmProcessingTemplateConvenience = algorithmProcessingTemplateConvenience;
Expand All @@ -56,4 +65,27 @@ public <RESULT> RESULT fastRP(
resultBuilder
);
}

public <RESULT> RESULT graphSage(
GraphName graphName,
GraphSageStreamConfig configuration,
ResultBuilder<GraphSageStreamConfig, GraphSageResult, RESULT, Void> resultBuilder
) {
var model = graphSageModelCatalog.get(configuration);
var relationshipWeightPropertyFromTrainConfiguration = model.trainConfig().relationshipWeightProperty();

var validationHook = new GraphSageValidationHook(configuration, model);

return algorithmProcessingTemplateConvenience.processAlgorithm(
relationshipWeightPropertyFromTrainConfiguration,
graphName,
configuration,
Optional.of(List.of(validationHook)),
GraphSage,
() -> estimationFacade.graphSage(configuration, false),
graph -> algorithms.graphSage(graph, configuration),
Optional.empty(),
resultBuilder
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,31 @@
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageWriteConfig;
import org.neo4j.gds.logging.Log;

import java.util.List;
import java.util.Optional;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FastRP;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.GraphSage;

public final class NodeEmbeddingAlgorithmsWriteModeBusinessFacade {
private final GraphSageModelCatalog graphSageModelCatalog;
private final NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade;
private final NodeEmbeddingAlgorithms algorithms;
private final AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience;
private final WriteToDatabase writeToDatabase;

private NodeEmbeddingAlgorithmsWriteModeBusinessFacade(
GraphSageModelCatalog graphSageModelCatalog,
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade,
NodeEmbeddingAlgorithms algorithms,
AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience,
WriteToDatabase writeToDatabase
) {
this.graphSageModelCatalog = graphSageModelCatalog;
this.estimationFacade = estimationFacade;
this.algorithms = algorithms;
this.algorithmProcessingTemplateConvenience = algorithmProcessingTemplateConvenience;
Expand All @@ -53,6 +62,7 @@ private NodeEmbeddingAlgorithmsWriteModeBusinessFacade(

public static NodeEmbeddingAlgorithmsWriteModeBusinessFacade create(
Log log,
GraphSageModelCatalog graphSageModelCatalog,
RequestScopedDependencies requestScopedDependencies,
WriteContext writeContext,
NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade,
Expand All @@ -63,6 +73,7 @@ public static NodeEmbeddingAlgorithmsWriteModeBusinessFacade create(
var writeToDatabase = new WriteToDatabase(writeNodePropertyService);

return new NodeEmbeddingAlgorithmsWriteModeBusinessFacade(
graphSageModelCatalog,
estimationFacade,
algorithms,
algorithmProcessingTemplateConvenience,
Expand All @@ -87,4 +98,29 @@ public <RESULT> RESULT fastRP(
resultBuilder
);
}

public <RESULT> RESULT graphSage(
GraphName graphName,
GraphSageWriteConfig configuration,
ResultBuilder<GraphSageWriteConfig, GraphSageResult, RESULT, NodePropertiesWritten> resultBuilder
) {
var model = graphSageModelCatalog.get(configuration);
var relationshipWeightPropertyFromTrainConfiguration = model.trainConfig().relationshipWeightProperty();

var validationHook = new GraphSageValidationHook(configuration, model);

var writeStep = new GraphSageWriteStep(writeToDatabase, configuration);

return algorithmProcessingTemplateConvenience.processAlgorithm(
relationshipWeightPropertyFromTrainConfiguration,
graphName,
configuration,
Optional.of(List.of(validationHook)),
GraphSage,
() -> estimationFacade.graphSage(configuration, false),
graph -> algorithms.graphSage(graph, configuration),
Optional.of(writeStep),
resultBuilder
);
}
}
Loading

0 comments on commit c9e5b95

Please sign in to comment.