DeepStateModel

class DeepStateModel(ssm: etna.models.nn.deepstate.state_space_model.CompositeSSM, input_size: int, encoder_length: int, decoder_length: int, num_layers: int = 1, n_samples: int = 5, lr: float = 0.001, train_batch_size: int = 16, test_batch_size: int = 16, optimizer_params: Optional[dict] = None, trainer_params: Optional[dict] = None, train_dataloader_params: Optional[dict] = None, test_dataloader_params: Optional[dict] = None, val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None)[source]

Bases: etna.models.base.DeepBaseModel

DeepState model.

Init Deep State Model.

Parameters
  • ssm (etna.models.nn.deepstate.state_space_model.CompositeSSM) – state Space Model of the system

  • input_size (int) – size of the input feature space: features for RNN part.

  • encoder_length (int) – encoder length

  • decoder_length (int) – decoder length

  • num_layers (int) – number of layers in RNN

  • n_samples (int) – number of samples to use in predictions generation

  • num_layers – number of layers

  • lr (float) – learning rate

  • train_batch_size (int) – batch size for training

  • test_batch_size (int) – batch size for testing

  • optimizer_params (Optional[dict]) – parameters for optimizer for Adam optimizer (api reference torch.optim.Adam)

  • trainer_params (Optional[dict]) – Pytorch ligthning trainer parameters (api reference pytorch_lightning.trainer.trainer.Trainer)

  • train_dataloader_params (Optional[dict]) – parameters for train dataloader like sampler for example (api reference torch.utils.data.DataLoader)

  • test_dataloader_params (Optional[dict]) – parameters for test dataloader

  • val_dataloader_params (Optional[dict]) – parameters for validation dataloader

  • split_params (Optional[dict]) –

    dictionary with parameters for torch.utils.data.random_split() for train-test splitting
    • train_size: (float) value from 0 to 1 - fraction of samples to use for training

    • generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting

    • torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing __len__

Inherited-members

Methods

fit(ts)

Fit model.

forecast(ts, prediction_size[, ...])

Make predictions.

get_model()

Get model.

load(path)

Load an object.

params_to_tune()

Get grid for tuning hyperparameters.

predict(ts, prediction_size[, return_components])

Make predictions.

raw_fit(torch_dataset)

Fit model on torch like Dataset.

raw_predict(torch_dataset)

Make inference on torch like Dataset.

save(path)

Save the object.

set_params(**params)

Return new object instance with modified parameters.

to_dict()

Collect all information about etna object in dict.

Attributes

context_size

Context size of the model.