130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import os
|
||
import requests
|
||
import numpy as np
|
||
from datetime import datetime
|
||
from celery import shared_task
|
||
from django.conf import settings
|
||
from sklearn.ensemble import RandomForestClassifier
|
||
from sklearn.datasets import load_iris
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||
from skl2onnx import convert_sklearn
|
||
from skl2onnx.common.data_types import FloatTensorType
|
||
from namecreate.models import TrainingJob
|
||
|
||
|
||
def notify_go_service(model_path, metrics):
|
||
"""Go servisine model yüklenmiş olduğunu bildirir."""
|
||
try:
|
||
go_service_url = settings.GO_SERVICE_URL
|
||
if not go_service_url:
|
||
return False
|
||
|
||
payload = {
|
||
"model_path": model_path,
|
||
"metrics": metrics,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{go_service_url}/reload-model",
|
||
json=payload,
|
||
timeout=10
|
||
)
|
||
|
||
return response.status_code == 200
|
||
except Exception as e:
|
||
print(f"Go servisi bildirimi başarısız: {str(e)}")
|
||
return False
|
||
|
||
|
||
@shared_task(name='namecreate.tasks.train_model_task')
|
||
def train_model_task(task_id):
|
||
"""
|
||
Makine öğrenme modelini arka planda eğitir ve ONNX olarak kaydeder.
|
||
"""
|
||
try:
|
||
job = TrainingJob.objects.get(task_id=task_id)
|
||
job.status = 'running'
|
||
job.started_at = datetime.now()
|
||
job.save(update_fields=['status', 'started_at'])
|
||
|
||
# 1. Veri Seti Yükleme
|
||
if job.features and job.labels:
|
||
# Kullanıcının gönderdiği veri
|
||
X = np.array(job.features, dtype=np.float32)
|
||
y = np.array(job.labels, dtype=np.int32)
|
||
else:
|
||
# Demo: Iris dataset
|
||
iris = load_iris()
|
||
X, y = iris.data.astype(np.float32), iris.target
|
||
|
||
X_train, X_test, y_train, y_test = train_test_split(
|
||
X, y, test_size=0.2, random_state=42
|
||
)
|
||
feature_count = X.shape[1]
|
||
|
||
# 2. Model Eğitimi
|
||
model = RandomForestClassifier(n_estimators=10, random_state=42)
|
||
model.fit(X_train, y_train)
|
||
|
||
# 3. Metrikleri Hesapla
|
||
y_pred = model.predict(X_test)
|
||
accuracy = accuracy_score(y_test, y_pred)
|
||
precision = precision_score(y_test, y_pred, average='weighted')
|
||
recall = recall_score(y_test, y_pred, average='weighted')
|
||
f1 = f1_score(y_test, y_pred, average='weighted')
|
||
|
||
# 4. ONNX Formatına Dönüştür (feature_count dinamik)
|
||
initial_type = [('float_input', FloatTensorType([None, feature_count]))]
|
||
onx = convert_sklearn(model, initial_types=initial_type)
|
||
|
||
# 5. Dosyaya Kaydet (Versiyonlu - Timestamp ile)
|
||
timestamp = job.model_version.strftime('%Y-%m-%d_%H-%M-%S')
|
||
model_filename = f"model_{timestamp}.onnx"
|
||
model_path = os.path.join(settings.MEDIA_ROOT, 'models', model_filename)
|
||
|
||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||
with open(model_path, "wb") as f:
|
||
f.write(onx.SerializeToString())
|
||
|
||
# 6. Go Servisine Bilder
|
||
metrics = {
|
||
'accuracy': float(accuracy),
|
||
'precision': float(precision),
|
||
'recall': float(recall),
|
||
'f1_score': float(f1),
|
||
}
|
||
go_notified = notify_go_service(model_path, metrics)
|
||
|
||
# 7. Veritabanına Kaydet
|
||
job.status = 'completed'
|
||
job.completed_at = datetime.now()
|
||
job.model_path = model_path
|
||
job.accuracy = accuracy
|
||
job.precision = precision
|
||
job.recall = recall
|
||
job.f1_score = f1
|
||
job.go_service_notified = go_notified
|
||
job.save()
|
||
|
||
return {
|
||
'status': 'success',
|
||
'task_id': task_id,
|
||
'model_path': model_path,
|
||
'go_service_notified': go_notified,
|
||
'metrics': metrics
|
||
}
|
||
|
||
except Exception as e:
|
||
job = TrainingJob.objects.get(task_id=task_id)
|
||
job.status = 'failed'
|
||
job.error_message = str(e)
|
||
job.save(update_fields=['status', 'error_message'])
|
||
|
||
return {
|
||
'status': 'error',
|
||
'task_id': task_id,
|
||
'error': str(e)
|
||
}
|