Skip to content

Commit

Permalink
Merge pull request #9318 from lassewesth/gs3
Browse files Browse the repository at this point in the history
migrate graphsage train
  • Loading branch information
lassewesth authored Jul 5, 2024
2 parents b6027dd + 982d183 commit d412be8
Show file tree
Hide file tree
Showing 47 changed files with 530 additions and 400 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import org.neo4j.gds.algorithms.estimation.AlgorithmEstimator;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainEstimateDefinition;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
Expand Down Expand Up @@ -53,18 +51,6 @@ public <C extends Node2VecBaseConfig> MemoryEstimateResult node2Vec(
);
}

public <C extends GraphSageTrainConfig> MemoryEstimateResult graphSageTrain(
Object graphNameOrConfiguration,
C configuration
) {
return algorithmEstimator.estimate(
graphNameOrConfiguration,
configuration,
configuration.relationshipWeightProperty(),
new GraphSageTrainEstimateDefinition(configuration.toMemoryEstimateParameters())
);
}

public <C extends HashGNNConfig> MemoryEstimateResult hashGNN(
Object graphNameOrConfiguration,
C configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,23 @@

import org.neo4j.gds.algorithms.AlgorithmComputationResult;
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.algorithms.validation.AfterLoadValidation;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageAlgorithmFactory;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainAlgorithmFactory;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNFactory;
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
import org.neo4j.gds.embeddings.node2vec.Node2VecAlgorithmFactory;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecResult;
import org.neo4j.gds.modelcatalogservices.ModelCatalogService;

import java.util.List;
import java.util.Optional;

public class NodeEmbeddingsAlgorithmsFacade {

private final AlgorithmRunner algorithmRunner;
private final ModelCatalogService modelCatalogService;

public NodeEmbeddingsAlgorithmsFacade(
AlgorithmRunner algorithmRunner,
ModelCatalogService modelCatalogService
AlgorithmRunner algorithmRunner
) {
this.algorithmRunner = algorithmRunner;
this.modelCatalogService = modelCatalogService;
}

AlgorithmComputationResult<Node2VecResult> node2Vec(
Expand All @@ -67,46 +52,6 @@ AlgorithmComputationResult<Node2VecResult> node2Vec(
);
}

AlgorithmComputationResult<GraphSageResult> graphSage(
String graphName,
GraphSageBaseConfig config
) {
Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model = GraphSageModelResolver.resolveModel(
modelCatalogService.get(),
config.username(),
config.modelName()
);

AfterLoadValidation validationCondition = (graphStore) -> {
GraphSageTrainConfig trainConfig = model.trainConfig();
trainConfig.graphStoreValidation(
graphStore,
config.nodeLabelIdentifiers(graphStore),
config.internalRelationshipTypes(graphStore)
);
};

return algorithmRunner.run(
graphName,
config,
model.trainConfig().relationshipWeightProperty(),
new GraphSageAlgorithmFactory<>(modelCatalogService.get()),
List.of(validationCondition)
);
}

AlgorithmComputationResult<Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>> graphSageTrain(
String graphName,
GraphSageTrainConfig config
) {
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
new GraphSageTrainAlgorithmFactory()
);
}

AlgorithmComputationResult<HashGNNResult> hashGNN(
String graphName,
HashGNNConfig config
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public record GraphSageTrainParameters(
ActivationFunction activationFunction
) {

long numberOfBatches(long nodeCount) {
public long numberOfBatches(long nodeCount) {
return (long) Math.ceil(nodeCount / (double) batchSize);
}
public int batchesPerIteration(long nodeCount) {
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public enum Algorithm {
FilteredKNN,
FilteredNodeSimilarity,
GraphSage,
GraphSageTrain,
HarmonicCentrality,
K1Coloring,
KCore,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public enum LabelForProgressTracking {
FilteredKNN("Filtered K-Nearest Neighbours"),
FilteredNodeSimilarity("Filtered Node Similarity"),
GraphSage("GraphSage"),
GraphSageTrain("GraphSageTrain"),
HarmonicCentrality("HarmonicCentrality"),
K1Coloring("K1Coloring"),
KCore("KCoreDecomposition"),
Expand Down Expand Up @@ -92,6 +93,7 @@ public static LabelForProgressTracking from(Algorithm algorithm) {
case FilteredKNN -> FilteredKNN;
case FilteredNodeSimilarity -> FilteredNodeSimilarity;
case GraphSage -> GraphSage;
case GraphSageTrain -> GraphSageTrain;
case HarmonicCentrality -> HarmonicCentrality;
case K1Coloring -> K1Coloring;
case KCore -> KCore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetri
GraphSageModelTrainer.GraphSageTrainMetrics.class
);
}

void store(Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model) {
modelCatalog.set(model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,13 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.procedures.embeddings;
package org.neo4j.gds.applications.algorithms.embeddings;

import org.neo4j.gds.algorithms.TrainResult;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.procedures.embeddings.graphsage.GraphSageTrainResult;

public class GraphSageComputationalResultTransformer {


public static GraphSageTrainResult toTrainResult(
TrainResult<Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>> trainResult
) {

return new GraphSageTrainResult(trainResult.algorithmSpecificFields(), trainResult.trainMillis());
}
public interface GraphSageModelRepository {
void store(Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model);
}
Loading

0 comments on commit d412be8

Please sign in to comment.