I'm going to show that this works with Numpy and PyTorch without doing anything else:
import torch
import numpy as np
Agnostic cat
I need an agnostic cat function that's going to produce concatenated arrays or tensors regardless of what's passed. Also, using this hack so it doesn't care if Numpy or PyTorch aren't installed.
Showing that this works with either arrays or tensors:
np.random.seed(0)
xs = [np.random.randn(2,2) for _ in range(4)]
c = cat(xs, 1)
c
xs = [torch.tensor(x) for x in xs]
_c = cat(xs, 1)
assert np.abs(c - _c.numpy()).max() < 1e-3
_c
Pairing
This is closely based on Kai Arulkumaran's relation function so I'm going to test the results of einops operations against that code. The first function I need concatenates the cartesian product of pairs on the trailing dimension. Example of what a cartesian product is:
import urllib
from IPython.display import display,SVG
u = "https://upload.wikimedia.org/wikipedia/commons/4/4e/Cartesian_Product_qtl1.svg"
with urllib.request.urlopen(u) as f:
svg_string = f.read()
display(SVG(svg_string))
It might be possible to do this with torch.cartesian_prod and gather but it also needs to be batched. Instead, I do it here using repeat.
torch.manual_seed(0)
x = torch.randn(4,8,16)
Copying Kai's code and using einops:
def kai_prodpair(x):
b, o, c = x.shape
return torch.cat((x.unsqueeze(1).expand(b, o, o, c).contiguous().view(b, o * o, c),
x.unsqueeze(2).expand(b, o, o, c).contiguous().view(b, o * o, c)), 2)
def prodpair(x):
b, o, c = x.shape
return cat([repeat(x, 'b o c -> b (m o) c', m=o),
repeat(x, 'b o c -> b (o m) c', m=o)], 2)
I'm going to need to check if a lot of tensors are equal so we'll need this utility function:
def allclose(a, b, epsilon=1e-3):
return torch.abs(a - b).max() < epsilon
assert allclose(kai_prodpair(x), prodpair(x))
I wanted to check if this was precisely a cartesian product, but it's not because of the order of the arguments into cat
.
def cartesian_trailing(x):
out = []
b, o, c = x.shape
for i in range(b):
for j in range(o):
for k in range(o):
a, b = x[i, j], x[i, k]
out.append(rearrange(cat([a,b], 0), 'c2 -> () c2'))
return rearrange(cat(out, 0), '(b osq) c2 -> b osq c2', osq=o**2, c2=2*c)
eq = allclose(prodpair(x), cartesian_trailing(x))
print(f"Is prodpair a cartesian product? {'yes' if eq.item() else 'no'}")
So, I'll reverse that for the implementation, just so I can say it's really a cartesian product.
err = allclose(prodpair(x), cartesian_trailing(x))
assert err
print(f"Is prodpair a cartesian product now? {'yes' if err.item() else 'no'}")
def kai_append_embedding(pairs, embedding):
b, osq, c2 = pairs.shape
return torch.cat((pairs, embedding.unsqueeze(1).expand(b, osq, embedding.size(1))), 2)
Checking the einops version is correct:
b, o, c = x.shape
pairs = prodpair(x)
embedding = torch.randn(b, c)
assert allclose(append_embedding(pairs, embedding), kai_append_embedding(pairs, embedding))
Applying g
g
is a function applied to all pairs and their embeddings. It's assumed to take a two dimensional tensor as input so the tensor is rearranged before and after applying it. All that's left to compute the relational function is to take the sum of the resulting representations over all pairs. I've added the option to reduce using mean
instead if required using the kwarg reduction
.
def kai_relation(input, g, embedding=None, max_pairwise=None):
r"""Applies an all-to-all pairwise relation function to a set of objects.
See :class:`~torch.nn.Relation` for details.
"""
# Batch size, number of objects, feature size
b, o, c = input.size()
# Create pairwise matrix
pairs = torch.cat((input.unsqueeze(1).expand(b, o, o, c).contiguous().view(b, o * o, c),
input.unsqueeze(2).expand(b, o, o, c).contiguous().view(b, o * o, c)), 2)
# Append embedding if provided
if embedding is not None:
pairs = torch.cat((pairs, embedding.unsqueeze(1).expand(b, o ** 2, embedding.size(1))), 2)
# Calculate new feature size
c = pairs.size(2)
# Pack into batches
pairs = pairs.view(b * o ** 2, c)
# Pass through g
if max_pairwise is None:
output = g(pairs)
else:
outputs = []
for batch in range(0, b * o ** 2, max_pairwise):
outputs.append(g(pairs[batch:batch + max_pairwise]))
output = torch.cat(outputs, 0)
# Unpack
output = output.view(b, o ** 2, output.size(1)).sum(1).squeeze(1)
return output
def dummy_g(x):
assert x.ndim == 2
return x
assert allclose(kai_relation(x, dummy_g), relation(x, dummy_g))
assert allclose(torch.tensor(relation(x.numpy(), dummy_g)), relation(x, dummy_g))