This notebook tries to reproduce and understand the results from Graph Neural Network Bandits. Please see the full repo for source code and utils.


GNN model definition

The authors use GNNs to model a node permutation-invariant reward function. This will be a single conv-layer GCN with large width-$m$ hidden ReLU layers: $f_{\text{GNN}} (G): \mathcal{G} \rightarrow \mathbb{R}$. Weights are initialized from unit Gaussian.

Also described here, $f_\text{GNN}$ can be decomposed into

  • BLOCK:

    • AGGREGATION: For node $j$ we have $\bar{h}_j = \frac{\sum_{\mathcal{i \in N(j)\cup j}} h_i}{c}$, where $c$ ensures this vector is unit-norm. We will precompute this though, meaning the input to the model is actually $\bar{h}_\mathcal{G}$.
    • TRANSFORMATION: $\frac{1}{\sqrt{m}}\sigma(f_\text{NN}(\bar{h}))$ is a ReLU layer, then rescaled with $\frac{1}{\sqrt{m}}$ to ensure convergence and a closed form NTK.
  • READOUT: aka global mean pooling, computes the representation of the graph from the node feature matrix via $\frac{1}{N}\Sigma_j h_j^{(L)}$.

Note that input to the model is a precomputed node feature tensor $\bar{h}$ and not a Data object.

from gnn import GNN

model = GNN(width=2048)
print(model)
GNN(
  (f1): Linear(in_features=5, out_features=2048, bias=False)
  (f2): Linear(in_features=2048, out_features=1, bias=False)
)

Synthetic graph domains

Generate a synthetic dataset as described in section 5. Experiments. Modest parameters are used to reduce computational cost.

  • 6 graph domains with Erdős–Rényi graphs parameterized by ($p, N$)
  • $p \in [0.1, 0.25, 0.75]$
  • $N \in [5, 10]$
  • sample 200 graphs per domain $\mathcal{G}_{p,N}$
  • node features are in $\mathbb{R}^5$

ENGraph contains both the full graph representation (adj matrix) as well as precomputed aggr. node feature vectors $\bar{h}_j$

from utils import ENGraph

# sample some graph domains (200 graphs per domain)
graphs = {}
for p in [0.1, 0.25, 0.75]:
    for N in [5, 10]:
        graphs[(p,N)] = []
        for i in range(200):
            graphs[(p,N)].append(ENGraph(p, N))
            
# plot sampled graphs and their pre-aggregated node feature matrices
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=[6, 3.5])
for idx, engraph in enumerate(graphs[(0.75, 5)][:3]):
    g = torch_geometric.utils.to_networkx(engraph.graph, to_undirected=True)
    axs[1, idx].imshow(engraph.hbar.detach().numpy())
    axs[1, idx].set_yticks(ticks = range(5), labels=[f"x{i}" for i in range(5)])
    axs[1, idx].set_xticks(ticks = range(5), labels=[i+1 for i in range(5)])
    nx.draw_circular(g, ax=axs[0, idx], node_size=50)

svg

GNN as a GP

As a sanity check, see whether large-width GNN does have approximately gaussian distributed outputs, affirming that in the infinite-width limit the GNN should converge to a gaussian process.

# init GNN
model = GNN(width=2048)
model.eval()

# reinit and re-evaluate 1000 times
hbar_test = graphs[(0.25, 5)][0].hbar
outs = []
for _ in range(1000):
    model.reinit()
    y = model(hbar_test)
    outs.append(y.detach().squeeze().item())

# show distribution of model outputs is normal
plt.figure(figsize=[4,2])
plt.hist(outs)
plt.show()

svg


Compute empirical G-NTK

Here we compute the empirical NTK for a small subset of graph domain $\mathcal{G}_{0.25, 10}$ using jacobian contraction, as here.

  • Compute jacobians using functorch: $\mathbf{g}(G;\mathbf{\theta}_0) = \nabla_\theta{f_\text{GNN}(G)}$
  • Compute NTK matrix for 5 graphs: $\mathbf{K} = \mathbf{g} \mathbf{g}^\top$

In the paper they explicitly show the relationship between the “vanilla” NTK and Graph-NTK by decomposing the GNN into FFNNs, making the deterministic Graph-NTK a function of the NTK.

  • $\bar{h}_j$ aggregated node feature vec for node $j$
  • $f_\text{GNN}(G) = \frac{1}{N} \sum^N_{j=1} f_\text{NN}(\bar{h}_j)$, i.e. single conv-layer GNN is equivalent to an FFNN over aggregated node features
  • $k_\text{GNN} = \frac{1}{N^2} \Sigma_{j, j’ = 1}^N k_\text{NN}(\bar{h}_j, \bar{h}_{j’})$, GNN-NTK defined as a function of NTK

In practice, for these experiments we use the empirical NTK due to incompatibility of torch with the neural-tangents library.

from gnn import GNN

# init GNN
model = GNN(width=2048)
model.eval()

# evaluate one example
hbar_test = graphs[(0.25, 5)][0].hbar
y = model(hbar_test)
if y is not None:
    print(f"output shape: {y.size()}")
    print("predicted reward:", y.detach().item())
    output shape: torch.Size([1, 1])
    predicted reward: 0.28805238008499146

Plotting 5 graphs from domain $\mathcal{G}_{0.25, 10}$

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=[9, 3.5])
for idx, engraph in enumerate(graphs[(0.25, 10)][:5]):
    g = torch_geometric.utils.to_networkx(engraph.graph, to_undirected=True)
    axs[1, idx].imshow(engraph.hbar.detach().numpy())
    axs[1, idx].set_yticks([])
    axs[1, idx].set_xticks([])
    axs[1, 0].set_yticks(ticks = range(10), labels=[f"x{i}" for i in range(10)])
    axs[1, 0].set_xticks(ticks = range(5), labels=[i+1 for i in range(5)])
    nx.draw_circular(g, ax=axs[0, idx], node_size=50)

svg

Compute NTK matrix for 5 graphs

$K_{ij} = k_\text{GNN} (G_i, G_j)$, kernel is scalar since output is scalar.

# sample 5 graphs from one domain (p=0.25, N=10)
train_G = torch.concat(
    [g.hbar.unsqueeze(0) for g in graphs[(0.25, 10)][:5]]
)

# compute empirical NTK matrix for 5 graphs
K = model.batchNTK(train_G, train_G)
fig = plt.figure(figsize=[3,3])
img = plt.imshow(K.detach().numpy())
fig.colorbar(img)
plt.axis("off")
plt.show()

svg


Gaussian Process reward

For the synthetic dataset, the authors learn a smooth “true” reward function for each domain using a GP $f: \mathcal{G} \rightarrow \mathbb{R}$. gpytorch requires tensor objects so we use the precomputed aggregated node feature tensor as input: $f(\bar{h}): \mathbb{R}^{N \times d} \rightarrow \mathbb{R}$.

  • Use a GP prior $f \sim \text{GP}(0, k_\text{GNN})$
  • Learn posterior GP with 5 graphs from a particular domain $(G_i, y_i)_{i \leq 5} \in \mathcal{G}_{p,N}$, where $y_i$ are randomly sampled from $N(0,1)$
  • Sample from smooth posterior GP to get $f(G_i)$

The final dataset is then $\mathcal{D}_{p,N} = \{(G_i, f(G_i))|\mathcal{G}_{p,N} \}$

Learn GP posterior and reward for each graph domain

from gp import GNTK, GP_reward

RESAMPLE = False

if RESAMPLE:
    # learn GP posterior and sample reward for dataset
    for domain in graphs.keys():
        # get 5 samples from domain
        train_x = torch.concat(
            [g.hbar.unsqueeze(0) for g in graphs[domain][:5]]
        ).flatten(1)
        train_y = torch.normal(torch.tensor([0.0]*5), torch.tensor([1.0]*5))

        # initialize likelihood and model
        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        gp = GP_reward(
            train_x, train_y, likelihood, 
            model.batchNTK, x_shape=(domain[1], 5)
        )
        
        # set to train
        gp.train()
        likelihood.train()

        # init optimizer and loss
        optimizer = torch.optim.Adam(gp.parameters(), lr=0.1)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp)

        print(f"fitting GP for domain G{domain} ", end="")
        for i in range(100):
            optimizer.zero_grad()
            output = gp(train_x)
            loss = -mll(output, train_y)
            loss.backward()
            optimizer.step()
            if i%10==0: print(".",end="")

        # predict all graphs in domain
        gp.eval()
        test_x = torch.concat(
            [g.hbar.unsqueeze(0) for g in graphs[domain]]
        ).flatten(1)
        preds = gp(test_x).mean

        # update dataset
        print(" updating graphs")
        for g, y in zip(graphs[domain], preds):
            g.y = y
            
    # save updated dataset
    pickle.dump(graphs, open("../data/graphs.pkl", "wb"))
else:
    # load precomputed dataset with GP reward
    graphs = pickle.load(open("../data/graphs.pkl","rb"))

GNN-UCB

Background

The authors propose to use the Graph-NTK to balance exploitation, i.e. training of GNN, and exploration of arms, i.e. acquiring new samples $(G_i, y_i)$. In the lazy (overparameterized) regime, neural networks are essentially gaussian processes, which allows one to quantify the upper confidence bound (UCB) of a GNN with the NTK. For simplicity and understanding we will only implement a variant of NeuralUCB, GNN-UCB.

  • $\text{UCB}(G; \mu, \sigma) = \mu(G) + \beta_t\sigma(G)$ is the acquisition function
  • $\mu \triangleq f_\text{GNN}$ is straightforward, now how to quantify $\sigma$?
  • recall GP posteriors admit a closed form for $k’ = k(x^*, x^*) - k^\top_{X,x^*}(K_{XX}+\sigma^2I)^{-1}k_{X,x^*}$

Brief overview of linear / kernelized bandits and Linear-UCB for understanding. Please refer to section 4. on regret analysis in the NeuralUCB paper, the original LinUCB paper, or a lecture on linear bandits for in-depth background.

  • $y_t = f(\mathbf{x}_t) + \epsilon_t$. One observes the reward $y_t$ from a function with some noise. In this case just linear regression: $f(\mathbf{x}_t) = \mathbf{w}^\top \mathbf{x}$.
  • Kernelized setting: assuming a ridge loss, we have the closed form estimate for the model: $\hat{\mathbf{w}}_t = (X_tX_t^\top + \lambda I)^{-1}X_t^\top y$. Notice that $K = X_tX_t^\top$ is the linear kernel matrix.
  • The cumulative regret is $R_T = \sum^T_{t=1}\max_\mathbf{x} f(\mathbf{x}) - f(\mathbf{x}_t)$, this should be sublinear for the algorithm to converge.
  • The error of predictions are bounded by the (scaled) standard deviation of the expected reward: $|f^* - f| \leq \beta_t \sqrt{\mathbf{x}_t^\top (\lambda I + K)^{-1} \mathbf{x}_t}$
  • This bound can then be used in the vanilla UCB acquisition function $\text{UCB}(\mathbf{x}_t) = f(x_t) + \beta_t \sqrt{\mathbf{\mathbf{x}}_t^\top (\lambda I + K)^{-1} \mathbf{x}_t}$

For GNN-UCB:

  • Can understand normalized jacobian $\mathbf{g}(G)/\sqrt{m}$ as a feature map $\phi(G)$.
  • Here $K = \mathbf{gg}^\top /{m}$, is the kernel matrix

$$\text{GNN-UCB}(G_t) = f_\text{GNN}(G_t;\theta_{t-1}) + \beta_t \sqrt{\mathbf{g}(G_t;\theta_{t-1})^\top(\lambda I+K)^{-1}\mathbf{g}(G_t;\theta_{t-1})/m}$$

Practical implementation

Using modified instructions specified in Appendix D.2, D.3:

  • rounds $T = 160$
  • “explore” for $T_0$ steps with random samples of $G_i$ to pre-train
  • for the subsequent $T_1$ steps train but re-init model parameters at each step
  • for any remaining steps train only every 20 steps
  • use $\beta = 0.3$ and $\lambda=0.0025$

GNN training:

  • use $m=2048$, $L=2$ with gaussian weight init for the GNN
  • use MSELoss: $\mathcal{L}(\theta) = \frac{1}{t}\sum_{i<t} (f_\text{GNN}(G_i;\theta_t) - y_i)^2$, no L2-reg since we don’t use weight decay: $m\lambda |\theta - \theta^{(0)}|_2^2$
  • use Adam optimizer (lr=0.001) instead of SGD (not theoretically correct but practical)
  • train network for gradient steps $J_t = \min J$ such that $\mathcal{L}(\theta_{t-1})\leq 10^{-4}$ or the relative change in loss is less than $10^{-3}$

Plot model prediction and UCB

import collections
from utils import GraphData

SIZE = 25
model.reinit()
ucb = GNNUCB(model, beta=.5)
# load data
graphs = pickle.load(open("../data/graphs.pkl","rb"))
data = GraphData(graphs[(0.25, 10)])

f_G = [model(g.hbar).item() for g in data.domain[:SIZE]]
y_G = [g.y.item() for g in data.domain[:SIZE]]
U_G = [ucb(f_G, g.hbar).item() for f_G, g in zip(f_G, data.domain[:SIZE])]
plt.figure()
plt.plot(f_G)
plt.plot(U_G, '--')
plt.plot(range(SIZE), y_G, '.')
plt.show()

svg

import collections
from utils import GraphData
from bandit import GNNUCB
from bandit import explore, train

# init GNN
model = GNN( width=2048)
model.eval()

# init training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
loss_0 = torch.tensor(1e-3)

# load data
graphs = pickle.load(open("../data/graphs.pkl","rb"))
data = GraphData(graphs[(0.25, 10)])
y_max = torch.max(torch.tensor([g.y for g in data.domain]))
    
# init acquisition function
ucb = GNNUCB(model)

# track instant regret
regrets = []
    
# randomly explore domain to pretrain
print("random exploration (40 steps)")
for t in range(40):
    sample_idx = torch.randint(len(data.domain)-1, size=(1,1))
    data.train.add(data.domain[sample_idx])
    regrets.append(y_max - y)
    
# "exploration and exploitation"
print("explore and exploit",end="")
for t in range(40, 160):
    # explore (pick graphb from domain using prev model)
    y = explore(data, model, ucb, t)
    regrets.append(y_max - y)
    
    # exploit (train GNN)
    if t%5==0:
        print(f"\nt={t}, regret={y_max-y:.3f}")
        log = True
    params_t = train(optimizer, criterion, loss_0, data, model, t, log)
    log = False 
random exploration (40 steps)
explore and exploit
t=40, regret=0.478
........................................
t=45, regret=0.210
.............................................
t=50, regret=0.089
..................................................
t=55, regret=0.210
.......................................................
t=60, regret=0.253
............................................................
t=65, regret=0.089
.................................................................
t=70, regret=0.000
......................................................................
t=75, regret=0.265
...........................................................................
t=80, regret=0.000
................................................................................
t=85, regret=0.043
.....................................................................................
t=90, regret=0.054
..........................................................................................
t=95, regret=0.000
...............................................................................................
t=100, regret=0.000
....................................................................................................
t=105, regret=0.140
.........................................................................................................
t=110, regret=0.000
..............................................................................................................
t=115, regret=0.100
...................................................................................................................
t=120, regret=0.000
........................................................................................................................
t=125, regret=0.000
.............................................................................................................................
t=130, regret=0.000
..................................................................................................................................
t=135, regret=0.000
.......................................................................................................................................
t=140, regret=0.000
......................................................................
t=145, regret=0.089
.................................................................................................................................................
t=150, regret=0.000
........................................................................................................................................
t=155, regret=0.001
...........................................................................................................................................................

Plot cumulative regret

Rt = 0
ctr = []
for r in regrets[40:]:
    Rt += r
    ctr.append(Rt.item())

plt.figure(figsize=[4,4])
plt.plot(ctr)
plt.xlabel("t")
plt.ylabel("Rt")
plt.show()

svg