Files
insta/namecreate/tasks.py
Beyhan Oğur 2be3a313ad first commit
2026-04-26 22:26:46 +03:00

130 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import requests
import numpy as np
from datetime import datetime
from celery import shared_task
from django.conf import settings
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from namecreate.models import TrainingJob
def notify_go_service(model_path, metrics):
"""Go servisine model yüklenmiş olduğunu bildirir."""
try:
go_service_url = settings.GO_SERVICE_URL
if not go_service_url:
return False
payload = {
"model_path": model_path,
"metrics": metrics,
"timestamp": datetime.now().isoformat(),
}
response = requests.post(
f"{go_service_url}/reload-model",
json=payload,
timeout=10
)
return response.status_code == 200
except Exception as e:
print(f"Go servisi bildirimi başarısız: {str(e)}")
return False
@shared_task(name='namecreate.tasks.train_model_task')
def train_model_task(task_id):
"""
Makine öğrenme modelini arka planda eğitir ve ONNX olarak kaydeder.
"""
try:
job = TrainingJob.objects.get(task_id=task_id)
job.status = 'running'
job.started_at = datetime.now()
job.save(update_fields=['status', 'started_at'])
# 1. Veri Seti Yükleme
if job.features and job.labels:
# Kullanıcının gönderdiği veri
X = np.array(job.features, dtype=np.float32)
y = np.array(job.labels, dtype=np.int32)
else:
# Demo: Iris dataset
iris = load_iris()
X, y = iris.data.astype(np.float32), iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
feature_count = X.shape[1]
# 2. Model Eğitimi
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X_train, y_train)
# 3. Metrikleri Hesapla
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')
# 4. ONNX Formatına Dönüştür (feature_count dinamik)
initial_type = [('float_input', FloatTensorType([None, feature_count]))]
onx = convert_sklearn(model, initial_types=initial_type)
# 5. Dosyaya Kaydet (Versiyonlu - Timestamp ile)
timestamp = job.model_version.strftime('%Y-%m-%d_%H-%M-%S')
model_filename = f"model_{timestamp}.onnx"
model_path = os.path.join(settings.MEDIA_ROOT, 'models', model_filename)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
with open(model_path, "wb") as f:
f.write(onx.SerializeToString())
# 6. Go Servisine Bilder
metrics = {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1_score': float(f1),
}
go_notified = notify_go_service(model_path, metrics)
# 7. Veritabanına Kaydet
job.status = 'completed'
job.completed_at = datetime.now()
job.model_path = model_path
job.accuracy = accuracy
job.precision = precision
job.recall = recall
job.f1_score = f1
job.go_service_notified = go_notified
job.save()
return {
'status': 'success',
'task_id': task_id,
'model_path': model_path,
'go_service_notified': go_notified,
'metrics': metrics
}
except Exception as e:
job = TrainingJob.objects.get(task_id=task_id)
job.status = 'failed'
job.error_message = str(e)
job.save(update_fields=['status', 'error_message'])
return {
'status': 'error',
'task_id': task_id,
'error': str(e)
}