Source code for deepchem_server.core.model_mappings

from datetime import datetime
from functools import wraps
import logging
from typing import Any, Callable, Optional

import deepchem as dc
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression

from deepchem_server.core.model_config_mapper import DeepChemModelConfigMapper, ModelAddressWrapper


logger = logging.getLogger(__name__)


[docs] def sklearn_model(model: Callable) -> Callable: """Wrapper for sklearn models to integrate with DeepChem SklearnModel. Parameters ---------- model : Callable A sklearn model class to be wrapped. Returns ------- Callable A function that initializes a DeepChem SklearnModel with the given sklearn model. """ # The wrapper here is used for distinguishing SklearnModel parameters and scikit-learn model parameters @wraps(model) def initialize_sklearn_model(model_dir: Optional[str] = None, **kwargs) -> Any: """Initialize sklearn model wrapped in DeepChem SklearnModel. Parameters ---------- model_dir : str, optional Directory path for the model. **kwargs Additional keyword arguments for the sklearn model. Returns ------- Any DeepChem sklearn model instance. """ if model_dir is None: return dc.models.SklearnModel(model(**kwargs)) else: return dc.models.SklearnModel(model(**kwargs), model_dir=model_dir) return initialize_sklearn_model
model_address_map = ModelAddressWrapper({ "linear_regression": DeepChemModelConfigMapper( model_class=sklearn_model(LinearRegression), required_init_params=None, optional_init_params=["fit_intercept", "copy_X", "n_jobs", "positive"], required_train_params=None, optional_train_params=None, ), "random_forest_classifier": DeepChemModelConfigMapper( model_class=sklearn_model(RandomForestClassifier), required_init_params=None, optional_init_params=[ "n_estimators", "criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "bootstrap", "oob_score", "n_jobs", "random_state", "verbose", "warm_start", "class_weight", "ccp_alpha", "max_samples", ], required_train_params=None, optional_train_params=["sample_weight"], ), "random_forest_regressor": DeepChemModelConfigMapper( model_class=sklearn_model(RandomForestRegressor), required_init_params=None, optional_init_params=[ "n_estimators", "criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "bootstrap", "oob_score", "n_jobs", "random_state", "verbose", "warm_start", "ccp_alpha", "max_samples", ], required_train_params=None, optional_train_params=["sample_weight"], ), }) LOGS = {}
[docs] def update_logs(log_error: ImportError) -> None: """Update logs during import errors. Parameters ---------- log_error : ImportError Import error object to be logged. Returns ------- None Examples -------- >>> from deepchem_server.core import model_mappings >>> model_mappings.LOGS == {} True >>> e = ImportError('cannot import DummyModule') >>> model_mappings.update_logs(e) >>> list(model_mappings.LOGS.values())[0] ImportError('cannot import DummyModule') """ current_date_time = str(datetime.now()) LOGS[current_date_time] = log_error
MODEL_FEAT_MAP = {} try: # torch models from deepchem.models import GCNModel model_address_map["gcn"] = DeepChemModelConfigMapper( model_class=GCNModel, required_init_params=["n_tasks"], optional_init_params=[ "graph_conv_layers", "activation", "residual", "batchnorm", "dropout", "predictor_hidden_feats", "predictor_dropout", "mode", "number_atom_features", "n_classes", "self_loop", "output_types", "batch_size", "learning_rate", "optimizer", "tensorboard", "wandb", "log_frequency", "device", "regularization_loss", "wandb_logger", ], required_train_params=None, optional_train_params=[ "nb_epoch", "max_checkpoints_to_keep", "checkpoint_interval", "deterministic", "restore", "variables", "loss", "callbacks", "all_losses", ], ) MODEL_FEAT_MAP["gcn"] = "molgraphconv" except ImportError as e2: update_logs(e2) logger.error(f"torch models not imported: {e2}") model_names = model_address_map.keys()