Source code for tfts.models.auto_config
"""AutoConfig to set up models custom config"""
from collections import OrderedDict
import importlib
from typing import Dict
from .base import BaseConfig
CONFIG_MAPPING_NAMES = OrderedDict(
[
("seq2seq", "Seq2seqConfig"),
("rnn", "RNNConfig"),
("wavenet", "WaveNetConfig"),
("tcn", "TCNConfig"),
("transformer", "TransformerConfig"),
("bert", "BertConfig"),
("informer", "InformerConfig"),
("autoformer", "AutoFormerConfig"),
("tft", "TFTransformerConfig"),
("unet", "UnetConfig"),
("nbeats", "NBeatsConfig"),
("dlinear", "DLinearConfig"),
("rwkv", "RWKVConfig"),
("patches_tst", "PatchTSTConfig"),
("deep_ar", "DeepARConfig"),
]
)
[docs]
class AutoConfig(BaseConfig):
"""AutoConfig for tfts model"""
def __init__(self, **kwargs: Dict[str, object]):
super().__init__(**kwargs)
@classmethod
def for_model(cls, model_name: str):
if model_name in CONFIG_MAPPING_NAMES:
class_name = CONFIG_MAPPING_NAMES[model_name]
module = importlib.import_module(f".{model_name}", "tfts.models")
mapping = getattr(module, class_name)
return mapping()
raise ValueError(
f"Unrecognized model: {model_name}. Should contain one of {', '.join(CONFIG_MAPPING_NAMES.keys())}"
)
def __call__(self, model_name: str):
return self.for_model(model_name)