Train Model
Runs training on a base model to generate a fine-tuned model given a training dataset.
POST /models/train
Body Parameters
Name | Type | Default value | Description |
---|---|---|---|
datasetId |
UUID | "" |
Required - ID of the training dataset already created. |
modelName |
String | "" |
Required - Name of the model that will be created by this task. |
baseModel |
String | "" |
Required - Name of the model or {model_id}/{checkpoint_id} to fine-tune. Model name must be from the list of supported models. {model_id}/{checkpoint_id} can be from any Marqtuned model in your account. |
maxTrainingTime |
Float | 86400 |
Optional - Maximum time to run the training task, in seconds. Default is 86400 = 24 hours. The training task will be automatically terminated when the max time is reached. |
hyperparameters |
Dictionary | "" |
Required - Training task parameters - see the Training parameters guide for details. |
instanceType |
String | marqtune.basic |
Required - marqtune.basic or marqtune.performance instance type for performing the training. More details can be found in the Getting Started with Marqtune guide. |
waitForCompletion |
Boolean | True |
Optional[py-marqtune client only] - Instructs the client to continuously wait and poll until the operation is completed. |
Example: Training a model
from marqtune.client import Client
from marqtune.enums import ModelType, DatasetType, InstanceType
url = "https://marqtune.marqo.ai"
api_key = "{api_key}"
marqtune_client = Client(url=url, api_key=api_key)
marqtune_client.train_model(
dataset_id="dataset_id",
model_name="test_model",
base_model="Marqo/ViT-B-32.laion2b_s34b_b79k",
max_training_time=600,
instance_type=InstanceType.BASIC,
hyperparameters={"leftKeys": ["query"], "rightKeys": ["my_image", "my_text"], "leftWeights": [1], "rightWeights": [0.9, 0.1] },
wait_for_completion=True
)
# Train a model.
cURL -X POST 'https://marqtune.marqo.ai/models/train' \
-H "Content-Type: application/json" \
-H 'x-api-key: {api_key}' \
-d '{
"datasetId": "dataset_id",
"modelName": "test_model",
"baseModel": "Marqo/ViT-B-32.laion2b_s34b_b79k",
"maxTrainingTime": 600,
"hyperparameters": {"leftKeys": ["query"], "rightKeys": ["my_image", "my_text"], "leftWeights": [1], "rightWeights": [0.9, 0.1] },
"instanceType": "marqtune.basic"
}'
Response: 202 Accepted
Training task has been initialised and will now be executed.
{
"statusCode": 202,
"body": {
"modelId": "model_id"
}
}
Response: 400 (Invalid dataset)
Invalid dataset
{
"statusCode": 400,
"body": {
"message": "Dataset must be of type 'training'"
}
}
Response: 400 (Invalid base model)
Invalid base model
{
"statusCode": 400,
"body": {
"message": "Model with id {base_model} not found"
}
}
Response: 400 (Invalid checkpoint)
Invalid checkpoint
{
"statusCode": 400,
"body": {
"message": "Invalid checkpoint. Available checkpoints: {checkpoints}"
}
}
Response: 400 (Invalid hyperparameters)
Invalid hyperparameters are present in the data schema of the dataset
{
"statusCode": 400,
"body": {
"message": "Invalid <left|right> key: <hyperparameter key> not found in the data schema"
}
}
Response: 400 (Invalid hyperparameters)
Invalid weight key is present in the data schema of the dataset
{
"statusCode": 400,
"body": {
"message": "Invalid weight key: <weight_key> not found in the data schema"
}
}
Response: 400 (Invalid Request)
Request path or method is invalid.
{
"statusCode": 400,
"body": {
"message": "Invalid request method"
}
}
Response: 401 (Unauthorised)
Unauthorised. Check your API key and try again.
{
"message": "Unauthorized."
}
Response: 500 (Internal server error)
Internal server error. Check your API key and try again.
{
"message": "Internal server error."
}