Skip to content

Commit

Permalink
naive jit implementation for cooling (#1408)
Browse files Browse the repository at this point in the history
Co-authored-by: Andy Maloney <[email protected]>
  • Loading branch information
dmiracle and amaloney authored Mar 1, 2025
1 parent 9d07be1 commit 347c590
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions datashader/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from __future__ import annotations

import numba as nb
import numpy as np
import param
import scipy.sparse


class LayoutAlgorithm(param.ParameterizedFunction):
"""
Baseclass for all graph layout algorithms.
Expand Down Expand Up @@ -172,13 +172,25 @@ def _merge_points_with_nodes(nodes, points, params):
n[params.y] = points[:, 1]
return n


def cooling(matrix, points, temperature, params):
dt = temperature / float(params.iterations + 1)
displacement = np.zeros((params.dim, len(points)))
for iteration in range(params.iterations):
matrix = matrix.toarray()
c_params = {
'iterations': params.iterations,
'dim': params.dim,
'k': params.k,
'nohubs': params.nohubs,
'linlog': params.linlog
}
_cooling(matrix, points, temperature, **c_params)


@nb.jit(nopython=True, nogil=True, parallel=True)
def _cooling(matrix, points, temperature, iterations, dim, k, nohubs, linlog):
dt = temperature / float(iterations + 1)
displacement = np.zeros((dim, len(points)))
for iteration in range(iterations):
displacement *= 0
for i in range(matrix.shape[0]):
for i in nb.prange(matrix.shape[0]):
# difference between this row's node position and all others
delta = (points[i] - points).T

Expand All @@ -189,16 +201,16 @@ def cooling(matrix, points, temperature, params):
distance = np.where(distance < 0.01, 0.01, distance)

# the adjacency matrix row
ai = matrix[i].toarray()
ai = matrix[i]

# displacement "force"
dist = params.k * params.k / distance ** 2
dist = k * k / distance ** 2

if params.nohubs:
if nohubs:
dist = dist / float(ai.sum(axis=1) + 1)
if params.linlog:
if linlog:
dist = np.log(dist + 1)
displacement[:, i] += (delta * (dist - ai * distance / params.k)).sum(axis=1)
displacement[:, i] += (delta * (dist - ai * distance / k)).sum(axis=1)

# update points
length = np.sqrt((displacement ** 2).sum(axis=0))
Expand Down

0 comments on commit 347c590

Please sign in to comment.