header

nn_freeze is to make a model untrainable. Commented out is a (broken) version that replaces parameters with constant buffers. Checking pytorch's model.train() and model.eval() source code, it seems that setting requires_grad to False is sufficient.

def nn_freeze(model): # for n,p in model.named_parameters(): # x = p.data # delattr(model,n) # model.register_buffer(n,x) for param in model.parameters(): param.requires_grad = False for l in model.children(): nn_freeze(l) # for l in model.modules(): nn_freeze(l) return model def nn_unfreeze(model): for param in model.parameters(): param.requires_grad = True for l in model.children(): nn_freeze(l) return model

Just to be sure that this is working as intended, here is a frozen example that is unable to train even after model.train()

def nn_freeze_check(): torch.cuda.empty_cache() σ = nn.GELU() model = nn_freeze(nn.Sequential(*[ nn.Conv2d(1 ,1,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Conv2d(1 ,32,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Conv2d(32,10,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Flatten(), nn.Linear(10,10), nn.Softmax(1), ]).to(device)) class DebugModule(nn.Module): def __init__(self): super(DebugModule,self).__init__() # need to add useless parameters to be optimized, otherwise torch complains and crashes self.L = nn.Linear(784,10).to(device) self.flat = nn.Flatten() def forward(self, x): return model(x) + 0*self.L(self.flat(x)) m = DebugModule() mnist_train(m,10)

it's handy to change an optimizer's learning rate sometimes

def optim_lr(optimizer,lr): for g in optimizer.param_groups: g['lr'] = lr

modulize makes an nn.Module out of a function.

def modulize(fun): """ Parameters ---------- fun : function """ class Modulized(nn.Module): def __init__(self): super(Modulized,self).__init__() def forward(self, x): return fun(x) return Modulized()

gradiator makes an nn.Module out of a function and a given custom jacobian

# https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa def gradiator(fun,grad_fun): """ Parameters ---------- fun : x ↦ y | grad_fun : x,grad_out ↦ grad_in | evaluation of jacobian for backpropagation """ class GradiatorFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) # save input for backward pass return fun(x) @staticmethod def backward(ctx, grad_output): if not ctx.needs_input_grad[0]: return None # if grad not required, don't compute x, = ctx.saved_tensors # restore input from context grad_input = grad_fun(x,grad_output) return grad_input return modulize(GradiatorFunction.apply)

approximator makes an nn.Module out of a function and a given differentiable approximation

def approximator(fun,app): # PNNFunction """ Parameters ---------- fun : function | app : diff function | differentiable approximation of fun """ class ApproximatorFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args): ctx.save_for_backward(*args) return fun(*args) @staticmethod def backward(ctx, grad_output): args = ctx.saved_tensors # https://pytorch.org/docs/stable/generated/torch.set_grad_enabled.html torch.set_grad_enabled(True) # vjp's output require_grad # https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html y = torch.autograd.functional.vjp(app, args, v=grad_output) torch.set_grad_enabled(False) return y[1] return ApproximatorFunction.apply

Here is an usage example

σ = modulize(lambda x: torch.sin(x**2)) σ = modulize(lambda x: x ,grad_out: grad_out) σ = modulize(lambda x: x**2,grad_out: 2*x*grad_out) σ = approximator(lambda x: torch.floor(x),lambda x: x) mnist_model = nn.Sequential(*[ nn.Conv2d(1 ,1,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Conv2d(1 ,32,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Conv2d(32,10,(3,3)), σ, nn.MaxPool2d((2,2)), nn.Flatten(), nn.Linear(10,10), nn.Softmax(1), ]).to(device)
class NN_ADD_PARAM(nn.Module): """ Parameters ---------- module : nn.Module | should take a single argument x ∈ ℝ^n or ℂ^n or similar θ_shape : shape | shape of stored parameters to be optimized """ def __init__(self,model,θ_shape=0): super(NN_ADD_PARAM,self).__init__() self.model = model self.θ = torch.nn.Parameter(torch.randn(θ_shape)) def forward(self,x): θθ = self.θ if len(x.shape) == 2: # if multiple samples, reshape θ accordingly sample_count = x.shape[0] θθ = self.θ.repeat(sample_count,1) X = torch.hstack((θθ,x)) # actual usage return self.model(X) # TODO : # example1 : NN_ADD_PARAM(nn_freeze(nn.conv2d(...))) class NN_APPROX(nn.Module): """ this uses app to compute gradient """ """ Parameters ---------- fun : function | the actual thing app : diff function | differentiable approximation of fun """ def __init__(self,fun,app): super(NN_APPROX,self).__init__() # self.fun = fun # self.app = app self.f = approximator(fun,app) def forward(self,x): # a = self.app(x) # f = self.fun(x) # return f.detach() + a - a.detach() return self.f(x) class NN_DIFF_APPROX(nn.Module): """ Parameters ---------- fun : function(x) | the actual thing jac : jacobian(x,grad) | jacobian approximation of fun """ def __init__(self,fun,app): super(NN_DIFF_APPROX,self).__init__() self.fun = fun self.app = app def forward(self,x): a = self.app(x) f = self.fun(x) return f.detach() + a - a.detach() def PNNFunction(physical,approximation): """ Parameters ---------- physical : function | some complex physical phenomenon (a physical driver, or a simulation if not available) approximation : differentiable function | an approximation of the physical function (typically a neural net) """ class PNNFunctionHelper(torch.autograd.Function): @staticmethod def forward(ctx, *args): ctx.save_for_backward(*args) return physical(*args) @staticmethod def backward(ctx, grad_output): args = ctx.saved_tensors # https://pytorch.org/docs/stable/generated/torch.set_grad_enabled.html torch.set_grad_enabled(True) # vjp's output require_grad # https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html y = torch.autograd.functional.vjp(approximation, args, v=grad_output) # FIXME torch.set_grad_enabled(False) return y[1] return PNNFunctionHelper.apply class PNN(nn.Module): """ Parameters ---------- physical : (θ,x...) ↦ anything | 1st arg = internal parameters ; 2nd,3rd... args = input approximation : diff function | same as PNNFunction θ_shape : shape | shape of stored parameters f_args : (θ,x...) ↦ args | how to use arguments within pnn """ def __init__(self,physical,approximation,θ_shape,f_args=lambda θ,*x: (θ,*x)): super(PNN,self).__init__() self.pnn = PNNFunction(physical,approximation) self.θ = torch.nn.Parameter(torch.randn(θ_shape)) self.f_agrs = f_args def forward(self,*args): return self.pnn(*self.f_agrs(self.θ,*args)) # return self.pnn(self.θ,*args) # f_args = lambda θ,*x: (torch.hstack((θ,)+x),) # g = lambda *x: x # g(*f_args(torch.zeros(3),torch.ones(3))) def PNNFunction(physical,approximation): class PNNFunctionHelper(torch.autograd.Function): @staticmethod def forward(ctx, x, θ): ctx.save_for_backward(x, θ) # TODO : save more context for backwarf return physical(x, θ) @staticmethod def backward(ctx, grad_output): x, θ = ctx.saved_tensors torch.set_grad_enabled(True) # TODO : this is for vjp's output to require_grad... # https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html y = torch.autograd.functional.vjp(approximation, (x, θ), v=grad_output) torch.set_grad_enabled(False) return y[1] return PNNFunctionHelper.apply

Physical module version 2

# based on https://laurent.sexy/nn/optics/pdf/deep-physical-neural-networks-trained-with-backpropagation-addendum.pdf p6-7 def PNNFunction(physical,approximation): """ Parameters ---------- physical : a function | a physical driver, or a simulation if not available approximation : diff function | differentiable approximation of physical """ class PNNFunctionHelper(torch.autograd.Function): @staticmethod def forward(ctx, *args): ctx.save_for_backward(*args) return physical(*args) @staticmethod def backward(ctx, grad_output): args = ctx.saved_tensors torch.set_grad_enabled(True) # make vjp's output require_grad # https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html y = torch.autograd.functional.vjp(approximation, args, v=grad_output) torch.set_grad_enabled(False) return y[1] return PNNFunctionHelper.apply class PNN(nn.Module): """ Parameters ---------- physical : (θ,x...) ↦ anything | same as PNNFunction approximation : diff function | same as PNNFunction θ_shape : shape | shape of stored parameters f_args : (θ,x...) ↦ args | how to use arguments within pnn """ def __init__(self,physical,approximation,θ_shape,f_args=lambda θ,*x: (θ,*x)): super(PNN,self).__init__() self.pnn = PNNFunction(physical,approximation) self.θ = torch.nn.Parameter(torch.randn(θ_shape)) self.f_agrs = f_args def forward(self,*args): return self.pnn(*self.f_agrs(self.θ,*args)) # return self.pnn(self.θ,*args) # f_args = lambda θ,*x: (torch.hstack((θ,)+x),) # g = lambda *x: x # g(*f_args(torch.zeros(3),torch.ones(3))) def PNNFunction(physical,approximation): class PNNFunctionHelper(torch.autograd.Function): @staticmethod def forward(ctx, x, θ): ctx.save_for_backward(x, θ) # TODO : save more context for backwarf return physical(x, θ) @staticmethod def backward(ctx, grad_output): x, θ = ctx.saved_tensors torch.set_grad_enabled(True) # TODO : this is for vjp's output to require_grad... # https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html # print("OMG",x.shape,θ.shape) y = torch.autograd.functional.vjp(approximation, (x, θ), v=grad_output) torch.set_grad_enabled(False) return y[1] return PNNFunctionHelper.apply

Physical module version 3 (jacobian modelization)

texture sampling

### interpolations ######################## def nearest (t,v1,v2): λ = (t>0.5).float(); return v1*(1-λ) + v2*λ def lerp (t,v1,v2): return v1*(1-t) + v2*t def smoothstep(t,v1,v2): λ = t*t*(3-2*t); return v1*(1-λ) + v2*λ ### sampler ############################### def sampler1d_repeat(tex,x,interpolation=lerp): l = len(tex) X=x*l # so that x ∈ [0,1] wil cover the whole texture Xi = X.int() - (X<=0)*1 t = X-Xi return interpolation(t, tex[Xi % l], tex[(Xi + 1) % l]) def sampler1d_clamp(tex,x,interpolation=lerp): l = len(tex) X=x*l # so that x ∈ [0,1] wil cover the whole texture Xi = X.int() - (X<=0)*1 t = X-Xi return interpolation(t, tex[torch.clamp(Xi,0,l-1)], tex[torch.clamp(Xi+1,0,l-1)]) def sampler2d_nearest_clamp(tex, coord): c, h, w = tex.shape # remap coord from [0,1]x[0,1] to [0,h]x[0,w] coord[:, 0] *= h-1 coord[:, 1] *= w-1 coord = torch.round(coord).long() coord_x = torch.clamp(coord[:, 1],0,w-1) # repeat last pixel coord_y = torch.clamp(coord[:, 0],0,h-1) # repeat last pixel pixels = tex[:, coord_y, coord_x] return pixels.permute(1,0) ### sampler example def sampler1d_example(): tex = torch.randn(12) res = 1024 plt.figure(figsize=(25,5)) plt.plot(torch.linspace(0,1,res),sampler1d_repeat(tex,torch.linspace(-1,2,res),lerp )) plt.plot(torch.linspace(1,2,res),sampler1d_repeat(tex,torch.linspace(-1,2,res),nearest )) plt.plot(torch.linspace(2,3,res),sampler1d_repeat(tex,torch.linspace(-1,2,res),smoothstep)) plt.figure(figsize=(25,5)) plt.plot(torch.linspace(0,1,res),sampler1d_clamp(tex,torch.linspace(-1,2,res),lerp )) plt.plot(torch.linspace(1,2,res),sampler1d_clamp(tex,torch.linspace(-1,2,res),nearest )) plt.plot(torch.linspace(2,3,res),sampler1d_clamp(tex,torch.linspace(-1,2,res),smoothstep))

a bunch of interesting functions

### polynoms ############################## def poly(coeff): # returns x ↦ ∑ coeff[n]⋅x^n N = range(len(coeff)) return lambda x: sum([coeff[n]*(x**n) for n in N]) ### fake gradient | floor ##### def fake_floor(): return modulize( lambda x: torch.floor(x), lambda x,grad: grad # fake the gradient to identity ) ### fake gradient | cantor ################ def cantor_function(x,n=6): if n==0: return x return ( (((0 <=x) & (x<=1/3)) * cantor_function(3*x,n-1)/2 ) + ((1/3< x) & (x< 2/3)) * (1/2) + (2/3 <=x) * (1/2+cantor_function(3*x-2,n-1)/2) ) cantor_precomputed = cantor_function(torch.linspace(0,1,3**4+1),4).to(device) def cantor(): return modulize(lambda x: sampler1d_clamp(cantor_precomputed,x,nearest)) # cantor = modulize(lambda x: torch.floor(x) + sampler1d_repeat(cantor_precomputed,x,nearest)) def fake_cantor(): return modulize( lambda x: torch.floor(x) + sampler1d_repeat(cantor_precomputed,x,nearest), lambda x,grad: grad # fake the gradient to identity ) ### fake gradient | sigmoid staircase ##### ### fake gradient | sigmoid sinus ######### ### fake gradient | binary sinus ########## ### heterogenous ########################## def multi(σs): def multi_(x): σcount = len(σs) out = torch.zeros_like(x) for (i,σ) in enumerate(σs): index = torch.arange(i,len(x)-1,σcount) out[index] = σ(x[index]) return out return modulize(multi_) ### arbitrary ℝn → ℝm function ############ # returns a *pixelated* arbitrary function [0,1]^{dim}→[0,1]^{dim} def random_function_rigid(dim_in=5,dim_out=5,res=16,device=device): # generate seed seed = torch.rand((res,)*dim_in+(dim_out,)).to(device) # smooth it a bit for i in range(dim_in): seed = (seed + seed.roll(1,dims=i))/2 for i in range(dim_in): seed = (seed + seed.roll(1,dims=i))/2 for i in range(dim_in): seed = (seed + seed.roll(1,dims=i))/2 def f(x): i = (((x*res).type(torch.long) % res) + res) % res if len(i.shape)==1: out = seed for j in i: out = out[j] return out elif len(i.shape)==2: # seed[i[:,0],i[:,1],i[:,2],...,i[:,dim_in]] # this doesn't work T_T ... so let's use stupid code for now # out = seed # for j in range(dim_in): out = out[i[:,j]] # return out if dim_in==1: return seed[i[:,0]] if dim_in==2: return seed[i[:,0],i[:,1]] if dim_in==3: return seed[i[:,0],i[:,1],i[:,2]] if dim_in==4: return seed[i[:,0],i[:,1],i[:,2],i[:,3]] if dim_in==5: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4]] if dim_in==6: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5]] if dim_in==7: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6]] if dim_in==8: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6],i[:,7]] if dim_in==9: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6],i[:,7],i[:,8]] φ = (1 + np.sqrt(5))/2 # def fractale(x): return f(x) + f(x*2)/2 + f(x*4)/4 + f(x*8)/8 def fractale(x): return f(x/φ) + f(x+φ*π) + f((x+π)*φ)/2 + f(-(x-2*φ)*φ*π)/4 + f(x*φ*φ*φ)/8 return fractale # # returns a *smooth* arbitrary somewhat smooth function [0,1]^{dim}→[0,1]^{dim} # def random_function(dim_in=5,dim_out=5,res=4): # # generate seed # seed = torch.rand((res,)*dim_in+(dim_out,)) # # smooth it a bit # for i in range(dim_in): seed = (seed + seed.roll(1,dims=i))/2 # def f(x): # i = (x*(res)).type(torch.long) % (res-1) # int part # λ =((x*(res)) % 1) # floating part # λ = 3*λ*λ - 2*λ*λ*λ # make it smoothstep # if dim_in==1: # λ0 = λ[:,0].repeat(dim_out,1).permute((1,0)) # return seed[i[:,0]]*(1-λ0) + λ0*(seed[i[:,0]+1]) # if dim_in==2: # λ0 = λ[:,0].repeat(dim_out,1).permute((1,0)) # λ1 = λ[:,1].repeat(dim_out,1).permute((1,0)) # a0 = seed[i[:,0],i[:,1] ]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1] ] # a1 = seed[i[:,0],i[:,1]+1]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1]+1] # return a0*(1-λ1) + λ1*a1 # if dim_in==3: # λ0 = λ[:,0].repeat(dim_out,1).permute((1,0)) # λ1 = λ[:,1].repeat(dim_out,1).permute((1,0)) # λ1 = λ[:,2].repeat(dim_out,1).permute((1,0)) # a00 = seed[i[:,0],i[:,1] ,i[:,2] ]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1] ,i[:,2] ] # a01 = seed[i[:,0],i[:,1]+1,i[:,2] ]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1]+1,i[:,2] ] # a10 = seed[i[:,0],i[:,1] ,i[:,2]+1]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1] ,i[:,2]+1] # a11 = seed[i[:,0],i[:,1]+1,i[:,2]+1]*(1-λ0) + λ0*seed[i[:,0]+1,i[:,1]+1,i[:,2]+1] # return seed[i[:,0],i[:,1],i[:,2]] # # if dim_in==4: return seed[i[:,0],i[:,1],i[:,2],i[:,3]] # # if dim_in==5: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4]] # # if dim_in==6: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5]] # # if dim_in==7: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6]] # # if dim_in==8: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6],i[:,7]] # # if dim_in==9: return seed[i[:,0],i[:,1],i[:,2],i[:,3],i[:,4],i[:,5],i[:,6],i[:,7],i[:,8]] # φ = (1 + np.sqrt(5))/2 # # def fractale(x): return f(x) + f(x*2)/2 + f(x*4)/4 + f(x*8)/8 # def fractale(x): return f(x/φ) + f(x) + f(x*φ)/2 + f(-x*φ*π)/4 + f(x*φ*φ*φ)/8 # return fractale def random_function_debug_1(res=64): im = random_function_rigid(1,3,device='cpu') pix = np.zeros((res,1,3)) xs = torch.linspace(-2,2,res) for i,x in enumerate(xs): pix[i,0,:] = im(torch.tensor([[x]])).tolist()[0] pix = pix.transpose() plt.plot(xs,pix[0][0]) plt.plot(xs,pix[1][0]) plt.plot(xs,pix[2][0]) def random_function_debug(res=64): im = random_function_rigid(2,3,device='cpu') pix = np.zeros((res,res,3)) for i,x in enumerate(torch.linspace(-2,2,res)): for j,y in enumerate(torch.linspace(-2,2,res)): pix[i,j,:] = im(torch.tensor([[x,y]])).tolist()[0] plt.imshow((pix+2)/4) # plt.figure(figsize=(25,5));random_function_debug_1(res=1024) # plt.figure(figsize=(10,10));random_function_debug(res=256) def random_function_light(dim_in=196,dim_out=196,device=device): L1 = nn_freeze(nn.Linear(dim_in,dim_in)).to(device) L2 = nn_freeze(nn.Linear(dim_in,dim_out)).to(device) L3 = nn_freeze(nn.Linear(dim_in,dim_out)).to(device) L2.weight.data = torch.relu(L2.weight.data) def f(x): x = torch.sigmoid(L2(torch.sin(L1(x)) + torch.sin(torch.relu(L2(x))))) x+= torch.exp(-(L3(torch.sin(x))**2)) return torch.sin(x) return modulize(f).to(device) # random_function_light(device='cpu')(torch.ones(10,196)).shape