src.environment.problem.SOO.COCO_BBOB.kan.MultKAN¶
Module Contents¶
Classes¶
KAN class |
Data¶
API¶
- class src.environment.problem.SOO.COCO_BBOB.kan.MultKAN.MultKAN(width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu')[source]¶
Bases:
torch.nn.ModuleKAN class
Attributes:¶
grid : int the number of grid intervals k : int spline order act_fun : a list of KANLayers symbolic_fun: a list of Symbolic_KANLayer depth : int depth of KAN width : list number of neurons in each layer. Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons. With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). mult_arity : int, or list of int lists multiplication arity for each multiplication node (the number of numbers to be multiplied) grid : int the number of grid intervals k : int the order of piecewise polynomial base_fun : fun residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x) symbolic_fun : a list of Symbolic_KANLayer Symbolic_KANLayers symbolic_enabled : bool If False, the symbolic front is not computed (to save time). Default: True. width_in : list The number of input neurons for each layer width_out : list The number of output neurons for each layer base_fun_name : str The base function b(x) grip_eps : float The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile) node_bias : a list of 1D torch.float node_scale : a list of 1D torch.float subnode_bias : a list of 1D torch.float subnode_scale : a list of 1D torch.float symbolic_enabled : bool when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero) affine_trainable : bool indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale) sp_trainable : bool indicate whether the overall magnitude of splines is trainable sb_trainable : bool indicate whether the overall magnitude of base function is trainable save_act : bool indicate whether intermediate activations are saved in forward pass node_scores : None or list of 1D torch.float node attribution score edge_scores : None or list of 2D torch.float edge attribution score subnode_scores : None or list of 1D torch.float subnode attribution score cache_data : None or 2D torch.float cached input data acts : None or a list of 2D torch.float activations on nodes auto_save : bool indicate whether to automatically save a checkpoint once the model is modified state_id : int the state of the model (used to save checkpoint) ckpt_path : str the folder to store checkpoints round : int the number of times rewind() has been called device : str
Initialization
initalize a KAN model
Args:¶
width : list of int Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs) With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs) grid : int number of grid intervals. Default: 3. k : int order of piecewise polynomial. Default: 3. mult_arity : int, or list of int lists multiplication arity for each multiplication node (the number of numbers to be multiplied) noise_scale : float initial injected noise to spline. base_fun : str the residual function b(x). Default: 'silu' symbolic_enabled : bool compute (True) or skip (False) symbolic computations (for efficiency). By default: True. affine_trainable : bool affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias grid_eps : float When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. grid_range : list/np.array of shape (2,)) setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True) sp_trainable : bool If true, scale_sp is trainable. Default: True. sb_trainable : bool If true, scale_base is trainable. Default: True. device : str device seed : int random seed save_act : bool indicate whether intermediate activations are saved in forward pass sparse_init : bool sparse initialization (True) or normal dense initialization. Default: False. auto_save : bool indicate whether to automatically save a checkpoint once the model is modified state_id : int the state of the model (used to save checkpoint) ckpt_path : str the folder to store checkpoints. Default: './model' round : int the number of times rewind() has been called device : str
Returns:¶
selfExample¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) checkpoint directory created: ./model saving model version 0.0
- to(device)[source]¶
move the model to device
Args:¶
device : str or deviceReturns:¶
selfExample¶
from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.to(device)
- initialize_from_another_model(another_model, x)[source]¶
initialize from another model of the same width, but their ‘grid’ parameter can be different. Note this is equivalent to refine() when we don’t want to keep another_model
Args:¶
another_model : MultKAN x : 2D torch.float
Returns:¶
selfExample¶
from kan import * model1 = KAN(width=[2,5,1], grid=3) model2 = KAN(width=[2,5,1], grid=10) x = torch.rand(100,2) model2.initialize_from_another_model(model1, x)
- refine(new_grid)[source]¶
grid refinement
Args:¶
new_grid : init the number of grid intervals after refinement
Returns:¶
a refined model : MultKANExample¶
from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) print(model.grid) x = torch.rand(100,2) model.get_act(x) model = model.refine(10) print(model.grid) checkpoint directory created: ./model saving model version 0.0 5 saving model version 0.1 10
- saveckpt(path='model')[source]¶
save the current model to files (configuration file and state file)
Args:¶
path : str the path where checkpoints are saved
Returns:¶
NoneExample¶
from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.saveckpt(‘./mark’)
There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state¶
- static loadckpt(path='model')[source]¶
load checkpoint from path
Args:¶
path : str the path where checkpoints are saved
Returns:¶
MultKANExample¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.saveckpt(‘./mark’) KAN.loadckpt(‘./mark’)
- copy()[source]¶
deepcopy
Args:¶
path : str the path where checkpoints are saved
Returns:¶
MultKANExample¶
from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) model2 = model.copy() model2.act_fun[0].coef.data *= 2 print(model2.act_fun[0].coef.data) print(model.act_fun[0].coef.data)
- rewind(model_id)[source]¶
rewind to an old version
Args:¶
model_id : str in format '{a}.{b}' where a is the round number, b is the version number in that round
Returns:¶
MultKANExample¶
Please refer to tutorials. API 12: Checkpoint, save & load model
- checkout(model_id)[source]¶
check out an old version
Args:¶
model_id : str in format '{a}.{b}' where a is the round number, b is the version number in that round
Returns:¶
MultKANExample¶
Same use as rewind, although checkout doesn’t change states
- update_grid_from_samples(x)[source]¶
update grid from samples
Args:¶
x : 2D torch.tensor inputs
Returns:¶
NoneExample¶
from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) print(model.act_fun[0].grid) x = torch.linspace(-10,10,steps=101)[:,None] model.update_grid_from_samples(x) print(model.act_fun[0].grid)
- update_grid(x)[source]¶
call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN
- initialize_grid_from_another_model(model, x)[source]¶
initialize grid from another model
Args:¶
model : MultKAN parent model x : 2D torch.tensor inputs
Returns:¶
NoneExample¶
from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) print(model.act_fun[0].grid) x = torch.linspace(-10,10,steps=101)[:,None] model2 = KAN(width=[1,1], grid=10, k=3, seed=0) model2.initialize_grid_from_another_model(model, x) print(model2.act_fun[0].grid)
- forward(x, singularity_avoiding=False, y_th=10.0)[source]¶
forward pass
Args:¶
x : 2D torch.tensor inputs singularity_avoiding : bool whether to avoid singularity for the symbolic branch y_th : float the threshold for singularity
Returns:¶
NoneExample1¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) x = torch.rand(100,2) model(x).shape
Example2¶
from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) x = torch.tensor([[1],[-0.01]]) model.fix_symbolic(0,0,0,’log’,fit_params_bool=False) print(model(x)) print(model(x, singularity_avoiding=True)) print(model(x, singularity_avoiding=True, y_th=1.))
- fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True)[source]¶
set (l,i,j) activation to be symbolic (specified by fun_name)
Args:¶
l : int layer index i : int input neuron index j : int output neuron index fun_name : str function name fit_params_bool : bool obtaining affine parameters through fitting (True) or setting default values (False) a_range : tuple sweeping range of a b_range : tuple sweeping range of b verbose : bool If True, more information is printed. random : bool initialize affine parameteres randomly or as [1,0,1,0] log_history : bool indicate whether to log history when the function is called
Returns:¶
None or r2 (coefficient of determination)Example 1¶
when fit_params_bool = False
model = KAN(width=[2,5,1], grid=5, k=3) model.fix_symbolic(0,1,3,’sin’,fit_params_bool=False) print(model.act_fun[0].mask.reshape(2,5)) print(model.symbolic_fun[0].mask.reshape(2,5))
Example 2¶
when fit_params_bool = True
model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) x = torch.normal(0,1,size=(100,2)) model(x) # obtain activations (otherwise model does not have attributes acts) model.fix_symbolic(0,1,3,’sin’,fit_params_bool=True) print(model.act_fun[0].mask.reshape(2,5)) print(model.symbolic_fun[0].mask.reshape(2,5))
- get_range(l, i, j, verbose=True)[source]¶
Get the input range and output range of the (l,i,j) activation
Args:¶
l : int layer index i : int input neuron index j : int output neuron index
Returns:¶
x_min : float minimum of input x_max : float maximum of input y_min : float minimum of output y_max : float maximum of output
Example¶
model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) x = torch.normal(0,1,size=(100,2)) model(x) # do a forward pass to obtain model.acts model.get_range(0,0,0)
- plot(folder='./figures', beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0)[source]¶
plot KAN
Args:¶
folder : str the folder to store pngs beta : float positive number. control the transparency of each activation. transparency = tanh(beta*l1). mask : bool If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions. mode : bool "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean). scale : float control the size of the diagram in_vars: None or list of str the name(s) of input variables out_vars: None or list of str the name(s) of output variables title: None or str title varscale : float the size of input variables
Returns:¶
FigureExample¶
see more interactive examples in demos
model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) x = torch.normal(0,1,size=(100,2)) model(x) # do a forward pass to obtain model.acts model.plot()
- reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)[source]¶
Get regularization
Args:¶
reg_metric : the regularization metric 'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward' lamb_l1 : float l1 penalty strength lamb_entropy : float entropy penalty strength lamb_coef : float coefficient penalty strength lamb_coefdiff : float coefficient smoothness strength
Returns:¶
reg_ : torch.floatExample¶
model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) x = torch.rand(100,2) model.get_act(x) model.reg(‘edge_forward_spline_n’, 1.0, 2.0, 1.0, 1.0)
- get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)[source]¶
Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.
- disable_symbolic_in_fit(lamb)[source]¶
during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)
- fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None)[source]¶
training
Args:¶
dataset : dic contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] opt : str "LBFGS" or "Adam" steps : int training steps log : int logging frequency lamb : float overall penalty strength lamb_l1 : float l1 penalty strength lamb_entropy : float entropy penalty strength lamb_coef : float coefficient magnitude penalty strength lamb_coefdiff : float difference of nearby coefficits (smoothness) penalty strength update_grid : bool If True, update grid regularly before stop_grid_update_step grid_update_num : int the number of grid updates before stop_grid_update_step start_grid_update_step : int no grid updates before this training step stop_grid_update_step : int no grid updates after this training step loss_fn : function loss function lr : float learning rate batch : int batch size, if -1 then full. save_fig_freq : int save figure every (save_fig_freq) steps singularity_avoiding : bool indicate whether to avoid singularity for the symbolic part y_th : float singularity threshold (anything above the threshold is considered singular and is softened in some ways) reg_metric : str regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'} metrics : a list of metrics (as functions) the metrics to be computed in training display_metrics : a list of functions the metric to be displayed in tqdm progress bar
Returns:¶
results : dic results['train_loss'], 1D array of training losses (RMSE) results['test_loss'], 1D array of test losses (RMSE) results['reg'], 1D array of regularization other metrics specified in metrics
Example¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot()
Most examples in toturals involve the fit() method. Please check them for useness.¶
- prune_node(threshold=0.01, mode='auto', active_neurons_id=None, log_history=True)[source]¶
pruning nodes
Args:¶
threshold : float if the attribution score of a neuron is below the threshold, it is considered dead and will be removed mode : str 'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in.
Returns:¶
pruned network : MultKANExample¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune_node() model.plot()
- prune_edge(threshold=0.03, log_history=True)[source]¶
pruning edges
Args:¶
threshold : float if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
Returns:¶
pruned network : MultKANExample¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune_edge() model.plot()
- prune(node_th=0.01, edge_th=0.03)[source]¶
prune (both nodes and edges)
Args:¶
node_th : float if the attribution score of a node is below node_th, it is considered dead and will be set to zero. edge_th : float if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
Returns:¶
pruned network : MultKANExample¶
from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune() model.plot()
- prune_input(threshold=0.01, active_inputs=None, log_history=True)[source]¶
prune inputs
Args:¶
threshold : float if the attribution score of the input feature is below threshold, it is considered irrelevant. active_inputs : None or list if a list is passed, the manual mode will disregard attribution score and prune as instructed.
Returns:¶
pruned network : MultKANExample1¶
automatic
from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot() model = model.prune_input() model.plot()
Example2¶
automatic
from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot() model = model.prune_input(active_inputs=[0,1]) model.plot()
- remove_node(l, i, mode='all', log_history=True)[source]¶
remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
- attribute(l=None, i=None, out_score=None, plot=True)[source]¶
get attribution scores
Args:¶
l : None or int layer index i : None or int neuron index out_score : None or 1D torch.float specify output scores plot : bool when plot = True, display the bar show
Returns:¶
attribution scoresExample¶
from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.attribute() model.feature_score
- feature_interaction(l, neuron_th=0.01, feature_th=0.01)[source]¶
get feature interaction
Args:¶
l : int layer index neuron_th : float threshold to determine whether a neuron is active feature_th : float threshold to determine whether a feature is active
Returns:¶
dictionaryExample¶
from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.attribute() model.feature_interaction(1)
- suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=lambda x: ..., c_loss_fun=lambda x: ..., weight_simple=0.8)[source]¶
suggest symbolic function
Args:¶
l : int layer index i : int neuron index in layer l j : int neuron index in layer j a_range : tuple search range of a b_range : tuple search range of b lib : list of str library of candidate symbolic functions topk : int the number of top functions displayed verbose : bool if verbose = True, print more information r2_loss_fun : functoon function : r2 -> "bits" c_loss_fun : fun function : c -> 'bits' weight_simple : float the simplifty weight: the higher, more prefer simplicity over performance
Returns:¶
best_name (str), best_fun (function), best_r2 (float), best_c (float)Example¶
from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.suggest_symbolic(0,1,0)
- auto_symbolic(a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1)[source]¶
automatic symbolic regression for all edges
Args:¶
a_range : tuple search range of a b_range : tuple search range of b lib : list of str library of candidate symbolic functions verbose : int larger verbosity => more verbosity
Returns:¶
NoneExample¶
from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.auto_symbolic()
- symbolic_formula(var=None, normalizer=None, output_normalizer=None)[source]¶
get symbolic formula
Args:¶
var : None or a list of sympy expression input variables normalizer : [mean, std] output_normalizer : [mean, std]
Returns:¶
NoneExample¶
from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.auto_symbolic() model.symbolic_formula()[0][0]
- expand_depth()[source]¶
expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:¶
var : None or a list of sympy expression input variables normalizer : [mean, std] output_normalizer : [mean, std]
Returns:¶
None
- expand_width(layer_id, n_added_nodes, sum_bool=True, mult_arity=2)[source]¶
expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:¶
layer_id : int layer index n_added_nodes : init the number of added nodes sum_bool : bool if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes mult_arity : init multiplication arity (the number of numbers to be multiplied)
Returns:¶
None
- perturb(mag=1.0, mode='non-intrusive')[source]¶
preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:¶
mag : float perturbation magnitude mode : str pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'}
Returns:¶
None
- module(start_layer, chain)[source]¶
specify network modules
Args:¶
start_layer : int the earliest layer of the module chain : str specify neurons in the module
Returns:¶
None