-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
23 lines (20 loc) · 791 Bytes
/
Copy pathtrain.py
File metadata and controls
23 lines (20 loc) · 791 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import pickle
from model import build_model, compile_model
from tensorflow.keras.callbacks import EarlyStopping
def load_data():
X_train = np.load('X_train.npy')
y_train = np.load('y_train.npy')
return X_train, y_train
def train_model():
X_train, y_train = load_data()
input_shape = (X_train.shape[1], 1)
X_train = np.expand_dims(X_train, axis=2) # 扩展维度以适应Conv1D输入
model = build_model(input_shape)
model = compile_model(model)
early_stopping = EarlyStopping(monitor='val_loss', patience=3)
history = model.fit(X_train, y_train, epochs=10, batch_size=64, validation_split=0.2, callbacks=[early_stopping])
model.save('spam_classifier_model.h5')
return history
if __name__ == "__main__":
train_model()