first commit
This commit is contained in:
129
namecreate/tasks.py
Normal file
129
namecreate/tasks.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import requests
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from celery import shared_task
|
||||
from django.conf import settings
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
from skl2onnx import convert_sklearn
|
||||
from skl2onnx.common.data_types import FloatTensorType
|
||||
from namecreate.models import TrainingJob
|
||||
|
||||
|
||||
def notify_go_service(model_path, metrics):
|
||||
"""Go servisine model yüklenmiş olduğunu bildirir."""
|
||||
try:
|
||||
go_service_url = settings.GO_SERVICE_URL
|
||||
if not go_service_url:
|
||||
return False
|
||||
|
||||
payload = {
|
||||
"model_path": model_path,
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{go_service_url}/reload-model",
|
||||
json=payload,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Go servisi bildirimi başarısız: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(name='namecreate.tasks.train_model_task')
|
||||
def train_model_task(task_id):
|
||||
"""
|
||||
Makine öğrenme modelini arka planda eğitir ve ONNX olarak kaydeder.
|
||||
"""
|
||||
try:
|
||||
job = TrainingJob.objects.get(task_id=task_id)
|
||||
job.status = 'running'
|
||||
job.started_at = datetime.now()
|
||||
job.save(update_fields=['status', 'started_at'])
|
||||
|
||||
# 1. Veri Seti Yükleme
|
||||
if job.features and job.labels:
|
||||
# Kullanıcının gönderdiği veri
|
||||
X = np.array(job.features, dtype=np.float32)
|
||||
y = np.array(job.labels, dtype=np.int32)
|
||||
else:
|
||||
# Demo: Iris dataset
|
||||
iris = load_iris()
|
||||
X, y = iris.data.astype(np.float32), iris.target
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42
|
||||
)
|
||||
feature_count = X.shape[1]
|
||||
|
||||
# 2. Model Eğitimi
|
||||
model = RandomForestClassifier(n_estimators=10, random_state=42)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 3. Metrikleri Hesapla
|
||||
y_pred = model.predict(X_test)
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
precision = precision_score(y_test, y_pred, average='weighted')
|
||||
recall = recall_score(y_test, y_pred, average='weighted')
|
||||
f1 = f1_score(y_test, y_pred, average='weighted')
|
||||
|
||||
# 4. ONNX Formatına Dönüştür (feature_count dinamik)
|
||||
initial_type = [('float_input', FloatTensorType([None, feature_count]))]
|
||||
onx = convert_sklearn(model, initial_types=initial_type)
|
||||
|
||||
# 5. Dosyaya Kaydet (Versiyonlu - Timestamp ile)
|
||||
timestamp = job.model_version.strftime('%Y-%m-%d_%H-%M-%S')
|
||||
model_filename = f"model_{timestamp}.onnx"
|
||||
model_path = os.path.join(settings.MEDIA_ROOT, 'models', model_filename)
|
||||
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(onx.SerializeToString())
|
||||
|
||||
# 6. Go Servisine Bilder
|
||||
metrics = {
|
||||
'accuracy': float(accuracy),
|
||||
'precision': float(precision),
|
||||
'recall': float(recall),
|
||||
'f1_score': float(f1),
|
||||
}
|
||||
go_notified = notify_go_service(model_path, metrics)
|
||||
|
||||
# 7. Veritabanına Kaydet
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.now()
|
||||
job.model_path = model_path
|
||||
job.accuracy = accuracy
|
||||
job.precision = precision
|
||||
job.recall = recall
|
||||
job.f1_score = f1
|
||||
job.go_service_notified = go_notified
|
||||
job.save()
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'task_id': task_id,
|
||||
'model_path': model_path,
|
||||
'go_service_notified': go_notified,
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
job = TrainingJob.objects.get(task_id=task_id)
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.save(update_fields=['status', 'error_message'])
|
||||
|
||||
return {
|
||||
'status': 'error',
|
||||
'task_id': task_id,
|
||||
'error': str(e)
|
||||
}
|
||||
Reference in New Issue
Block a user