src.environment.problem.SOO.COCO_BBOB.kan.MultKAN

Module Contents

Classes

MultKAN

KAN class

Data

KAN

API

class src.environment.problem.SOO.COCO_BBOB.kan.MultKAN.MultKAN(width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu')[source]

Bases: torch.nn.Module

KAN class

Attributes:

grid : int
    the number of grid intervals
k : int
    spline order
act_fun : a list of KANLayers
symbolic_fun: a list of Symbolic_KANLayer
depth : int
    depth of KAN
width : list
    number of neurons in each layer.
    Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
    With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). 
mult_arity : int, or list of int lists
    multiplication arity for each multiplication node (the number of numbers to be multiplied)
grid : int
    the number of grid intervals
k : int
    the order of piecewise polynomial
base_fun : fun
    residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
symbolic_fun : a list of Symbolic_KANLayer
    Symbolic_KANLayers
symbolic_enabled : bool
    If False, the symbolic front is not computed (to save time). Default: True.
width_in : list
    The number of input neurons for each layer
width_out : list
    The number of output neurons for each layer
base_fun_name : str
    The base function b(x)
grip_eps : float
    The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
node_bias : a list of 1D torch.float
node_scale : a list of 1D torch.float
subnode_bias : a list of 1D torch.float
subnode_scale : a list of 1D torch.float
symbolic_enabled : bool
    when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
affine_trainable : bool
    indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
sp_trainable : bool
    indicate whether the overall magnitude of splines is trainable
sb_trainable : bool
    indicate whether the overall magnitude of base function is trainable
save_act : bool
    indicate whether intermediate activations are saved in forward pass
node_scores : None or list of 1D torch.float
    node attribution score
edge_scores : None or list of 2D torch.float
    edge attribution score
subnode_scores : None or list of 1D torch.float
    subnode attribution score
cache_data : None or 2D torch.float
    cached input data
acts : None or a list of 2D torch.float
    activations on nodes
auto_save : bool
    indicate whether to automatically save a checkpoint once the model is modified
state_id : int
    the state of the model (used to save checkpoint)
ckpt_path : str
    the folder to store checkpoints
round : int
    the number of times rewind() has been called
device : str

Initialization

initalize a KAN model

Args:

width : list of int
    Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
    With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
grid : int
    number of grid intervals. Default: 3.
k : int
    order of piecewise polynomial. Default: 3.
mult_arity : int, or list of int lists
    multiplication arity for each multiplication node (the number of numbers to be multiplied)
noise_scale : float
    initial injected noise to spline.
base_fun : str
    the residual function b(x). Default: 'silu'
symbolic_enabled : bool
    compute (True) or skip (False) symbolic computations (for efficiency). By default: True. 
affine_trainable : bool
    affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
grid_eps : float
    When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
grid_range : list/np.array of shape (2,))
    setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
sp_trainable : bool
    If true, scale_sp is trainable. Default: True.
sb_trainable : bool
    If true, scale_base is trainable. Default: True.
device : str
    device
seed : int
    random seed
save_act : bool
    indicate whether intermediate activations are saved in forward pass
sparse_init : bool
    sparse initialization (True) or normal dense initialization. Default: False.
auto_save : bool
    indicate whether to automatically save a checkpoint once the model is modified
state_id : int
    the state of the model (used to save checkpoint)
ckpt_path : str
    the folder to store checkpoints. Default: './model'
round : int
    the number of times rewind() has been called
device : str

Returns:

self

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) checkpoint directory created: ./model saving model version 0.0

to(device)[source]

move the model to device

Args:

device : str or device

Returns:

self

Example

from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.to(device)

property width_in[source]

The number of input nodes for each layer

property width_out[source]

The number of output subnodes for each layer

property n_sum[source]

The number of addition nodes for each layer

property n_mult[source]

The number of multiplication nodes for each layer

property feature_score[source]

attribution scores for inputs

initialize_from_another_model(another_model, x)[source]

initialize from another model of the same width, but their ‘grid’ parameter can be different. Note this is equivalent to refine() when we don’t want to keep another_model

Args:

another_model : MultKAN
x : 2D torch.float

Returns:

self

Example

from kan import * model1 = KAN(width=[2,5,1], grid=3) model2 = KAN(width=[2,5,1], grid=10) x = torch.rand(100,2) model2.initialize_from_another_model(model1, x)

log_history(method_name)[source]
refine(new_grid)[source]

grid refinement

Args:

new_grid : init
    the number of grid intervals after refinement

Returns:

a refined model : MultKAN

Example

from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) print(model.grid) x = torch.rand(100,2) model.get_act(x) model = model.refine(10) print(model.grid) checkpoint directory created: ./model saving model version 0.0 5 saving model version 0.1 10

saveckpt(path='model')[source]

save the current model to files (configuration file and state file)

Args:

path : str
    the path where checkpoints are saved

Returns:

None

Example

from kan import * device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.saveckpt(‘./mark’)

There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state

static loadckpt(path='model')[source]

load checkpoint from path

Args:

path : str
    the path where checkpoints are saved

Returns:

MultKAN

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) model.saveckpt(‘./mark’) KAN.loadckpt(‘./mark’)

copy()[source]

deepcopy

Args:

path : str
    the path where checkpoints are saved

Returns:

MultKAN

Example

from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) model2 = model.copy() model2.act_fun[0].coef.data *= 2 print(model2.act_fun[0].coef.data) print(model.act_fun[0].coef.data)

rewind(model_id)[source]

rewind to an old version

Args:

model_id : str
    in format '{a}.{b}' where a is the round number, b is the version number in that round 

Returns:

MultKAN

Example

Please refer to tutorials. API 12: Checkpoint, save & load model

checkout(model_id)[source]

check out an old version

Args:

model_id : str
    in format '{a}.{b}' where a is the round number, b is the version number in that round 

Returns:

MultKAN

Example

Same use as rewind, although checkout doesn’t change states

update_grid_from_samples(x)[source]

update grid from samples

Args:

x : 2D torch.tensor
    inputs

Returns:

None

Example

from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) print(model.act_fun[0].grid) x = torch.linspace(-10,10,steps=101)[:,None] model.update_grid_from_samples(x) print(model.act_fun[0].grid)

update_grid(x)[source]

call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN

initialize_grid_from_another_model(model, x)[source]

initialize grid from another model

Args:

model : MultKAN
    parent model
x : 2D torch.tensor
    inputs

Returns:

None

Example

from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) print(model.act_fun[0].grid) x = torch.linspace(-10,10,steps=101)[:,None] model2 = KAN(width=[1,1], grid=10, k=3, seed=0) model2.initialize_grid_from_another_model(model, x) print(model2.act_fun[0].grid)

forward(x, singularity_avoiding=False, y_th=10.0)[source]

forward pass

Args:

x : 2D torch.tensor
    inputs
singularity_avoiding : bool
    whether to avoid singularity for the symbolic branch
y_th : float
    the threshold for singularity

Returns:

None

Example1

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, seed=0) x = torch.rand(100,2) model(x).shape

Example2

from kan import * model = KAN(width=[1,1], grid=5, k=3, seed=0) x = torch.tensor([[1],[-0.01]]) model.fix_symbolic(0,0,0,’log’,fit_params_bool=False) print(model(x)) print(model(x, singularity_avoiding=True)) print(model(x, singularity_avoiding=True, y_th=1.))

set_mode(l, i, j, mode, mask_n=None)[source]
fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True)[source]

set (l,i,j) activation to be symbolic (specified by fun_name)

Args:

l : int
    layer index
i : int
    input neuron index
j : int
    output neuron index
fun_name : str
    function name
fit_params_bool : bool
    obtaining affine parameters through fitting (True) or setting default values (False)
a_range : tuple
    sweeping range of a
b_range : tuple
    sweeping range of b
verbose : bool
    If True, more information is printed.
random : bool
    initialize affine parameteres randomly or as [1,0,1,0]
log_history : bool
    indicate whether to log history when the function is called

Returns:

None or r2 (coefficient of determination)

Example 1

when fit_params_bool = False

model = KAN(width=[2,5,1], grid=5, k=3) model.fix_symbolic(0,1,3,’sin’,fit_params_bool=False) print(model.act_fun[0].mask.reshape(2,5)) print(model.symbolic_fun[0].mask.reshape(2,5))

Example 2

when fit_params_bool = True

model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) x = torch.normal(0,1,size=(100,2)) model(x) # obtain activations (otherwise model does not have attributes acts) model.fix_symbolic(0,1,3,’sin’,fit_params_bool=True) print(model.act_fun[0].mask.reshape(2,5)) print(model.symbolic_fun[0].mask.reshape(2,5))

unfix_symbolic(l, i, j, log_history=True)[source]

unfix the (l,i,j) activation function.

unfix_symbolic_all()[source]

unfix all activation functions.

get_range(l, i, j, verbose=True)[source]

Get the input range and output range of the (l,i,j) activation

Args:

l : int
    layer index
i : int
    input neuron index
j : int
    output neuron index

Returns:

x_min : float
    minimum of input
x_max : float
    maximum of input
y_min : float
    minimum of output
y_max : float
    maximum of output

Example

model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) x = torch.normal(0,1,size=(100,2)) model(x) # do a forward pass to obtain model.acts model.get_range(0,0,0)

plot(folder='./figures', beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0)[source]

plot KAN

Args:

folder : str
    the folder to store pngs
beta : float
    positive number. control the transparency of each activation. transparency = tanh(beta*l1).
mask : bool
    If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
mode : bool
    "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
scale : float
    control the size of the diagram
in_vars: None or list of str
    the name(s) of input variables
out_vars: None or list of str
    the name(s) of output variables
title: None or str
    title
varscale : float
    the size of input variables

Returns:

Figure

Example

see more interactive examples in demos

model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) x = torch.normal(0,1,size=(100,2)) model(x) # do a forward pass to obtain model.acts model.plot()

reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)[source]

Get regularization

Args:

reg_metric : the regularization metric
    'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
lamb_l1 : float
    l1 penalty strength
lamb_entropy : float
    entropy penalty strength
lamb_coef : float
    coefficient penalty strength
lamb_coefdiff : float
    coefficient smoothness strength

Returns:

reg_ : torch.float

Example

model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) x = torch.rand(100,2) model.get_act(x) model.reg(‘edge_forward_spline_n’, 1.0, 2.0, 1.0, 1.0)

get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)[source]

Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.

disable_symbolic_in_fit(lamb)[source]

during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)

get_params()[source]

Get parameters

fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None)[source]

training

Args:

dataset : dic
    contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
opt : str
    "LBFGS" or "Adam"
steps : int
    training steps
log : int
    logging frequency
lamb : float
    overall penalty strength
lamb_l1 : float
    l1 penalty strength
lamb_entropy : float
    entropy penalty strength
lamb_coef : float
    coefficient magnitude penalty strength
lamb_coefdiff : float
    difference of nearby coefficits (smoothness) penalty strength
update_grid : bool
    If True, update grid regularly before stop_grid_update_step
grid_update_num : int
    the number of grid updates before stop_grid_update_step
start_grid_update_step : int
    no grid updates before this training step
stop_grid_update_step : int
    no grid updates after this training step
loss_fn : function
    loss function
lr : float
    learning rate
batch : int
    batch size, if -1 then full.
save_fig_freq : int
    save figure every (save_fig_freq) steps
singularity_avoiding : bool
    indicate whether to avoid singularity for the symbolic part
y_th : float
    singularity threshold (anything above the threshold is considered singular and is softened in some ways)
reg_metric : str
    regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
metrics : a list of metrics (as functions)
    the metrics to be computed in training
display_metrics : a list of functions
    the metric to be displayed in tqdm progress bar

Returns:

results : dic
    results['train_loss'], 1D array of training losses (RMSE)
    results['test_loss'], 1D array of test losses (RMSE)
    results['reg'], 1D array of regularization
    other metrics specified in metrics

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot()

Most examples in toturals involve the fit() method. Please check them for useness.

prune_node(threshold=0.01, mode='auto', active_neurons_id=None, log_history=True)[source]

pruning nodes

Args:

threshold : float
    if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
mode : str
    'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in.

Returns:

pruned network : MultKAN

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune_node() model.plot()

prune_edge(threshold=0.03, log_history=True)[source]

pruning edges

Args:

threshold : float
    if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.

Returns:

pruned network : MultKAN

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune_edge() model.plot()

prune(node_th=0.01, edge_th=0.03)[source]

prune (both nodes and edges)

Args:

node_th : float
    if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
edge_th : float
    if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.

Returns:

pruned network : MultKAN

Example

from kan import * model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model = model.prune() model.plot()

prune_input(threshold=0.01, active_inputs=None, log_history=True)[source]

prune inputs

Args:

threshold : float
    if the attribution score of the input feature is below threshold, it is considered irrelevant.
active_inputs : None or list
    if a list is passed, the manual mode will disregard attribution score and prune as instructed.

Returns:

pruned network : MultKAN

Example1

automatic

from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot() model = model.prune_input() model.plot()

Example2

automatic

from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.plot() model = model.prune_input(active_inputs=[0,1]) model.plot()

remove_edge(l, i, j, log_history=True)[source]

remove activtion phi(l,i,j) (set its mask to zero)

remove_node(l, i, mode='all', log_history=True)[source]

remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)

attribute(l=None, i=None, out_score=None, plot=True)[source]

get attribution scores

Args:

l : None or int
    layer index
i : None or int
    neuron index
out_score : None or 1D torch.float
    specify output scores
plot : bool
    when plot = True, display the bar show

Returns:

attribution scores

Example

from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.attribute() model.feature_score

node_attribute()[source]
feature_interaction(l, neuron_th=0.01, feature_th=0.01)[source]

get feature interaction

Args:

l : int
    layer index
neuron_th : float
    threshold to determine whether a neuron is active
feature_th : float
    threshold to determine whether a feature is active

Returns:

dictionary

Example

from kan import * model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.attribute() model.feature_interaction(1)

suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=lambda x: ..., c_loss_fun=lambda x: ..., weight_simple=0.8)[source]

suggest symbolic function

Args:

l : int
    layer index
i : int
    neuron index in layer l
j : int
    neuron index in layer j
a_range : tuple
    search range of a
b_range : tuple
    search range of b
lib : list of str
    library of candidate symbolic functions
topk : int
    the number of top functions displayed
verbose : bool
    if verbose = True, print more information
r2_loss_fun : functoon
    function : r2 -> "bits"
c_loss_fun : fun
    function : c -> 'bits'
weight_simple : float
    the simplifty weight: the higher, more prefer simplicity over performance

Returns:

best_name (str), best_fun (function), best_r2 (float), best_c (float)

Example

from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.suggest_symbolic(0,1,0)

auto_symbolic(a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1)[source]

automatic symbolic regression for all edges

Args:

a_range : tuple
    search range of a
b_range : tuple
    search range of b
lib : list of str
    library of candidate symbolic functions
verbose : int
    larger verbosity => more verbosity

Returns:

None

Example

from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.auto_symbolic()

symbolic_formula(var=None, normalizer=None, output_normalizer=None)[source]

get symbolic formula

Args:

var : None or a list of sympy expression
    input variables
normalizer : [mean, std]
output_normalizer : [mean, std]

Returns:

None

Example

from kan import * model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) dataset = create_dataset(f, n_var=3) model.fit(dataset, opt=’LBFGS’, steps=20, lamb=0.001); model.auto_symbolic() model.symbolic_formula()[0][0]

expand_depth()[source]

expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

var : None or a list of sympy expression
    input variables
normalizer : [mean, std]
output_normalizer : [mean, std]

Returns:

None
expand_width(layer_id, n_added_nodes, sum_bool=True, mult_arity=2)[source]

expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

layer_id : int
    layer index
n_added_nodes : init
    the number of added nodes
sum_bool : bool
    if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes
mult_arity : init
    multiplication arity (the number of numbers to be multiplied)

Returns:

None
perturb(mag=1.0, mode='non-intrusive')[source]

preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

mag : float
    perturbation magnitude
mode : str
    pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'}

Returns:

None
module(start_layer, chain)[source]

specify network modules

Args:

start_layer : int
    the earliest layer of the module
chain : str
    specify neurons in the module

Returns:

None
tree(x=None, in_var=None, style='tree', sym_th=0.001, sep_th=0.1, skip_sep_test=False, verbose=False)[source]

turn KAN into a tree

speed(compile=False)[source]

turn on KAN’s speed mode

get_act(x=None)[source]

collect intermidate activations

get_fun(l, i, j)[source]

get function (l,i,j)

history(k='all')[source]

get history

property n_edge[source]

the number of active edges

evaluate(dataset)[source]
swap(l, i1, i2, log_history=True)[source]
property connection_cost[source]
auto_swap_l(l)[source]
auto_swap()[source]

automatically swap neurons such as connection costs are minimized

src.environment.problem.SOO.COCO_BBOB.kan.MultKAN.KAN[source]

None