Skip to content

Commit

Permalink
Merge pull request #9575 from lassewesth/pipe2
Browse files Browse the repository at this point in the history
migrate nc pipelines
  • Loading branch information
lassewesth authored Sep 3, 2024
2 parents 8852e89 + 1c207b0 commit d67611a
Show file tree
Hide file tree
Showing 23 changed files with 470 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.ConfigKeyValidation;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Name;
Expand All @@ -39,25 +32,17 @@

import static org.neo4j.procedure.Mode.READ;

public class NodeClassificationPipelineAddTrainerMethodProcs extends BaseProc {
public class NodeClassificationPipelineAddTrainerMethodProcs {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.nodeClassification.addLogisticRegression", mode = READ)
@Description("Add a logistic regression configuration to the parameter space of the node classification train pipeline.")
public Stream<NodePipelineInfoResult> addLogisticRegression(
@Name("pipelineName") String pipelineName,
@Name(value = "config", defaultValue = "{}") Map<String, Object> logisticRegressionClassifierConfig
) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);

var allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys();
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, logisticRegressionClassifierConfig.keySet());

var tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, TrainingMethod.LogisticRegression);
pipeline.addTrainerConfig(
tunableTrainerConfig
);

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return facade.pipelines().addLogisticRegression(pipelineName, logisticRegressionClassifierConfig);
}

@Procedure(name = "gds.beta.pipeline.nodeClassification.addRandomForest", mode = READ)
Expand All @@ -66,17 +51,7 @@ public Stream<NodePipelineInfoResult> addRandomForest(
@Name("pipelineName") String pipelineName,
@Name(value = "config") Map<String, Object> randomForestClassifierConfig
) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);

var allowedKeys = RandomForestClassifierTrainerConfig.DEFAULT.configKeys();
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, randomForestClassifierConfig.keySet());

var tunableTrainerConfig = TunableTrainerConfig.of(randomForestClassifierConfig, TrainingMethod.RandomForestClassification);
pipeline.addTrainerConfig(
tunableTrainerConfig
);

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return facade.pipelines().addRandomForest(pipelineName, randomForestClassifierConfig);
}

@Procedure(name = "gds.alpha.pipeline.nodeClassification.addRandomForest", mode = READ, deprecatedBy = "gds.beta.pipeline.nodeClassification.addRandomForest")
Expand All @@ -87,14 +62,9 @@ public Stream<NodePipelineInfoResult> addRandomForestAlpha(
@Name("pipelineName") String pipelineName,
@Name(value = "config") Map<String, Object> randomForestClassifierConfig
) {
executionContext()
.metricsFacade()
.deprecatedProcedures().called("gds.alpha.pipeline.nodeClassification.addRandomForest");
facade.deprecatedProcedures().called("gds.alpha.pipeline.nodeClassification.addRandomForest");
facade.log().warn("Procedure `gds.alpha.pipeline.nodeClassification.addRandomForest` has been deprecated, please use `gds.beta.pipeline.nodeClassification.addRandomForest`.");

executionContext()
.log()
.warn(
"Procedure `gds.alpha.pipeline.nodeClassification.addRandomForest` has been deprecated, please use `gds.beta.pipeline.nodeClassification.addRandomForest`.");
return addRandomForest(pipelineName, randomForestClassifierConfig);
}

Expand All @@ -104,13 +74,6 @@ public Stream<NodePipelineInfoResult> addMLP(
@Name("pipelineName") String pipelineName,
@Name(value = "config", defaultValue = "{}") Map<String, Object> mlpClassifierConfig
) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);

var allowedKeys = MLPClassifierTrainConfig.DEFAULT.configKeys();
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, mlpClassifierConfig.keySet());

pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, TrainingMethod.MLPClassification));

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return facade.pipelines().addMLP(pipelineName, mlpClassifierConfig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -33,18 +31,13 @@

import static org.neo4j.procedure.Mode.READ;

public class NodeClassificationPipelineConfigureAutoTuningProc extends BaseProc {
public class NodeClassificationPipelineConfigureAutoTuningProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.alpha.pipeline.nodeClassification.configureAutoTuning", mode = READ)
@Description("Configures the auto-tuning of the node classification pipeline.")
public Stream<NodePipelineInfoResult> configureAutoTuning(@Name("pipelineName") String pipelineName, @Name("configuration") Map<String, Object> configMap) {
PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);
return PipelineCompanion.configureAutoTuning(
username(),
pipelineName,
configMap,
pipeline -> new NodePipelineInfoResult(pipelineName, (NodeClassificationTrainingPipeline) pipeline)
);
return facade.pipelines().configureAutoTuning(pipelineName, configMap);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -34,20 +31,13 @@

import static org.neo4j.procedure.Mode.READ;

public class NodeClassificationPipelineConfigureSplitProc extends BaseProc {
public class NodeClassificationPipelineConfigureSplitProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.nodeClassification.configureSplit", mode = READ)
@Description("Configures the split of the node classification training pipeline.")
public Stream<NodePipelineInfoResult> configureSplit(@Name("pipelineName") String pipelineName, @Name("configuration") Map<String, Object> configMap) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, NodeClassificationTrainingPipeline.class);

var cypherConfig = CypherMapWrapper.create(configMap);
var config = NodePropertyPredictionSplitConfig.of(cypherConfig);

cypherConfig.requireOnlyKeysFrom(config.configKeys());

pipeline.setSplitConfig(config);

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return facade.pipelines().configureSplit(pipelineName, configMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.StringIdentifierValidations;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -32,22 +30,13 @@

import static org.neo4j.procedure.Mode.READ;

@SuppressWarnings("immutables:subtype")
public class NodeClassificationPipelineCreateProc extends BaseProc {

public static NodePipelineInfoResult create(String username, String pipelineName) {
StringIdentifierValidations.validateNoWhiteCharacter(pipelineName, "pipelineName");

var pipeline = new NodeClassificationTrainingPipeline();

PipelineCatalog.set(username, pipelineName, pipeline);

return new NodePipelineInfoResult(pipelineName, pipeline);
}
public class NodeClassificationPipelineCreateProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.nodeClassification.create", mode = READ)
@Description("Creates a node classification training pipeline in the pipeline catalog.")
public Stream<NodePipelineInfoResult> create(@Name("pipelineName") String pipelineName) {
return Stream.of(create(username(), pipelineName));
return facade.pipelines().createPipeline(pipelineName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public Stream<NodePipelineInfoResult> addNodeProperty(

pipeline.addNodePropertyStep(createNodePropertyStep(taskName, procedureConfig));

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}

@Procedure(name = "gds.alpha.pipeline.nodeRegression.selectFeatures", mode = READ)
Expand All @@ -74,6 +74,6 @@ public Stream<NodePipelineInfoResult> selectFeatures(
throw new IllegalArgumentException("The value of `featureProperties` is required to be a list of strings.");
}

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public Stream<NodePipelineInfoResult> addLogisticRegression(

pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, TrainingMethod.LinearRegression));

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}

@Procedure(name = "gds.alpha.pipeline.nodeRegression.addRandomForest", mode = READ)
Expand All @@ -68,6 +68,6 @@ public Stream<NodePipelineInfoResult> addRandomForest(

pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, TrainingMethod.RandomForestRegression));

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public Stream<NodePipelineInfoResult> configureAutoTuning(@Name("pipelineName")
username(),
pipelineName,
configMap,
pipeline -> new NodePipelineInfoResult(pipelineName, (NodeRegressionTrainingPipeline) pipeline)
pipeline -> NodePipelineInfoResult.create(pipelineName, (NodeRegressionTrainingPipeline) pipeline)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public Stream<NodePipelineInfoResult> configureSplit(

pipeline.setSplitConfig(config);

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ public Stream<NodePipelineInfoResult> create(@Name("pipelineName") String pipeli

PipelineCatalog.set(username(), pipelineName, pipeline);

return Stream.of(new NodePipelineInfoResult(pipelineName, pipeline));
return Stream.of(NodePipelineInfoResult.create(pipelineName, pipeline));
}
}
Loading

0 comments on commit d67611a

Please sign in to comment.