如何用Java案例实现数据采样?

wen java案例 1

本文目录导读:

如何用Java案例实现数据采样?

  1. 简单随机采样
  2. 分层采样
  3. 系统采样
  4. 时间序列采样
  5. 有权重的采样
  6. 使用建议

我来为您提供几个Java数据采样的实现案例,涵盖常见的采样算法:

简单随机采样

import java.util.*;
import java.util.stream.Collectors;
public class RandomSampling {
    /**
     * 简单随机采样 - 从数据集中随机选择n个样本
     */
    public static <T> List<T> randomSample(List<T> data, int sampleSize) {
        if (sampleSize >= data.size()) {
            return new ArrayList<>(data);
        }
        List<T> copy = new ArrayList<>(data);
        Collections.shuffle(copy);
        return copy.subList(0, sampleSize);
    }
    /**
     * 简单随机采样 - 使用Reservoir Sampling(适合大数据流)
     */
    public static <T> List<T> reservoirSample(List<T> data, int sampleSize) {
        List<T> reservoir = new ArrayList<>(sampleSize);
        Random random = new Random();
        // 初始化:填充前k个元素
        for (int i = 0; i < sampleSize && i < data.size(); i++) {
            reservoir.add(data.get(i));
        }
        // 后续元素按概率替换
        for (int i = sampleSize; i < data.size(); i++) {
            int j = random.nextInt(i + 1);
            if (j < sampleSize) {
                reservoir.set(j, data.get(i));
            }
        }
        return reservoir;
    }
    public static void main(String[] args) {
        List<Integer> data = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
        // 测试简单随机采样
        System.out.println("简单随机采样: " + randomSample(data, 3));
        // 测试Reservoir Sampling
        System.out.println("Reservoir Sampling: " + reservoirSample(data, 3));
    }
}

分层采样

import java.util.*;
import java.util.stream.Collectors;
public class StratifiedSampling {
    /**
     * 分层采样 - 按类别均匀采样
     */
    public static <T> List<T> stratifiedSample(
            List<T> data, 
            java.util.function.Function<T, String> labelExtractor, 
            int totalSampleSize) {
        // 按标签分组
        Map<String, List<T>> groups = data.stream()
                .collect(Collectors.groupingBy(labelExtractor));
        // 计算每层采样数量
        int numGroups = groups.size();
        int perGroupSize = totalSampleSize / numGroups;
        List<T> result = new ArrayList<>();
        Random random = new Random();
        for (List<T> group : groups.values()) {
            List<T> sample = new ArrayList<>(group);
            Collections.shuffle(sample, random);
            int actualSize = Math.min(perGroupSize, sample.size());
            result.addAll(sample.subList(0, actualSize));
        }
        return result;
    }
    public static void main(String[] args) {
        // 创建带标签的数据
        List<DataPoint> data = new ArrayList<>();
        data.add(new DataPoint("A", 1.0));
        data.add(new DataPoint("A", 1.5));
        data.add(new DataPoint("A", 2.0));
        data.add(new DataPoint("B", 3.0));
        data.add(new DataPoint("B", 3.5));
        data.add(new DataPoint("B", 4.0));
        data.add(new DataPoint("C", 5.0));
        data.add(new DataPoint("C", 5.5));
        data.add(new DataPoint("C", 6.0));
        // 执行分层采样
        List<DataPoint> sample = stratifiedSample(
                data, 
                dp -> dp.label, 
                3
        );
        System.out.println("分层采样结果:");
        sample.forEach(System.out::println);
    }
    static class DataPoint {
        String label;
        double value;
        DataPoint(String label, double value) {
            this.label = label;
            this.value = value;
        }
        @Override
        public String toString() {
            return "DataPoint{" + "label='" + label + '\'' + ", value=" + value + '}';
        }
    }
}

系统采样

import java.util.*;
import java.util.stream.Collectors;
public class SystematicSampling {
    /**
     * 系统采样 - 按固定间隔采样
     */
    public static <T> List<T> systematicSample(List<T> data, int sampleSize) {
        if (sampleSize >= data.size()) {
            return new ArrayList<>(data);
        }
        int n = data.size();
        int interval = n / sampleSize;
        int start = new Random().nextInt(interval); // 随机起点
        List<T> result = new ArrayList<>(sampleSize);
        for (int i = 0; i < sampleSize; i++) {
            int index = start + i * interval;
            if (index < n) {
                result.add(data.get(index));
            }
        }
        return result;
    }
    public static void main(String[] args) {
        List<Integer> data = new ArrayList<>();
        for (int i = 1; i <= 100; i++) {
            data.add(i);
        }
        System.out.println("系统采样结果: " + systematicSample(data, 10));
    }
}

时间序列采样

import java.time.*;
import java.util.*;
import java.util.stream.Collectors;
public class TimeSeriesSampling {
    /**
     * 时间序列采样器
     */
    public static class TimeSeriesSampler<T> {
        /**
         * 按时间间隔采样
         */
        public List<TimedPoint<T>> sampleByInterval(
                List<TimedPoint<T>> data, 
                Duration interval) {
            if (data == null || data.isEmpty()) {
                return Collections.emptyList();
            }
            List<TimedPoint<T>> result = new ArrayList<>();
            TimedPoint<T> lastSample = data.get(0);
            result.add(lastSample);
            for (TimedPoint<T> point : data) {
                if (Duration.between(lastSample.timestamp, point.timestamp)
                        .compareTo(interval) >= 0) {
                    result.add(point);
                    lastSample = point;
                }
            }
            return result;
        }
        /**
         * 按固定数量采样
         */
        public List<TimedPoint<T>> sampleByCount(
                List<TimedPoint<T>> data, 
                int sampleCount) {
            if (data.size() <= sampleCount) {
                return new ArrayList<>(data);
            }
            int step = data.size() / sampleCount;
            List<TimedPoint<T>> result = new ArrayList<>(sampleCount);
            for (int i = 0; i < data.size(); i += step) {
                if (result.size() < sampleCount) {
                    result.add(data.get(i));
                }
            }
            return result;
        }
    }
    static class TimedPoint<T> {
        LocalDateTime timestamp;
        T value;
        TimedPoint(LocalDateTime timestamp, T value) {
            this.timestamp = timestamp;
            this.value = value;
        }
        @Override
        public String toString() {
            return "TimedPoint{" + 
                   "timestamp=" + timestamp + 
                   ", value=" + value + '}';
        }
    }
    public static void main(String[] args) {
        // 创建时间序列数据
        List<TimedPoint<Double>> timeSeries = new ArrayList<>();
        LocalDateTime start = LocalDateTime.now();
        for (int i = 0; i < 100; i++) {
            timeSeries.add(new TimedPoint<>(
                start.plusMinutes(i),
                Math.random() * 100
            ));
        }
        TimeSeriesSampler<Double> sampler = new TimeSeriesSampler<>();
        // 按10分钟间隔采样
        System.out.println("按时间间隔采样 (10分钟):");
        List<TimedPoint<Double>> intervalSample = 
            sampler.sampleByInterval(timeSeries, Duration.ofMinutes(10));
        intervalSample.forEach(p -> System.out.println(p.timestamp + " -> " + p.value));
        // 按固定数量采样
        System.out.println("\n按固定数量采样 (10个):");
        List<TimedPoint<Double>> countSample = 
            sampler.sampleByCount(timeSeries, 10);
        countSample.forEach(p -> System.out.println(p.timestamp + " -> " + p.value));
    }
}

有权重的采样

import java.util.*;
import java.util.stream.Collectors;
public class WeightedSampling {
    /**
     * 加权随机采样
     */
    public static <T> List<T> weightedSample(
            List<T> data, 
            List<Double> weights, 
            int sampleSize) {
        if (data.size() != weights.size()) {
            throw new IllegalArgumentException("数据和权重长度必须相同");
        }
        // 计算累计权重
        double[] cumulativeWeights = new double[weights.size()];
        double totalWeight = 0;
        for (int i = 0; i < weights.size(); i++) {
            totalWeight += weights.get(i);
            cumulativeWeights[i] = totalWeight;
        }
        // 进行加权采样
        List<T> result = new ArrayList<>(sampleSize);
        Random random = new Random();
        // 使用轮盘赌选择
        Set<Integer> selectedIndices = new HashSet<>();
        while (result.size() < sampleSize && result.size() < data.size()) {
            double randomValue = random.nextDouble() * totalWeight;
            // 二分查找选择位置
            int index = binarySearch(cumulativeWeights, randomValue);
            if (!selectedIndices.contains(index)) {
                selectedIndices.add(index);
                result.add(data.get(index));
            }
        }
        return result;
    }
    private static int binarySearch(double[] cumulativeWeights, double target) {
        int left = 0;
        int right = cumulativeWeights.length - 1;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (cumulativeWeights[mid] < target) {
                left = mid + 1;
            } else {
                right = mid;
            }
        }
        return left;
    }
    public static void main(String[] args) {
        List<String> items = Arrays.asList("A", "B", "C", "D");
        List<Double> weights = Arrays.asList(0.1, 0.2, 0.3, 0.4); // 概率权重
        System.out.println("加权采样结果:");
        for (int i = 0; i < 5; i++) {
            List<String> sample = weightedSample(items, weights, 2);
            System.out.println("采样" + (i+1) + ": " + sample);
        }
    }
}

使用建议

  1. 小数据集:使用简单随机采样
  2. 大数据流:使用Reservoir Sampling
  3. 类别不均衡:使用分层采样
  4. 时间序列数据:使用时间序列采样
  5. 有偏好的采样:使用加权采样

这些案例覆盖了常见的数据采样需求,您可以根据具体场景选择合适的采样方法。

上一篇Java案例如何实现限时任务?

下一篇当前分类已是最新一篇

抱歉,评论功能暂时关闭!