Error using trainNetwork (line 184) Invalid training data. Responses must be a matrix of numeric responses, or a N-by-1 cell array of sequences

10 views (last 30 days)
Good day Sir
Given a 3 year training data to train LSTM model, 1 year validation data to fine tune hyperparameters for optimization and regularization, and test data to evaluate the performance of the LSTM model, in which training, validation and test data was reshaped in the format as shown below.
input data:
  • X_train_reshaped = 1095x1 cell (4x24 double each);
  • X_val_reshaped = 366x1 cell (4x24 double each);
  • X_test_reshaped = 365x1 cell (4x24 double each)
Traget or output data:
  • y_train_cell = 1095x1 cell (1x24 double each);
  • y_val_cell = 366x1 cell (1x24 double each);
  • y_test_cell = 365x1 cell (1x24 double each)
% Reshape training data for LSTM input and output
num_features = size(X_train, 2);
num_samples_per_day = 24; % Number of hours in a day
num_days_train = size(X_train, 1) / num_samples_per_day;
X_train_reshaped = cell(num_days_train, 1);
y_train_cell = cell(num_days_train, 1);
% Repeat Reshape for validation and test data
% Apply grid search for hyperparameters
% Loop through hyperparameters
for num_layer = num_layers
for hidden_unit = hidden_units
for hidden_layer = hidden_layers
for learning_rate = learning_rates
for optimizer = optimizers
for dropout_rate = dropout_rates
for numEpochs = num_epochs
fprintf('Training LSTM model with num_layers=%d, hidden_units=%d, learning_rate=%.4f, optimizer=%s, dropout_rate=%.2f, num_epochs=%d\n', num_layer, hidden_unit, learning_rate, optimizer{1}, dropout_rate, numEpochs);
% Build LSTM model
layers = [
sequenceInputLayer(num_features)
];
for i = 1:hidden_layer
layers = [
layers
lstmLayer(hidden_unit, 'OutputMode', 'sequence')
];
if dropout_rate > 0
layers = [
layers
dropoutLayer(dropout_rate)
];
end
end
% Output layer configuration
layers = [
layers
fullyConnectedLayer(num_samples_per_day)
regressionLayer
];
% Specify training options
options = trainingOptions(optimizer{1}, ...
'MaxEpochs', numEpochs, ...
'InitialLearnRate', learning_rate, ...
'GradientThreshold', 1, ...
'ExecutionEnvironment', 'cpu', ...
'Verbose', false, ...
'Plots', 'training-progress');
% Train the LSTM model
net = trainNetwork(X_train_reshaped, y_train_cell, layers, options);
% Predict on validation set
y_pred = predict(net, X_val_reshaped);
% Evaluate the model
rmse = sqrt(mean((cell2mat(y_pred) - cell2mat(y_val_cell)).^2));
mae = mean(abs(cell2mat(y_pred) - cell2mat(y_val_cell)));
% Check if current model is the best so far
if rmse < best_rmse
best_rmse = rmse;
best_hyperparameters.hidden_units = hidden_unit;
best_hyperparameters.num_layers = hidden_layer;
best_hyperparameters.learning_rate = learning_rate;
best_hyperparameters.optimizer = optimizer{1};
best_hyperparameters.dropout_rate = dropout_rate;
best_hyperparameters.num_epochs = numEpochs;
best_mae = mae; % Update best MAE
end
end
end
end
end
end
end
end
Training LSTM model with num_layers=1, hidden_units=12, learning_rate=0.0010, optimizer=adam, dropout_rate=0.10, num_epochs=50
Error using trainNetwork (line 184)
Invalid training data. Responses must be a matrix of numeric responses, or a N-by-1 cell array of sequences, where N is the number of sequences. The feature dimension of all sequences must be the same.
  2 Comments
Ben
Ben on 8 Apr 2024 at 12:34
I notice fullyConnectedLayer(num_samples_per_day) where num_samples_per_day = 24 is the "sequence length" for each input/output sequence. In particular you say each of y_train_cell{i} is 1x24, i.e. a sequence with 1 feature and 24 samples. For that you need to use fullyConnectedLayer(1) at the end of the network.
Besides that, perhaps check that num_features is 4, and that all the entries X_train_reshaped{i} have size 4x24, and X_val_reshaped{i}. Similar for the output data. It could be easy for an empty array to accidentally get into one of these cells during your pre-processing.
Cris LaPierre
Cris LaPierre on 8 Apr 2024 at 13:17
An observation that cell inputs are only allowed for Sequence data, but your data does not appear to be sequence data.
Can you describe what your data is?

Sign in to comment.

Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!