Add Regression Prediction lab
This commit is contained in:
32
RegressionPredictionLab/build.gradle
Normal file
32
RegressionPredictionLab/build.gradle
Normal file
@@ -0,0 +1,32 @@
|
||||
import org.springframework.boot.gradle.plugin.SpringBootPlugin
|
||||
|
||||
apply plugin: 'java'
|
||||
apply plugin: 'org.springframework.boot'
|
||||
apply plugin: 'io.spring.dependency-management'
|
||||
|
||||
description = "Regression Prediction Lab"
|
||||
|
||||
java {
|
||||
sourceCompatibility = JavaVersion.VERSION_21
|
||||
targetCompatibility = JavaVersion.VERSION_21
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencyManagement {
|
||||
imports {
|
||||
mavenBom SpringBootPlugin.BOM_COORDINATES
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'org.springframework.boot:spring-boot-starter-web'
|
||||
|
||||
implementation 'org.apache.commons:commons-math3:3.6.1'
|
||||
implementation 'com.opencsv:opencsv:5.7.1'
|
||||
|
||||
runtimeOnly 'org.springframework.boot:spring-boot-devtools'
|
||||
testImplementation 'org.springframework.boot:spring-boot-starter-test'
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.example.regression.prediction;
|
||||
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.apache.commons.math3.stat.regression.SimpleRegression;
|
||||
import com.opencsv.CSVReader;
|
||||
import com.opencsv.exceptions.CsvException;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.List;
|
||||
|
||||
@SpringBootApplication
|
||||
public class CustomerPurchasePredictor {
|
||||
|
||||
public static void main(String[] args) {
|
||||
try {
|
||||
// 1. Load the dataset
|
||||
List<String[]> data = loadCSV("/customer_purchases.csv");
|
||||
|
||||
// 2. Prepare the regression model
|
||||
SimpleRegression regression = new SimpleRegression();
|
||||
|
||||
// Skip header row and add data points
|
||||
for (int i = 1; i < data.size(); i++) {
|
||||
String[] row = data.get(i);
|
||||
double income = Double.parseDouble(row[2]); // Independent variable (X)
|
||||
double purchaseAmount = Double.parseDouble(row[3]); // Dependent variable (Y)
|
||||
regression.addData(income, purchaseAmount);
|
||||
}
|
||||
|
||||
// 3. Print model statistics
|
||||
System.out.println("=== Model Summary ===");
|
||||
System.out.printf("R-squared: %.4f\n", regression.getRSquare());
|
||||
System.out.printf("Intercept: %.2f\n", regression.getIntercept());
|
||||
System.out.printf("Slope: %.4f\n", regression.getSlope());
|
||||
System.out.printf("Standard Error: %.4f\n\n", regression.getRegressionSumSquares());
|
||||
|
||||
// 4. Make predictions for new customers
|
||||
System.out.println("=== Predictions ===");
|
||||
predictPurchase(regression, 40000); // $40,000 income
|
||||
predictPurchase(regression, 55000); // $55,000 income
|
||||
predictPurchase(regression, 80000); // $80,000 income
|
||||
|
||||
} catch (IOException | CsvException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
private static List<String[]> loadCSV(String filePath) throws IOException, CsvException {
|
||||
try (CSVReader reader = new CSVReader(new InputStreamReader(new ClassPathResource(filePath).getInputStream()))) {
|
||||
return reader.readAll();
|
||||
}
|
||||
}
|
||||
|
||||
private static void predictPurchase(SimpleRegression regression, double income) {
|
||||
double predictedAmount = regression.predict(income);
|
||||
System.out.printf("Predicted purchase for $%,.2f income: $%,.2f\n",
|
||||
income, predictedAmount);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
spring.application.name=prediction
|
||||
@@ -0,0 +1,21 @@
|
||||
customer_id,age,income,purchase_amount
|
||||
1,25,30000,150
|
||||
2,30,35000,180
|
||||
3,35,40000,210
|
||||
4,40,45000,250
|
||||
5,45,50000,280
|
||||
6,50,55000,320
|
||||
7,55,60000,350
|
||||
8,60,65000,380
|
||||
9,65,70000,420
|
||||
10,70,75000,450
|
||||
11,28,32000,160
|
||||
12,32,38000,190
|
||||
13,38,42000,230
|
||||
14,42,48000,260
|
||||
15,48,52000,290
|
||||
16,52,58000,330
|
||||
17,58,62000,360
|
||||
18,62,68000,390
|
||||
19,68,72000,430
|
||||
20,72,78000,460
|
||||
|
@@ -0,0 +1,13 @@
|
||||
package com.example.regression.prediction;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
@SpringBootTest
|
||||
class PredictionApplicationTests {
|
||||
|
||||
@Test
|
||||
void contextLoads() {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -7,3 +7,4 @@ include 'RetailManagementSystem:front-end'
|
||||
include 'InventoryManagementSystem'
|
||||
include 'SmartClinicManagementSystem:app'
|
||||
include 'SoftwareDevChatbot'
|
||||
include 'RegressionPredictionLab'
|
||||
Reference in New Issue
Block a user