NBeatsBlock

class NBeatsBlock(input_size: int, theta_size: int, basis_function: torch.nn.modules.module.Module, num_layers: int, layer_size: int)[source]

Bases: torch.nn.modules.module.Module

Base N-BEATS block which takes a basis function as an argument.

N-BEATS block.

Parameters
  • input_size (int) – In-sample size.

  • theta_size (int) – Number of parameters for the basis function.

  • basis_function (nn.Module) – Basis function which takes the parameters and produces backcast and forecast.

  • num_layers (int) – Number of layers.

  • layer_size (int) – Layer size.

Methods

forward(x)

Forward pass.

Attributes

forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

Forward pass.

Parameters

x (torch.Tensor) – Input data.

Returns

Tuple with backcast and forecast.

Return type

Tuple[torch.Tensor, torch.Tensor]