Skip to content

Commit

Permalink
Merge pull request #398 from knshnb/chainer-graph
Browse files Browse the repository at this point in the history
Chainer graph
  • Loading branch information
corochann authored Sep 25, 2019
2 parents 04577ba + 56218b1 commit 0aecad6
Show file tree
Hide file tree
Showing 35 changed files with 1,372 additions and 35 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ We test supporting the brand-new Graph Warp Module (GWM) [18]-attached models fo

The following datasets are currently supported:

### Chemical
- QM9 [7, 8]
- Tox21 [9]
- MoleculeNet [11]
- ZINC (only 250k dataset) [12, 13]
- User (own) dataset

### Network
- cora [21]
- citeseer [22]
- reddit [23]

## Research Projects

If you use Chainer Chemistry in your research, feel free to submit a
Expand Down Expand Up @@ -206,3 +212,9 @@ papers. Use the library at your own risk.
.

[20] Marc Brockschmidt, ``GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation'', arXiv:1906.12192 [cs.ML], 2019.

[21] McCallum, Andrew Kachites and Nigam, Kamal and Rennie, Jason and Seymore, Kristie, Automating the Construction of Internet Portals with Machine Learning. *Information Retrieval*, 2000.

[22] C. Lee Giles and Kurt D. Bollacker and Steve Lawrence, CiteSeer: An Automatic Citation Indexing System. *Proceedings of the Third ACM Conference on Digital Libraries*, 1998.

[23] William L. Hamilton and Zhitao Ying and Jure Leskovec, Inductive Representation Learning on Large Graphs. *Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017*
Empty file.
73 changes: 73 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/base_graph_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy
import chainer


class BaseGraphData(object):
"""Base class of graph data """

def __init__(self, *args, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def to_device(self, device):
"""Send self to `device`
Args:
device (chainer.backend.Device): device
Returns:
self sent to `device`
"""
for k, v in self.__dict__.items():
if isinstance(v, (numpy.ndarray)):
setattr(self, k, device.send(v))
elif isinstance(v, (chainer.utils.CooMatrix)):
data = device.send(v.data.array)
row = device.send(v.row)
col = device.send(v.col)
device_coo_matrix = chainer.utils.CooMatrix(
data, row, col, v.shape, order=v.order)
setattr(self, k, device_coo_matrix)
return self


class PaddingGraphData(BaseGraphData):
"""Graph data class for padding pattern
Args:
x (numpy.ndarray): input node feature
adj (numpy.ndarray): adjacency matrix
y (int or numpy.ndarray): graph or node label
"""

def __init__(self, x=None, adj=None, super_node=None, pos=None, y=None,
**kwargs):
self.x = x
self.adj = adj
self.super_node = super_node
self.pos = pos
self.y = y
self.n_nodes = x.shape[0]
super(PaddingGraphData, self).__init__(**kwargs)


class SparseGraphData(BaseGraphData):
"""Graph data class for sparse pattern
Args:
x (numpy.ndarray): input node feature
edge_index (numpy.ndarray): sources and destinations of edges
edge_attr (numpy.ndarray): attribution of edges
y (int or numpy.ndarray): graph or node label
"""

def __init__(self, x=None, edge_index=None, edge_attr=None,
pos=None, super_node=None, y=None, **kwargs):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.pos = pos
self.super_node = super_node
self.y = y
self.n_nodes = x.shape[0]
super(SparseGraphData, self).__init__(**kwargs)
134 changes: 134 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/base_graph_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import chainer
import numpy
from chainer._backend import Device
from chainer_chemistry.dataset.graph_dataset.base_graph_data import \
BaseGraphData
from chainer_chemistry.dataset.graph_dataset.feature_converters \
import batch_with_padding, batch_without_padding, concat, shift_concat, \
concat_with_padding, shift_concat_with_padding


class BaseGraphDataset(object):
"""Base class of graph dataset (list of graph data)"""
_pattern = ''
_feature_entries = []
_feature_batch_method = []

def __init__(self, data_list, *args, **kwargs):
self.data_list = data_list

def register_feature(self, key, batch_method, skip_if_none=True):
"""Register feature with batch method
Args:
key (str): name of the feature
batch_method (function): batch method
skip_if_none (bool, optional): If true, skip if `batch_method` is
None. Defaults to True.
"""
if skip_if_none and getattr(self.data_list[0], key, None) is None:
return
self._feature_entries.append(key)
self._feature_batch_method.append(batch_method)

def update_feature(self, key, batch_method):
"""Update batch method of the feature
Args:
key (str): name of the feature
batch_method (function): batch method
"""

index = self._feature_entries.index(key)
self._feature_batch_method[index] = batch_method

def __len__(self):
return len(self.data_list)

def __getitem__(self, item):
return self.data_list[item]

def converter(self, batch, device=None):
"""Converter
Args:
batch (list[BaseGraphData]): list of graph data
device (int, optional): specifier of device. Defaults to None.
Returns:
self sent to `device`
"""
if not isinstance(device, Device):
device = chainer.get_device(device)
batch = [method(name, batch, device=device) for name, method in
zip(self._feature_entries, self._feature_batch_method)]
data = BaseGraphData(
**{key: value for key, value in zip(self._feature_entries, batch)})
return data


class PaddingGraphDataset(BaseGraphDataset):
"""Graph dataset class for padding pattern"""
_pattern = 'padding'

def __init__(self, data_list):
super(PaddingGraphDataset, self).__init__(data_list)
self.register_feature('x', batch_with_padding)
self.register_feature('adj', batch_with_padding)
self.register_feature('super_node', batch_with_padding)
self.register_feature('pos', batch_with_padding)
self.register_feature('y', batch_without_padding)
self.register_feature('n_nodes', batch_without_padding)


class SparseGraphDataset(BaseGraphDataset):
"""Graph dataset class for sparse pattern"""
_pattern = 'sparse'

def __init__(self, data_list):
super(SparseGraphDataset, self).__init__(data_list)
self.register_feature('x', concat)
self.register_feature('edge_index', shift_concat)
self.register_feature('edge_attr', concat)
self.register_feature('super_node', concat)
self.register_feature('pos', concat)
self.register_feature('y', batch_without_padding)
self.register_feature('n_nodes', batch_without_padding)

def converter(self, batch, device=None):
"""Converter
add `self.batch`, which represents the index of the graph each node
belongs to.
Args:
batch (list[BaseGraphData]): list of graph data
device (int, optional): specifier of device. Defaults to None.
Returns:
self sent to `device`
"""
data = super(SparseGraphDataset, self).converter(batch, device=device)
if not isinstance(device, Device):
device = chainer.get_device(device)
data.batch = numpy.concatenate([
numpy.full((data.x.shape[0]), i, dtype=numpy.int)
for i, data in enumerate(batch)
])
data.batch = device.send(data.batch)
return data

# for experiment
# use converter for the normal use
def converter_with_padding(self, batch, device=None):
self.update_feature('x', concat_with_padding)
self.update_feature('edge_index', shift_concat_with_padding)
data = super(SparseGraphDataset, self).converter(batch, device=device)
if not isinstance(device, Device):
device = chainer.get_device(device)
max_n_nodes = max([data.x.shape[0] for data in batch])
data.batch = numpy.concatenate([
numpy.full((max_n_nodes), i, dtype=numpy.int)
for i, data in enumerate(batch)
])
data.batch = device.send(data.batch)
return data
115 changes: 115 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/feature_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy
from chainer.dataset.convert import _concat_arrays


def batch_with_padding(name, batch, device=None, pad=0):
"""Batch with padding (increase ndim by 1)
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = _concat_arrays(
[getattr(example, name) for example in batch], pad)
return device.send(feat)


def batch_without_padding(name, batch, device=None):
"""Batch without padding (increase ndim by 1)
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = _concat_arrays(
[getattr(example, name) for example in batch], None)
return device.send(feat)


def concat_with_padding(name, batch, device=None, pad=0):
"""Concat without padding (ndim does not increase)
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = batch_with_padding(name, batch, device=device, pad=pad)
a, b = feat.shape
return feat.reshape((a * b))


def concat(name, batch, device=None, axis=0):
"""Concat with padding (ndim does not increase)
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = numpy.concatenate([getattr(data, name) for data in batch],
axis=axis)
return device.send(feat)


def shift_concat(name, batch, device=None, shift_attr='x', shift_axis=1):
"""Concat with index shift (ndim does not increase)
Concatenate graphs into a big one.
Used for sparse pattern batching.
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
shift_index_array = numpy.cumsum(
numpy.array([0] + [getattr(data, shift_attr).shape[0]
for data in batch]))
feat = numpy.concatenate([
getattr(data, name) + shift_index_array[i]
for i, data in enumerate(batch)], axis=shift_axis)
return device.send(feat)


def shift_concat_with_padding(name, batch, device=None, shift_attr='x',
shift_axis=1):
"""Concat with index shift and padding (ndim does not increase)
Concatenate graphs into a big one.
Used for sparse pattern batching.
Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
max_n_nodes = max([data.x.shape[0] for data in batch])
shift_index_array = numpy.arange(0, len(batch) * max_n_nodes, max_n_nodes)
feat = numpy.concatenate([
getattr(data, name) + shift_index_array[i]
for i, data in enumerate(batch)], axis=shift_axis)
return device.send(feat)
Loading

0 comments on commit 0aecad6

Please sign in to comment.