完成训练模块的转移
This commit is contained in:
549
yolov5/utils/loggers/comet/__init__.py
Normal file
549
yolov5/utils/loggers/comet/__init__.py
Normal file
@ -0,0 +1,549 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[3] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
|
||||
try:
|
||||
import comet_ml
|
||||
|
||||
# Project Configuration
|
||||
config = comet_ml.config.get_config()
|
||||
COMET_PROJECT_NAME = config.get_string(os.getenv("COMET_PROJECT_NAME"), "comet.project_name", default="yolov5")
|
||||
except ImportError:
|
||||
comet_ml = None
|
||||
COMET_PROJECT_NAME = None
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import yaml
|
||||
|
||||
from utils.dataloaders import img2label_paths
|
||||
from utils.general import check_dataset, scale_boxes, xywh2xyxy
|
||||
from utils.metrics import box_iou
|
||||
|
||||
COMET_PREFIX = "comet://"
|
||||
|
||||
COMET_MODE = os.getenv("COMET_MODE", "online")
|
||||
|
||||
# Model Saving Settings
|
||||
COMET_MODEL_NAME = os.getenv("COMET_MODEL_NAME", "yolov5")
|
||||
|
||||
# Dataset Artifact Settings
|
||||
COMET_UPLOAD_DATASET = os.getenv("COMET_UPLOAD_DATASET", "false").lower() == "true"
|
||||
|
||||
# Evaluation Settings
|
||||
COMET_LOG_CONFUSION_MATRIX = os.getenv("COMET_LOG_CONFUSION_MATRIX", "true").lower() == "true"
|
||||
COMET_LOG_PREDICTIONS = os.getenv("COMET_LOG_PREDICTIONS", "true").lower() == "true"
|
||||
COMET_MAX_IMAGE_UPLOADS = int(os.getenv("COMET_MAX_IMAGE_UPLOADS", 100))
|
||||
|
||||
# Confusion Matrix Settings
|
||||
CONF_THRES = float(os.getenv("CONF_THRES", 0.001))
|
||||
IOU_THRES = float(os.getenv("IOU_THRES", 0.6))
|
||||
|
||||
# Batch Logging Settings
|
||||
COMET_LOG_BATCH_METRICS = os.getenv("COMET_LOG_BATCH_METRICS", "false").lower() == "true"
|
||||
COMET_BATCH_LOGGING_INTERVAL = os.getenv("COMET_BATCH_LOGGING_INTERVAL", 1)
|
||||
COMET_PREDICTION_LOGGING_INTERVAL = os.getenv("COMET_PREDICTION_LOGGING_INTERVAL", 1)
|
||||
COMET_LOG_PER_CLASS_METRICS = os.getenv("COMET_LOG_PER_CLASS_METRICS", "false").lower() == "true"
|
||||
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
|
||||
to_pil = T.ToPILImage()
|
||||
|
||||
|
||||
class CometLogger:
|
||||
"""Log metrics, parameters, source code, models and much more with Comet."""
|
||||
|
||||
def __init__(self, opt, hyp, run_id=None, job_type="Training", **experiment_kwargs) -> None:
|
||||
"""Initializes CometLogger with given options, hyperparameters, run ID, job type, and additional experiment
|
||||
arguments.
|
||||
"""
|
||||
self.job_type = job_type
|
||||
self.opt = opt
|
||||
self.hyp = hyp
|
||||
|
||||
# Comet Flags
|
||||
self.comet_mode = COMET_MODE
|
||||
|
||||
self.save_model = opt.save_period > -1
|
||||
self.model_name = COMET_MODEL_NAME
|
||||
|
||||
# Batch Logging Settings
|
||||
self.log_batch_metrics = COMET_LOG_BATCH_METRICS
|
||||
self.comet_log_batch_interval = COMET_BATCH_LOGGING_INTERVAL
|
||||
|
||||
# Dataset Artifact Settings
|
||||
self.upload_dataset = self.opt.upload_dataset or COMET_UPLOAD_DATASET
|
||||
self.resume = self.opt.resume
|
||||
|
||||
self.default_experiment_kwargs = {
|
||||
"log_code": False,
|
||||
"log_env_gpu": True,
|
||||
"log_env_cpu": True,
|
||||
"project_name": COMET_PROJECT_NAME,
|
||||
} | experiment_kwargs
|
||||
self.experiment = self._get_experiment(self.comet_mode, run_id)
|
||||
self.experiment.set_name(self.opt.name)
|
||||
|
||||
self.data_dict = self.check_dataset(self.opt.data)
|
||||
self.class_names = self.data_dict["names"]
|
||||
self.num_classes = self.data_dict["nc"]
|
||||
|
||||
self.logged_images_count = 0
|
||||
self.max_images = COMET_MAX_IMAGE_UPLOADS
|
||||
|
||||
if run_id is None:
|
||||
self.experiment.log_other("Created from", "YOLOv5")
|
||||
if not isinstance(self.experiment, comet_ml.OfflineExperiment):
|
||||
workspace, project_name, experiment_id = self.experiment.url.split("/")[-3:]
|
||||
self.experiment.log_other(
|
||||
"Run Path",
|
||||
f"{workspace}/{project_name}/{experiment_id}",
|
||||
)
|
||||
self.log_parameters(vars(opt))
|
||||
self.log_parameters(self.opt.hyp)
|
||||
self.log_asset_data(
|
||||
self.opt.hyp,
|
||||
name="hyperparameters.json",
|
||||
metadata={"type": "hyp-config-file"},
|
||||
)
|
||||
self.log_asset(
|
||||
f"{self.opt.save_dir}/opt.yaml",
|
||||
metadata={"type": "opt-config-file"},
|
||||
)
|
||||
|
||||
self.comet_log_confusion_matrix = COMET_LOG_CONFUSION_MATRIX
|
||||
|
||||
if hasattr(self.opt, "conf_thres"):
|
||||
self.conf_thres = self.opt.conf_thres
|
||||
else:
|
||||
self.conf_thres = CONF_THRES
|
||||
if hasattr(self.opt, "iou_thres"):
|
||||
self.iou_thres = self.opt.iou_thres
|
||||
else:
|
||||
self.iou_thres = IOU_THRES
|
||||
|
||||
self.log_parameters({"val_iou_threshold": self.iou_thres, "val_conf_threshold": self.conf_thres})
|
||||
|
||||
self.comet_log_predictions = COMET_LOG_PREDICTIONS
|
||||
if self.opt.bbox_interval == -1:
|
||||
self.comet_log_prediction_interval = 1 if self.opt.epochs < 10 else self.opt.epochs // 10
|
||||
else:
|
||||
self.comet_log_prediction_interval = self.opt.bbox_interval
|
||||
|
||||
if self.comet_log_predictions:
|
||||
self.metadata_dict = {}
|
||||
self.logged_image_names = []
|
||||
|
||||
self.comet_log_per_class_metrics = COMET_LOG_PER_CLASS_METRICS
|
||||
|
||||
self.experiment.log_others(
|
||||
{
|
||||
"comet_mode": COMET_MODE,
|
||||
"comet_max_image_uploads": COMET_MAX_IMAGE_UPLOADS,
|
||||
"comet_log_per_class_metrics": COMET_LOG_PER_CLASS_METRICS,
|
||||
"comet_log_batch_metrics": COMET_LOG_BATCH_METRICS,
|
||||
"comet_log_confusion_matrix": COMET_LOG_CONFUSION_MATRIX,
|
||||
"comet_model_name": COMET_MODEL_NAME,
|
||||
}
|
||||
)
|
||||
|
||||
# Check if running the Experiment with the Comet Optimizer
|
||||
if hasattr(self.opt, "comet_optimizer_id"):
|
||||
self.experiment.log_other("optimizer_id", self.opt.comet_optimizer_id)
|
||||
self.experiment.log_other("optimizer_objective", self.opt.comet_optimizer_objective)
|
||||
self.experiment.log_other("optimizer_metric", self.opt.comet_optimizer_metric)
|
||||
self.experiment.log_other("optimizer_parameters", json.dumps(self.hyp))
|
||||
|
||||
def _get_experiment(self, mode, experiment_id=None):
|
||||
"""Returns a new or existing Comet.ml experiment based on mode and optional experiment_id."""
|
||||
if mode == "offline":
|
||||
return (
|
||||
comet_ml.ExistingOfflineExperiment(
|
||||
previous_experiment=experiment_id,
|
||||
**self.default_experiment_kwargs,
|
||||
)
|
||||
if experiment_id is not None
|
||||
else comet_ml.OfflineExperiment(
|
||||
**self.default_experiment_kwargs,
|
||||
)
|
||||
)
|
||||
try:
|
||||
if experiment_id is not None:
|
||||
return comet_ml.ExistingExperiment(
|
||||
previous_experiment=experiment_id,
|
||||
**self.default_experiment_kwargs,
|
||||
)
|
||||
|
||||
return comet_ml.Experiment(**self.default_experiment_kwargs)
|
||||
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"COMET WARNING: "
|
||||
"Comet credentials have not been set. "
|
||||
"Comet will default to offline logging. "
|
||||
"Please set your credentials to enable online logging."
|
||||
)
|
||||
return self._get_experiment("offline", experiment_id)
|
||||
|
||||
return
|
||||
|
||||
def log_metrics(self, log_dict, **kwargs):
|
||||
"""Logs metrics to the current experiment, accepting a dictionary of metric names and values."""
|
||||
self.experiment.log_metrics(log_dict, **kwargs)
|
||||
|
||||
def log_parameters(self, log_dict, **kwargs):
|
||||
"""Logs parameters to the current experiment, accepting a dictionary of parameter names and values."""
|
||||
self.experiment.log_parameters(log_dict, **kwargs)
|
||||
|
||||
def log_asset(self, asset_path, **kwargs):
|
||||
"""Logs a file or directory as an asset to the current experiment."""
|
||||
self.experiment.log_asset(asset_path, **kwargs)
|
||||
|
||||
def log_asset_data(self, asset, **kwargs):
|
||||
"""Logs in-memory data as an asset to the current experiment, with optional kwargs."""
|
||||
self.experiment.log_asset_data(asset, **kwargs)
|
||||
|
||||
def log_image(self, img, **kwargs):
|
||||
"""Logs an image to the current experiment with optional kwargs."""
|
||||
self.experiment.log_image(img, **kwargs)
|
||||
|
||||
def log_model(self, path, opt, epoch, fitness_score, best_model=False):
|
||||
"""Logs model checkpoint to experiment with path, options, epoch, fitness, and best model flag."""
|
||||
if not self.save_model:
|
||||
return
|
||||
|
||||
model_metadata = {
|
||||
"fitness_score": fitness_score[-1],
|
||||
"epochs_trained": epoch + 1,
|
||||
"save_period": opt.save_period,
|
||||
"total_epochs": opt.epochs,
|
||||
}
|
||||
|
||||
model_files = glob.glob(f"{path}/*.pt")
|
||||
for model_path in model_files:
|
||||
name = Path(model_path).name
|
||||
|
||||
self.experiment.log_model(
|
||||
self.model_name,
|
||||
file_or_folder=model_path,
|
||||
file_name=name,
|
||||
metadata=model_metadata,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
def check_dataset(self, data_file):
|
||||
"""Validates the dataset configuration by loading the YAML file specified in `data_file`."""
|
||||
with open(data_file) as f:
|
||||
data_config = yaml.safe_load(f)
|
||||
|
||||
path = data_config.get("path")
|
||||
if path and path.startswith(COMET_PREFIX):
|
||||
path = data_config["path"].replace(COMET_PREFIX, "")
|
||||
return self.download_dataset_artifact(path)
|
||||
self.log_asset(self.opt.data, metadata={"type": "data-config-file"})
|
||||
|
||||
return check_dataset(data_file)
|
||||
|
||||
def log_predictions(self, image, labelsn, path, shape, predn):
|
||||
"""Logs predictions with IOU filtering, given image, labels, path, shape, and predictions."""
|
||||
if self.logged_images_count >= self.max_images:
|
||||
return
|
||||
detections = predn[predn[:, 4] > self.conf_thres]
|
||||
iou = box_iou(labelsn[:, 1:], detections[:, :4])
|
||||
mask, _ = torch.where(iou > self.iou_thres)
|
||||
if len(mask) == 0:
|
||||
return
|
||||
|
||||
filtered_detections = detections[mask]
|
||||
filtered_labels = labelsn[mask]
|
||||
|
||||
image_id = path.split("/")[-1].split(".")[0]
|
||||
image_name = f"{image_id}_curr_epoch_{self.experiment.curr_epoch}"
|
||||
if image_name not in self.logged_image_names:
|
||||
native_scale_image = PIL.Image.open(path)
|
||||
self.log_image(native_scale_image, name=image_name)
|
||||
self.logged_image_names.append(image_name)
|
||||
|
||||
metadata = [
|
||||
{
|
||||
"label": f"{self.class_names[int(cls)]}-gt",
|
||||
"score": 100,
|
||||
"box": {"x": xyxy[0], "y": xyxy[1], "x2": xyxy[2], "y2": xyxy[3]},
|
||||
}
|
||||
for cls, *xyxy in filtered_labels.tolist()
|
||||
]
|
||||
metadata.extend(
|
||||
{
|
||||
"label": f"{self.class_names[int(cls)]}",
|
||||
"score": conf * 100,
|
||||
"box": {"x": xyxy[0], "y": xyxy[1], "x2": xyxy[2], "y2": xyxy[3]},
|
||||
}
|
||||
for *xyxy, conf, cls in filtered_detections.tolist()
|
||||
)
|
||||
self.metadata_dict[image_name] = metadata
|
||||
self.logged_images_count += 1
|
||||
|
||||
return
|
||||
|
||||
def preprocess_prediction(self, image, labels, shape, pred):
|
||||
"""Processes prediction data, resizing labels and adding dataset metadata."""
|
||||
nl, _ = labels.shape[0], pred.shape[0]
|
||||
|
||||
# Predictions
|
||||
if self.opt.single_cls:
|
||||
pred[:, 5] = 0
|
||||
|
||||
predn = pred.clone()
|
||||
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1])
|
||||
|
||||
labelsn = None
|
||||
if nl:
|
||||
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||
scale_boxes(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
|
||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred
|
||||
|
||||
return predn, labelsn
|
||||
|
||||
def add_assets_to_artifact(self, artifact, path, asset_path, split):
|
||||
"""Adds image and label assets to a wandb artifact given dataset split and paths."""
|
||||
img_paths = sorted(glob.glob(f"{asset_path}/*"))
|
||||
label_paths = img2label_paths(img_paths)
|
||||
|
||||
for image_file, label_file in zip(img_paths, label_paths):
|
||||
image_logical_path, label_logical_path = map(lambda x: os.path.relpath(x, path), [image_file, label_file])
|
||||
|
||||
try:
|
||||
artifact.add(
|
||||
image_file,
|
||||
logical_path=image_logical_path,
|
||||
metadata={"split": split},
|
||||
)
|
||||
artifact.add(
|
||||
label_file,
|
||||
logical_path=label_logical_path,
|
||||
metadata={"split": split},
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error("COMET ERROR: Error adding file to Artifact. Skipping file.")
|
||||
logger.error(f"COMET ERROR: {e}")
|
||||
continue
|
||||
|
||||
return artifact
|
||||
|
||||
def upload_dataset_artifact(self):
|
||||
"""Uploads a YOLOv5 dataset as an artifact to the Comet.ml platform."""
|
||||
dataset_name = self.data_dict.get("dataset_name", "yolov5-dataset")
|
||||
path = str((ROOT / Path(self.data_dict["path"])).resolve())
|
||||
|
||||
metadata = self.data_dict.copy()
|
||||
for key in ["train", "val", "test"]:
|
||||
split_path = metadata.get(key)
|
||||
if split_path is not None:
|
||||
metadata[key] = split_path.replace(path, "")
|
||||
|
||||
artifact = comet_ml.Artifact(name=dataset_name, artifact_type="dataset", metadata=metadata)
|
||||
for key in metadata.keys():
|
||||
if key in ["train", "val", "test"]:
|
||||
if isinstance(self.upload_dataset, str) and (key != self.upload_dataset):
|
||||
continue
|
||||
|
||||
asset_path = self.data_dict.get(key)
|
||||
if asset_path is not None:
|
||||
artifact = self.add_assets_to_artifact(artifact, path, asset_path, key)
|
||||
|
||||
self.experiment.log_artifact(artifact)
|
||||
|
||||
return
|
||||
|
||||
def download_dataset_artifact(self, artifact_path):
|
||||
"""Downloads a dataset artifact to a specified directory using the experiment's logged artifact."""
|
||||
logged_artifact = self.experiment.get_artifact(artifact_path)
|
||||
artifact_save_dir = str(Path(self.opt.save_dir) / logged_artifact.name)
|
||||
logged_artifact.download(artifact_save_dir)
|
||||
|
||||
metadata = logged_artifact.metadata
|
||||
data_dict = metadata.copy()
|
||||
data_dict["path"] = artifact_save_dir
|
||||
|
||||
metadata_names = metadata.get("names")
|
||||
if isinstance(metadata_names, dict):
|
||||
data_dict["names"] = {int(k): v for k, v in metadata.get("names").items()}
|
||||
elif isinstance(metadata_names, list):
|
||||
data_dict["names"] = {int(k): v for k, v in zip(range(len(metadata_names)), metadata_names)}
|
||||
else:
|
||||
raise "Invalid 'names' field in dataset yaml file. Please use a list or dictionary"
|
||||
|
||||
return self.update_data_paths(data_dict)
|
||||
|
||||
def update_data_paths(self, data_dict):
|
||||
"""Updates data paths in the dataset dictionary, defaulting 'path' to an empty string if not present."""
|
||||
path = data_dict.get("path", "")
|
||||
|
||||
for split in ["train", "val", "test"]:
|
||||
if data_dict.get(split):
|
||||
split_path = data_dict.get(split)
|
||||
data_dict[split] = (
|
||||
f"{path}/{split_path}" if isinstance(split, str) else [f"{path}/{x}" for x in split_path]
|
||||
)
|
||||
|
||||
return data_dict
|
||||
|
||||
def on_pretrain_routine_end(self, paths):
|
||||
"""Called at the end of pretraining routine to handle paths if training is not being resumed."""
|
||||
if self.opt.resume:
|
||||
return
|
||||
|
||||
for path in paths:
|
||||
self.log_asset(str(path))
|
||||
|
||||
if self.upload_dataset and not self.resume:
|
||||
self.upload_dataset_artifact()
|
||||
|
||||
return
|
||||
|
||||
def on_train_start(self):
|
||||
"""Logs hyperparameters at the start of training."""
|
||||
self.log_parameters(self.hyp)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
"""Called at the start of each training epoch."""
|
||||
return
|
||||
|
||||
def on_train_epoch_end(self, epoch):
|
||||
"""Updates the current epoch in the experiment tracking at the end of each epoch."""
|
||||
self.experiment.curr_epoch = epoch
|
||||
|
||||
return
|
||||
|
||||
def on_train_batch_start(self):
|
||||
"""Called at the start of each training batch."""
|
||||
return
|
||||
|
||||
def on_train_batch_end(self, log_dict, step):
|
||||
"""Callback function that updates and logs metrics at the end of each training batch if conditions are met."""
|
||||
self.experiment.curr_step = step
|
||||
if self.log_batch_metrics and (step % self.comet_log_batch_interval == 0):
|
||||
self.log_metrics(log_dict, step=step)
|
||||
|
||||
return
|
||||
|
||||
def on_train_end(self, files, save_dir, last, best, epoch, results):
|
||||
"""Logs metadata and optionally saves model files at the end of training."""
|
||||
if self.comet_log_predictions:
|
||||
curr_epoch = self.experiment.curr_epoch
|
||||
self.experiment.log_asset_data(self.metadata_dict, "image-metadata.json", epoch=curr_epoch)
|
||||
|
||||
for f in files:
|
||||
self.log_asset(f, metadata={"epoch": epoch})
|
||||
self.log_asset(f"{save_dir}/results.csv", metadata={"epoch": epoch})
|
||||
|
||||
if not self.opt.evolve:
|
||||
model_path = str(best if best.exists() else last)
|
||||
name = Path(model_path).name
|
||||
if self.save_model:
|
||||
self.experiment.log_model(
|
||||
self.model_name,
|
||||
file_or_folder=model_path,
|
||||
file_name=name,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Check if running Experiment with Comet Optimizer
|
||||
if hasattr(self.opt, "comet_optimizer_id"):
|
||||
metric = results.get(self.opt.comet_optimizer_metric)
|
||||
self.experiment.log_other("optimizer_metric_value", metric)
|
||||
|
||||
self.finish_run()
|
||||
|
||||
def on_val_start(self):
|
||||
"""Called at the start of validation, currently a placeholder with no functionality."""
|
||||
return
|
||||
|
||||
def on_val_batch_start(self):
|
||||
"""Placeholder called at the start of a validation batch with no current functionality."""
|
||||
return
|
||||
|
||||
def on_val_batch_end(self, batch_i, images, targets, paths, shapes, outputs):
|
||||
"""Callback executed at the end of a validation batch, conditionally logs predictions to Comet ML."""
|
||||
if not (self.comet_log_predictions and ((batch_i + 1) % self.comet_log_prediction_interval == 0)):
|
||||
return
|
||||
|
||||
for si, pred in enumerate(outputs):
|
||||
if len(pred) == 0:
|
||||
continue
|
||||
|
||||
image = images[si]
|
||||
labels = targets[targets[:, 0] == si, 1:]
|
||||
shape = shapes[si]
|
||||
path = paths[si]
|
||||
predn, labelsn = self.preprocess_prediction(image, labels, shape, pred)
|
||||
if labelsn is not None:
|
||||
self.log_predictions(image, labelsn, path, shape, predn)
|
||||
|
||||
return
|
||||
|
||||
def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix):
|
||||
"""Logs per-class metrics to Comet.ml after validation if enabled and more than one class exists."""
|
||||
if self.comet_log_per_class_metrics and self.num_classes > 1:
|
||||
for i, c in enumerate(ap_class):
|
||||
class_name = self.class_names[c]
|
||||
self.experiment.log_metrics(
|
||||
{
|
||||
"mAP@.5": ap50[i],
|
||||
"mAP@.5:.95": ap[i],
|
||||
"precision": p[i],
|
||||
"recall": r[i],
|
||||
"f1": f1[i],
|
||||
"true_positives": tp[i],
|
||||
"false_positives": fp[i],
|
||||
"support": nt[c],
|
||||
},
|
||||
prefix=class_name,
|
||||
)
|
||||
|
||||
if self.comet_log_confusion_matrix:
|
||||
epoch = self.experiment.curr_epoch
|
||||
class_names = list(self.class_names.values())
|
||||
class_names.append("background")
|
||||
num_classes = len(class_names)
|
||||
|
||||
self.experiment.log_confusion_matrix(
|
||||
matrix=confusion_matrix.matrix,
|
||||
max_categories=num_classes,
|
||||
labels=class_names,
|
||||
epoch=epoch,
|
||||
column_label="Actual Category",
|
||||
row_label="Predicted Category",
|
||||
file_name=f"confusion-matrix-epoch-{epoch}.json",
|
||||
)
|
||||
|
||||
def on_fit_epoch_end(self, result, epoch):
|
||||
"""Logs metrics at the end of each training epoch."""
|
||||
self.log_metrics(result, epoch=epoch)
|
||||
|
||||
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
|
||||
"""Callback to save model checkpoints periodically if conditions are met."""
|
||||
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
||||
self.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
||||
|
||||
def on_params_update(self, params):
|
||||
"""Logs updated parameters during training."""
|
||||
self.log_parameters(params)
|
||||
|
||||
def finish_run(self):
|
||||
"""Ends the current experiment and logs its completion."""
|
||||
self.experiment.end()
|
Reference in New Issue
Block a user