This notebook tries to reproduce and understand the results from Graph Neural Network Bandits. Please see the full repo for source code and utils.
GNN model definition
The authors use GNNs to model a node permutation-invariant reward function. This will be a single conv-layer GCN with large width-$m$ hidden ReLU layers: $f_{\text{GNN}} (G): \mathcal{G} \rightarrow \mathbb{R}$. Weights are initialized from unit Gaussian.
Also described here, $f_\text{GNN}$ can be decomposed into
BLOCK:
- AGGREGATION: For node $j$ we have $\bar{h}_j = \frac{\sum_{\mathcal{i \in N(j)\cup j}} h_i}{c}$, where $c$ ensures this vector is unit-norm. We will precompute this though, meaning the input to the model is actually $\bar{h}_\mathcal{G}$.
- TRANSFORMATION: $\frac{1}{\sqrt{m}}\sigma(f_\text{NN}(\bar{h}))$ is a ReLU layer, then rescaled with $\frac{1}{\sqrt{m}}$ to ensure convergence and a closed form NTK.
READOUT: aka global mean pooling, computes the representation of the graph from the node feature matrix via $\frac{1}{N}\Sigma_j h_j^{(L)}$.
Note that input to the model is a precomputed node feature tensor $\bar{h}$ and not a Data
object.
from gnn import GNN
model = GNN(width=2048)
print(model)
GNN(
(f1): Linear(in_features=5, out_features=2048, bias=False)
(f2): Linear(in_features=2048, out_features=1, bias=False)
)
Synthetic graph domains
Generate a synthetic dataset as described in section 5. Experiments. Modest parameters are used to reduce computational cost.
- 6 graph domains with Erdős–Rényi graphs parameterized by ($p, N$)
- $p \in [0.1, 0.25, 0.75]$
- $N \in [5, 10]$
- sample 200 graphs per domain $\mathcal{G}_{p,N}$
- node features are in $\mathbb{R}^5$
ENGraph
contains both the full graph representation (adj matrix) as well as precomputed aggr. node feature vectors $\bar{h}_j$
from utils import ENGraph
# sample some graph domains (200 graphs per domain)
graphs = {}
for p in [0.1, 0.25, 0.75]:
for N in [5, 10]:
graphs[(p,N)] = []
for i in range(200):
graphs[(p,N)].append(ENGraph(p, N))
# plot sampled graphs and their pre-aggregated node feature matrices
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=[6, 3.5])
for idx, engraph in enumerate(graphs[(0.75, 5)][:3]):
g = torch_geometric.utils.to_networkx(engraph.graph, to_undirected=True)
axs[1, idx].imshow(engraph.hbar.detach().numpy())
axs[1, idx].set_yticks(ticks = range(5), labels=[f"x{i}" for i in range(5)])
axs[1, idx].set_xticks(ticks = range(5), labels=[i+1 for i in range(5)])
nx.draw_circular(g, ax=axs[0, idx], node_size=50)
GNN as a GP
As a sanity check, see whether large-width GNN does have approximately gaussian distributed outputs, affirming that in the infinite-width limit the GNN should converge to a gaussian process.
# init GNN
model = GNN(width=2048)
model.eval()
# reinit and re-evaluate 1000 times
hbar_test = graphs[(0.25, 5)][0].hbar
outs = []
for _ in range(1000):
model.reinit()
y = model(hbar_test)
outs.append(y.detach().squeeze().item())
# show distribution of model outputs is normal
plt.figure(figsize=[4,2])
plt.hist(outs)
plt.show()
Compute empirical G-NTK
Here we compute the empirical NTK for a small subset of graph domain $\mathcal{G}_{0.25, 10}$ using jacobian contraction, as here.
- Compute jacobians using
functorch
: $\mathbf{g}(G;\mathbf{\theta}_0) = \nabla_\theta{f_\text{GNN}(G)}$ - Compute NTK matrix for 5 graphs: $\mathbf{K} = \mathbf{g} \mathbf{g}^\top$
In the paper they explicitly show the relationship between the “vanilla” NTK and Graph-NTK by decomposing the GNN into FFNNs, making the deterministic Graph-NTK a function of the NTK.
- $\bar{h}_j$ aggregated node feature vec for node $j$
- $f_\text{GNN}(G) = \frac{1}{N} \sum^N_{j=1} f_\text{NN}(\bar{h}_j)$, i.e. single conv-layer GNN is equivalent to an FFNN over aggregated node features
- $k_\text{GNN} = \frac{1}{N^2} \Sigma_{j, j’ = 1}^N k_\text{NN}(\bar{h}_j, \bar{h}_{j’})$, GNN-NTK defined as a function of NTK
In practice, for these experiments we use the empirical NTK due to incompatibility of torch
with the neural-tangents
library.
from gnn import GNN
# init GNN
model = GNN(width=2048)
model.eval()
# evaluate one example
hbar_test = graphs[(0.25, 5)][0].hbar
y = model(hbar_test)
if y is not None:
print(f"output shape: {y.size()}")
print("predicted reward:", y.detach().item())
output shape: torch.Size([1, 1])
predicted reward: 0.28805238008499146
Plotting 5 graphs from domain $\mathcal{G}_{0.25, 10}$
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=[9, 3.5])
for idx, engraph in enumerate(graphs[(0.25, 10)][:5]):
g = torch_geometric.utils.to_networkx(engraph.graph, to_undirected=True)
axs[1, idx].imshow(engraph.hbar.detach().numpy())
axs[1, idx].set_yticks([])
axs[1, idx].set_xticks([])
axs[1, 0].set_yticks(ticks = range(10), labels=[f"x{i}" for i in range(10)])
axs[1, 0].set_xticks(ticks = range(5), labels=[i+1 for i in range(5)])
nx.draw_circular(g, ax=axs[0, idx], node_size=50)
Compute NTK matrix for 5 graphs
$K_{ij} = k_\text{GNN} (G_i, G_j)$, kernel is scalar since output is scalar.
# sample 5 graphs from one domain (p=0.25, N=10)
train_G = torch.concat(
[g.hbar.unsqueeze(0) for g in graphs[(0.25, 10)][:5]]
)
# compute empirical NTK matrix for 5 graphs
K = model.batchNTK(train_G, train_G)
fig = plt.figure(figsize=[3,3])
img = plt.imshow(K.detach().numpy())
fig.colorbar(img)
plt.axis("off")
plt.show()
Gaussian Process reward
For the synthetic dataset, the authors learn a smooth “true” reward function for each domain using a GP $f: \mathcal{G} \rightarrow \mathbb{R}$. gpytorch
requires tensor objects so we use the precomputed aggregated node feature tensor as input: $f(\bar{h}): \mathbb{R}^{N \times d} \rightarrow \mathbb{R}$.
- Use a GP prior $f \sim \text{GP}(0, k_\text{GNN})$
- Learn posterior GP with 5 graphs from a particular domain $(G_i, y_i)_{i \leq 5} \in \mathcal{G}_{p,N}$, where $y_i$ are randomly sampled from $N(0,1)$
- Sample from smooth posterior GP to get $f(G_i)$
The final dataset is then $\mathcal{D}_{p,N} = \{(G_i, f(G_i))|\mathcal{G}_{p,N} \}$
Learn GP posterior and reward for each graph domain
from gp import GNTK, GP_reward
RESAMPLE = False
if RESAMPLE:
# learn GP posterior and sample reward for dataset
for domain in graphs.keys():
# get 5 samples from domain
train_x = torch.concat(
[g.hbar.unsqueeze(0) for g in graphs[domain][:5]]
).flatten(1)
train_y = torch.normal(torch.tensor([0.0]*5), torch.tensor([1.0]*5))
# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
gp = GP_reward(
train_x, train_y, likelihood,
model.batchNTK, x_shape=(domain[1], 5)
)
# set to train
gp.train()
likelihood.train()
# init optimizer and loss
optimizer = torch.optim.Adam(gp.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp)
print(f"fitting GP for domain G{domain} ", end="")
for i in range(100):
optimizer.zero_grad()
output = gp(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()
if i%10==0: print(".",end="")
# predict all graphs in domain
gp.eval()
test_x = torch.concat(
[g.hbar.unsqueeze(0) for g in graphs[domain]]
).flatten(1)
preds = gp(test_x).mean
# update dataset
print(" updating graphs")
for g, y in zip(graphs[domain], preds):
g.y = y
# save updated dataset
pickle.dump(graphs, open("../data/graphs.pkl", "wb"))
else:
# load precomputed dataset with GP reward
graphs = pickle.load(open("../data/graphs.pkl","rb"))
GNN-UCB
Background
The authors propose to use the Graph-NTK to balance exploitation, i.e. training of GNN, and exploration of arms, i.e. acquiring new samples $(G_i, y_i)$. In the lazy (overparameterized) regime, neural networks are essentially gaussian processes, which allows one to quantify the upper confidence bound (UCB) of a GNN with the NTK. For simplicity and understanding we will only implement a variant of NeuralUCB, GNN-UCB.
- $\text{UCB}(G; \mu, \sigma) = \mu(G) + \beta_t\sigma(G)$ is the acquisition function
- $\mu \triangleq f_\text{GNN}$ is straightforward, now how to quantify $\sigma$?
- recall GP posteriors admit a closed form for $k’ = k(x^*, x^*) - k^\top_{X,x^*}(K_{XX}+\sigma^2I)^{-1}k_{X,x^*}$
Brief overview of linear / kernelized bandits and Linear-UCB for understanding. Please refer to section 4. on regret analysis in the NeuralUCB paper, the original LinUCB paper, or a lecture on linear bandits for in-depth background.
- $y_t = f(\mathbf{x}_t) + \epsilon_t$. One observes the reward $y_t$ from a function with some noise. In this case just linear regression: $f(\mathbf{x}_t) = \mathbf{w}^\top \mathbf{x}$.
- Kernelized setting: assuming a ridge loss, we have the closed form estimate for the model: $\hat{\mathbf{w}}_t = (X_tX_t^\top + \lambda I)^{-1}X_t^\top y$. Notice that $K = X_tX_t^\top$ is the linear kernel matrix.
- The cumulative regret is $R_T = \sum^T_{t=1}\max_\mathbf{x} f(\mathbf{x}) - f(\mathbf{x}_t)$, this should be sublinear for the algorithm to converge.
- The error of predictions are bounded by the (scaled) standard deviation of the expected reward: $|f^* - f| \leq \beta_t \sqrt{\mathbf{x}_t^\top (\lambda I + K)^{-1} \mathbf{x}_t}$
- This bound can then be used in the vanilla UCB acquisition function $\text{UCB}(\mathbf{x}_t) = f(x_t) + \beta_t \sqrt{\mathbf{\mathbf{x}}_t^\top (\lambda I + K)^{-1} \mathbf{x}_t}$
For GNN-UCB:
- Can understand normalized jacobian $\mathbf{g}(G)/\sqrt{m}$ as a feature map $\phi(G)$.
- Here $K = \mathbf{gg}^\top /{m}$, is the kernel matrix
$$\text{GNN-UCB}(G_t) = f_\text{GNN}(G_t;\theta_{t-1}) + \beta_t \sqrt{\mathbf{g}(G_t;\theta_{t-1})^\top(\lambda I+K)^{-1}\mathbf{g}(G_t;\theta_{t-1})/m}$$
Practical implementation
Using modified instructions specified in Appendix D.2, D.3:
- rounds $T = 160$
- “explore” for $T_0$ steps with random samples of $G_i$ to pre-train
- for the subsequent $T_1$ steps train but re-init model parameters at each step
- for any remaining steps train only every 20 steps
- use $\beta = 0.3$ and $\lambda=0.0025$
GNN training:
- use $m=2048$, $L=2$ with gaussian weight init for the GNN
- use
MSELoss
: $\mathcal{L}(\theta) = \frac{1}{t}\sum_{i<t} (f_\text{GNN}(G_i;\theta_t) - y_i)^2$, no L2-reg since we don’t use weight decay: $m\lambda |\theta - \theta^{(0)}|_2^2$ - use
Adam
optimizer (lr=0.001
) instead ofSGD
(not theoretically correct but practical) - train network for gradient steps $J_t = \min J$ such that $\mathcal{L}(\theta_{t-1})\leq 10^{-4}$ or the relative change in loss is less than $10^{-3}$
Plot model prediction and UCB
import collections
from utils import GraphData
SIZE = 25
model.reinit()
ucb = GNNUCB(model, beta=.5)
# load data
graphs = pickle.load(open("../data/graphs.pkl","rb"))
data = GraphData(graphs[(0.25, 10)])
f_G = [model(g.hbar).item() for g in data.domain[:SIZE]]
y_G = [g.y.item() for g in data.domain[:SIZE]]
U_G = [ucb(f_G, g.hbar).item() for f_G, g in zip(f_G, data.domain[:SIZE])]
plt.figure()
plt.plot(f_G)
plt.plot(U_G, '--')
plt.plot(range(SIZE), y_G, '.')
plt.show()
import collections
from utils import GraphData
from bandit import GNNUCB
from bandit import explore, train
# init GNN
model = GNN( width=2048)
model.eval()
# init training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
loss_0 = torch.tensor(1e-3)
# load data
graphs = pickle.load(open("../data/graphs.pkl","rb"))
data = GraphData(graphs[(0.25, 10)])
y_max = torch.max(torch.tensor([g.y for g in data.domain]))
# init acquisition function
ucb = GNNUCB(model)
# track instant regret
regrets = []
# randomly explore domain to pretrain
print("random exploration (40 steps)")
for t in range(40):
sample_idx = torch.randint(len(data.domain)-1, size=(1,1))
data.train.add(data.domain[sample_idx])
regrets.append(y_max - y)
# "exploration and exploitation"
print("explore and exploit",end="")
for t in range(40, 160):
# explore (pick graphb from domain using prev model)
y = explore(data, model, ucb, t)
regrets.append(y_max - y)
# exploit (train GNN)
if t%5==0:
print(f"\nt={t}, regret={y_max-y:.3f}")
log = True
params_t = train(optimizer, criterion, loss_0, data, model, t, log)
log = False
random exploration (40 steps)
explore and exploit
t=40, regret=0.478
........................................
t=45, regret=0.210
.............................................
t=50, regret=0.089
..................................................
t=55, regret=0.210
.......................................................
t=60, regret=0.253
............................................................
t=65, regret=0.089
.................................................................
t=70, regret=0.000
......................................................................
t=75, regret=0.265
...........................................................................
t=80, regret=0.000
................................................................................
t=85, regret=0.043
.....................................................................................
t=90, regret=0.054
..........................................................................................
t=95, regret=0.000
...............................................................................................
t=100, regret=0.000
....................................................................................................
t=105, regret=0.140
.........................................................................................................
t=110, regret=0.000
..............................................................................................................
t=115, regret=0.100
...................................................................................................................
t=120, regret=0.000
........................................................................................................................
t=125, regret=0.000
.............................................................................................................................
t=130, regret=0.000
..................................................................................................................................
t=135, regret=0.000
.......................................................................................................................................
t=140, regret=0.000
......................................................................
t=145, regret=0.089
.................................................................................................................................................
t=150, regret=0.000
........................................................................................................................................
t=155, regret=0.001
...........................................................................................................................................................
Plot cumulative regret
Rt = 0
ctr = []
for r in regrets[40:]:
Rt += r
ctr.append(Rt.item())
plt.figure(figsize=[4,4])
plt.plot(ctr)
plt.xlabel("t")
plt.ylabel("Rt")
plt.show()