Skip to content

Commit

Permalink
use NumpyTupleDataset for padding pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
knshnb committed Sep 19, 2019
1 parent 4a51521 commit 624f732
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 172 deletions.
7 changes: 0 additions & 7 deletions chainer_chemistry/dataset/preprocessors/gin_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def get_input_features(self, mol):
adj_array = construct_adj_matrix(mol, out_size=self.out_size)
return atom_array, adj_array

def create_dataset(self, *args, **kwargs):
# args: (atom_array, adj_array, label_array)
data_list = [
PaddingGraphData(x=x, adj=adj, y=y) for (x, adj, y) in zip(*args)
]
return PaddingGraphDataset(data_list)


class GINSparsePreprocessor(MolPreprocessor):
def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False):
Expand Down
3 changes: 2 additions & 1 deletion chainer_chemistry/dataset/preprocessors/mol_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rdkit import Chem

from chainer_chemistry.dataset.preprocessors.base_preprocessor import BasePreprocessor # NOQA
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset # NOQA


class MolPreprocessor(BasePreprocessor):
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_input_features(self, mol):
raise NotImplementedError

def create_dataset(self, *args, **kwargs):
raise NotImplementedError
return NumpyTupleDataset(*args)

def process(self, filepath):
# Not used now...
Expand Down
5 changes: 5 additions & 0 deletions chainer_chemistry/datasets/numpy_tuple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy

from chainer_chemistry.dataset.indexers.numpy_tuple_dataset_feature_indexer import NumpyTupleDatasetFeatureIndexer # NOQA
from chainer_chemistry.dataset.converters import concat_mols


class NumpyTupleDataset(object):
Expand Down Expand Up @@ -48,6 +49,10 @@ def __len__(self):
def get_datasets(self):
return self._datasets

@property
def converter(self):
return concat_mols

@property
def features(self):
"""Extract features according to the specified index.
Expand Down
4 changes: 1 addition & 3 deletions chainer_chemistry/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, out_dim, node_embedding=False, hidden_channels=16,
self.weight_tying = weight_tying
self.n_edge_types = n_edge_types

def __call__(self, batch, is_real_node=None):
def __call__(self, atom_array, adj, is_real_node=None):
"""forward propagation
Args:
Expand All @@ -91,8 +91,6 @@ def __call__(self, batch, is_real_node=None):
Returns:
numpy.ndarray: final molecule representation
"""
atom_array, adj = batch.x, batch.adj

if atom_array.dtype == self.xp.int32:
h = self.embed(atom_array) # (minibatch, max_num_atoms)
else:
Expand Down
4 changes: 2 additions & 2 deletions chainer_chemistry/models/prediction/graph_conv_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(
self.label_scaler = label_scaler
self.postprocess_fn = postprocess_fn or chainer.functions.identity

def __call__(self, dataset):
x = self.graph_conv(dataset)
def __call__(self, *args, **kwargs):
x = self.graph_conv(*args, **kwargs)

if self.mlp:
x = self.mlp(x)
Expand Down
6 changes: 5 additions & 1 deletion chainer_chemistry/models/prediction/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from chainer import cuda, Variable # NOQA
from chainer import reporter
from chainer_chemistry.models.prediction.base import BaseForwardModel
from chainer_chemistry.dataset.graph_dataset.base_graph_data import BaseGraphData # NOQA


class Regressor(BaseForwardModel):
Expand Down Expand Up @@ -102,7 +103,10 @@ def __call__(self, *args, **kwargs):
"""

# --- Separate `args` and `t` ---
if isinstance(self.label_key, int):
if isinstance(args[0], BaseGraphData):
# for graph dataset
t = args[0].y
elif isinstance(self.label_key, int):
if not (-len(args) <= self.label_key < len(args)):
msg = 'Label key %d is out of bounds' % self.label_key
raise ValueError(msg)
Expand Down
154 changes: 0 additions & 154 deletions chainer_chemistry/models/prediction/regressor2.py

This file was deleted.

10 changes: 6 additions & 4 deletions examples/qm9/train_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from chainer_chemistry import datasets as D
from chainer_chemistry.datasets import NumpyTupleDataset
from chainer_chemistry.links.scaler.standard_scaler import StandardScaler
from chainer_chemistry.models.prediction.regressor2 import Regressor
from chainer_chemistry.models.prediction.regressor import Regressor
from chainer_chemistry.models.prediction import set_up_predictor
from chainer_chemistry.utils import run_train

Expand Down Expand Up @@ -122,9 +122,11 @@ def main():
if args.scale == 'standardize':
print('Fit StandardScaler to the labels.')
scaler = StandardScaler()
y = numpy.array([data.y for data in dataset])
print('y', y.shape)
scaler.fit(y)
if isinstance(dataset, NumpyTupleDataset):
scaler.fit(dataset.get_datasets()[-1])
else:
y = numpy.array([data.y for data in dataset])
scaler.fit(y)
else:
print('No standard scaling was selected.')
scaler = None
Expand Down

0 comments on commit 624f732

Please sign in to comment.