PatchTSModel

class PatchTSModel(decoder_length: int, encoder_length: int, patch_len: int = 4, stride: int = 1, num_layers: int = 3, hidden_size: int = 128, feedforward_size: int = 256, nhead: int = 16, lr: float = 0.001, loss: Optional[torch.nn.modules.module.Module] = None, train_batch_size: int = 128, test_batch_size: int = 128, optimizer_params: Optional[dict] = None, trainer_params: Optional[dict] = None, train_dataloader_params: Optional[dict] = None, test_dataloader_params: Optional[dict] = None, val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None)[source]

Bases: etna.models.base.DeepBaseModel

PatchTS model using PyTorch layers.

Init PatchTS model.

Parameters
  • encoder_length (int) – encoder length

  • decoder_length (int) –

  • patch_len (int) –

  • stride (int) –

  • num_layers (int) –

  • hidden_size (int) –

  • feedforward_size (int) –

  • nhead (int) –

  • lr (float) –

  • loss (Optional[torch.nn.Module]) –

  • train_batch_size (int) –

  • test_batch_size (int) –

  • optimizer_params (Optional[dict]) –

  • trainer_params (Optional[dict]) –

  • train_dataloader_params (Optional[dict]) –

  • test_dataloader_params (Optional[dict]) –

  • val_dataloader_params (Optional[dict]) –

  • split_params (Optional[dict]) –

decoder_length:

decoder length

patch_len:

size of patch

stride:

step of patch

num_layers:

number of layers

hidden_size:

size of the hidden state

feedforward_size:

size of feedforward layers in transformer

nhead:

number of transformer heads

lr:

learning rate

loss:

loss function, MSELoss by default

train_batch_size:

batch size for training

test_batch_size:

batch size for testing

optimizer_params:

parameters for optimizer for Adam optimizer (api reference torch.optim.Adam)

trainer_params:

Pytorch ligthning trainer parameters (api reference pytorch_lightning.trainer.trainer.Trainer)

train_dataloader_params:

parameters for train dataloader like sampler for example (api reference torch.utils.data.DataLoader)

test_dataloader_params:

parameters for test dataloader

val_dataloader_params:

parameters for validation dataloader

split_params:
dictionary with parameters for torch.utils.data.random_split() for train-test splitting
  • train_size: (float) value from 0 to 1 - fraction of samples to use for training

  • generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting

  • torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing __len__

Inherited-members

Parameters
  • decoder_length (int) –

  • encoder_length (int) –

  • patch_len (int) –

  • stride (int) –

  • num_layers (int) –

  • hidden_size (int) –

  • feedforward_size (int) –

  • nhead (int) –

  • lr (float) –

  • loss (Optional[torch.nn.Module]) –

  • train_batch_size (int) –

  • test_batch_size (int) –

  • optimizer_params (Optional[dict]) –

  • trainer_params (Optional[dict]) –

  • train_dataloader_params (Optional[dict]) –

  • test_dataloader_params (Optional[dict]) –

  • val_dataloader_params (Optional[dict]) –

  • split_params (Optional[dict]) –

Methods

fit(ts)

Fit model.

forecast(ts, prediction_size[, ...])

Make predictions.

get_model()

Get model.

load(path)

Load an object.

params_to_tune()

Get default grid for tuning hyperparameters.

predict(ts, prediction_size[, return_components])

Make predictions.

raw_fit(torch_dataset)

Fit model on torch like Dataset.

raw_predict(torch_dataset)

Make inference on torch like Dataset.

save(path)

Save the object.

set_params(**params)

Return new object instance with modified parameters.

to_dict()

Collect all information about etna object in dict.

Attributes

context_size

Context size of the model.

params_to_tune() Dict[str, etna.distributions.distributions.BaseDistribution][source]

Get default grid for tuning hyperparameters.

This grid tunes parameters: num_layers, hidden_size, lr, encoder_length. Other parameters are expected to be set by the user.

Returns

Grid to tune.

Return type

Dict[str, etna.distributions.distributions.BaseDistribution]