Train
Composabl Train API Documentation
Overview
The Composabl Train API provides the training infrastructure for agents using Ray for distributed computing. It handles the complete training lifecycle, including algorithm configuration, resource management, checkpointing, and deployment preparation.
Trainer API
The Trainer
class is the main interface for training agents.
Basic Usage
from composabl import Trainer
from composabl_core.config.trainer_config import (
BenchmarkConfig,
RecordConfig,
PostProcessingConfig,
)
# Create trainer with configuration
trainer = Trainer({
"target": {
"local": {"address": "localhost:1337"}
},
"env": {
"name": "my-environment"
}
})
# Train agent
trainer.train(agent, train_cycles=100)
# Evaluate performance
results = trainer.postprocess(
agent,
postprocess_config=PostProcessingConfig(
file_path="model_files/",
record=RecordConfig(
avi_file_name="output.avi",
gif_file_name="output.gif",
max_frames=24 * 5,
),
benchmark=BenchmarkConfig(
num_episodes_per_scenario=2,
file_name="benchmark.json",
),
),
)
# Package for deployment
deployed_agent = trainer.package(agent)
# Clean up resources
trainer.close()
Training Configuration
Complete Configuration Example
config = {
# License (required)
"license": "your-license-key",
# Training target (required)
"target": {
# See Training Targets section for options
},
# Environment configuration
"env": {
"name": "environment-id",
"init": {
"param1": "value1",
"difficulty": "medium"
}
}
}
Training Targets
Local Target
Train with a simulator running locally:
config = {
"target": {
"local": {
"address": "localhost:1337"
}
}
}
Docker Target
Train with simulators in Docker containers:
config = {
"target": {
"docker": {
"image": "composabl/sim-reactor:latest"
}
}
}
Benchmarking
Benchmarking
For every scenario on the top-level skill, this will perform `num_episodes_per_scenario` amount of inference episodes.
# Evaluate trained agent
eval_results = trainer.postprocess(
agent,
postprocess_config=PostProcessingConfig(
file_path="model_files/",
benchmark=BenchmarkConfig(
num_episodes_per_scenario=2,
file_name="benchmark.json",
),
),
)
Recording
# Evaluation with recording
eval_results = trainer.postprocess(
agent,
postprocess_config=PostProcessingConfig(
file_path="model_files/",
record=RecordConfig(
avi_file_name="output.avi",
gif_file_name="output.gif",
max_frames=24 * 5,
),
),
)
Last updated