Source code for deepchem_server.core.train

import ast
import logging
import math
from typing import Dict, Optional

from deepchem.models.torch_models import TorchModel

from deepchem_server.core import config, model_mappings
from deepchem_server.core.address import DeepchemAddress
from deepchem_server.core.cards import ModelCard
from deepchem_server.core.progress_logger import log_progress


logger = logging.getLogger(__name__)


[docs] def train(model_type: str, dataset_address: str, model_name: str, init_kwargs: Optional[Dict] = None, train_kwargs: Optional[Dict] = None) -> str: """Trains a model on the specified dataset and writes output to datastore. Parameters ---------- model_type: str A model string recognized by deepchem_server.core.model_mappings dataset_address: str The Deepchem server datastore address of the training dataset. The dataset in the address should be a DeepChem dataset. model_name: str The name under which the output trained model will be stored in the workspace. init_kwargs: Optional[Dict] Keyword arguments to pass to model on initialization. train_kwargs: Optional[Dict] Keyword arguments to pass to model on training. Examples -------- >>> from deepchem_server.core.cards import DataCard >>> from deepchem_server.core import config, featurize >>> from deepchem_server.core.datastore import DiskDataStore >>> import tempfile >>> import pandas as pd >>> disk_datastore = DiskDataStore('profile', 'project', tempfile.mkdtemp()) >>> config.set_datastore(disk_datastore) >>> df = pd.DataFrame([["CCC", 0], ["CCCCC", 1]], columns=["smiles", "label"]) >>> card = DataCard(address='', file_type='csv', data_type='pandas.DataFrame') >>> data_address = disk_datastore.upload_data_from_memory(df, "test.csv", card) >>> dataset_address = featurize(data_address, ... featurizer="ecfp", ... output="feat_test", ... dataset_column="smiles", ... label_column="label") >>> train(model_type = "random_forest_regressor", ... dataset_address = dataset_address, ... model_name = "random_forest_model") 'deepchem://profile/project/random_forest_model' """ if init_kwargs is None: init_kwargs = {} if train_kwargs is None: train_kwargs = {} if isinstance(init_kwargs, str): init_kwargs = ast.literal_eval(init_kwargs) if isinstance(train_kwargs, str): train_kwargs = ast.literal_eval(train_kwargs) if model_type not in model_mappings.model_address_map: raise ValueError(f"Model type not recognized.\nLogs: {model_mappings.LOGS}") model = model_mappings.model_address_map[model_type](**init_kwargs) model_name = DeepchemAddress.get_key(model_name) datastore = config.get_datastore() if datastore is None: raise ValueError("Datastore not set") if datastore.exists(model_name): raise FileExistsError(f"Model name {model_name} already exists.") dataset_size = datastore.get_file_size(dataset_address) log_progress('training', 10, f"downloading dataset '{dataset_address}' ({dataset_size} bytes)") dataset = datastore.get(dataset_address) # if the model is a TorchModel, add a callback to log the epoch number if isinstance(model, TorchModel): batch_size = init_kwargs.get('batch_size', 100) nb_epoch = train_kwargs.get('nb_epoch', 10) total_steps = math.ceil(dataset.get_shape()[0][0] / batch_size) * nb_epoch log_frequency = 1 model.log_frequency = log_frequency def callback(_, step): if step % log_frequency == 0: log_progress('training', int(10 + (step / total_steps) * 80), f"training model '{model_name}' ({step}/{total_steps})") model.fit(dataset, callbacks=[callback], **train_kwargs) else: # TODO Add test for deepchem dataset log_progress('training', 50, f"training model '{model_name}'") model.fit(dataset, **train_kwargs) try: # This is a no-op for non sklearn models model.save() except NotImplementedError: pass description = '' card = ModelCard(address='', model_type=model_type, train_dataset_address=dataset_address, init_kwargs=init_kwargs, train_kwargs=train_kwargs, description=description) log_progress('training', 95, f"uploading model '{model_name}' to datastore") model_address = datastore.upload_data_from_memory(model, model_name, card, kind='model') log_progress('training', 100, f"model '{model_name}' uploaded to datastore") return model_address