C# 从基础神经元到实现在0~9数字识别

news/2025/2/23 11:00:45

训练图片:mnist160

测试结果:1000次训练学习率为0.1时,准确率在60%以上

学习的图片越多,训练的时候越长(比如把 epochs*10 = 10000或更高时)效果越好

using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Windows.Forms;

namespace LLM
{
   

// 定义权重类


    class Weight
    {
        private static Random random = new Random();
        public double Value { get; set; }

        public Weight()
        {
            Value = random.NextDouble() - 0.5;
        }
    }

   

// 定义神经元连接类


    class NeuronLink
    {
        public Weight Weight { get; set; }
        public Neuron FromNeuron { get; set; }
        public Neuron ToNeuron { get; set; }

        public NeuronLink(Neuron fromNeuron, Neuron toNeuron)
        {
            FromNeuron = fromNeuron;
            ToNeuron = toNeuron;
            Weight = new Weight();
        }
    }

    // 定义神经元类


    class Neuron
    {
        private static Random random = new Random();
        public double Bias { get; set; }
        public double Output { get; set; }
        public double Error { get; set; }
        public NeuronLink[] InputLinks { get; set; }

        public Neuron(int inputCount, Neuron[] previousLayerNeurons)
        {
            Bias = random.NextDouble() - 0.5;
            InputLinks = new NeuronLink[inputCount];
            for (int i = 0; i < inputCount; i++)
            {
                InputLinks[i] = new NeuronLink(previousLayerNeurons[i], this);
            }
        }

       

// 激活函数(Sigmoid)


        private double Sigmoid(double x)
        {
            return 1.0 / (1.0 + Math.Exp(-x));
        }

        // 计算神经元的输出


        public double CalculateOutput()
        {
            double sum = Bias;
            foreach (var link in InputLinks)
            {
                sum += link.FromNeuron.Output * link.Weight.Value;
            }
            Output = Sigmoid(sum);
            return Output;
        }

        // 激活函数的导数


        public double SigmoidDerivative()
        {
            return Output * (1 - Output);
        }
    }

    // 定义层类


    class Layer
    {
        public Neuron[] Neurons { get; set; }

        public Layer(int neuronCount, Layer previousLayer)
        {
            Neurons = new Neuron[neuronCount];
            if (previousLayer == null)
            {
                for (int i = 0; i < neuronCount; i++)
                {


                    // 输入层神经元没有输入连接


                    Neurons[i] = new Neuron(0, new Neuron[0]);
                }
            }
            else
            {
                for (int i = 0; i < neuronCount; i++)
                {
                    Neurons[i] = new Neuron(previousLayer.Neurons.Length, previousLayer.Neurons);
                }
            }
        }
    }

    // 定义神经网络类


    class NeuralNetwork
    {
        private Layer inputLayer;
        private Layer hiddenLayer;
        private Layer outputLayer;

        public NeuralNetwork(int inputSize, int hiddenSize, int outputSize)
        {
            inputLayer = new Layer(inputSize, null);
            hiddenLayer = new Layer(hiddenSize, inputLayer);
            outputLayer = new Layer(outputSize, hiddenLayer);
        }

        // 前向传播


        public double[] FeedForward(double[] input)
        {


            // 设置输入层神经元的输出


            for (int i = 0; i < inputLayer.Neurons.Length; i++)
            {
                inputLayer.Neurons[i].Output = input[i];
            }

            // 计算隐藏层神经元的输出


            foreach (var neuron in hiddenLayer.Neurons)
            {
                neuron.CalculateOutput();
            }

            // 计算输出层神经元的输出


            double[] outputs = new double[outputLayer.Neurons.Length];
            for (int i = 0; i < outputLayer.Neurons.Length; i++)
            {
                outputs[i] = outputLayer.Neurons[i].CalculateOutput();
            }

            return outputs;
        }

        // 训练网络


        public void Train(double[] input, double[] target, double learningRate)
        {


            // 前向传播


            double[] output = FeedForward(input);

            // 计算输出层的误差


            for (int i = 0; i < outputLayer.Neurons.Length; i++)
            {
                outputLayer.Neurons[i].Error = (target[i] - output[i]) * outputLayer.Neurons[i].SigmoidDerivative();
            }

            // 反向传播到隐藏层


            for (int j = 0; j < hiddenLayer.Neurons.Length; j++)
            {
                double errorSum = 0;
                foreach (var link in hiddenLayer.Neurons[j].InputLinks)
                {
                    errorSum += link.ToNeuron.Error * link.Weight.Value;
                }
                hiddenLayer.Neurons[j].Error = errorSum * hiddenLayer.Neurons[j].SigmoidDerivative();
            }

            // 更新输出层的权重和偏置


            foreach (var neuron in outputLayer.Neurons)
            {
                neuron.Bias += learningRate * neuron.Error;
                foreach (var link in neuron.InputLinks)
                {
                    link.Weight.Value += learningRate * neuron.Error * link.FromNeuron.Output;
                }
            }

            // 更新隐藏层的权重和偏置


            foreach (var neuron in hiddenLayer.Neurons)
            {
                neuron.Bias += learningRate * neuron.Error;
                foreach (var link in neuron.InputLinks)
                {
                    link.Weight.Value += learningRate * neuron.Error * link.FromNeuron.Output;
                }
            }
        }
    }

    // 定义训练数据对象类


    class TrainingData
    {
        public double[] Input { get; set; }
        public double[] Target { get; set; }

        public TrainingData(double[] input, double[] target)
        {
            Input = input;
            Target = target;
        }
    }

    public class Program
    {

//测试


        public static void Main()
        {
            // 假设图片是 28x28 的黑白图片,输入层大小为 28x28
            int inputSize = 28 * 28;
            int hiddenSize = 30;
            int outputSize = 10; // 识别 0 - 9 数字

            NeuralNetwork neuralNetwork = new NeuralNetwork(inputSize, hiddenSize, outputSize);

            // 创建训练数据对象数组
            string dire = Application.StartupPath + "\\mnist160\\train\\";
            string[] directories = System.IO.Directory.GetDirectories(dire);
            List<TrainingData> allTrainingData = new List<TrainingData>();
            string[] files = null;
            foreach (string directory in directories)
            {
                files = System.IO.Directory.GetFiles(directory);
                for (int i = 0; i < files.Length; i++)
                {
                    // 读取图片
                    string imagePath = files[i];
                    double[] input = ReadImageAsInput(imagePath);
                    double[] target = new double[outputSize];
                    string dirname = new DirectoryInfo(directory).Name;
                    target[int.Parse(dirname)] = 1;
                    allTrainingData.Add(new TrainingData(input, target));
                }
            }
            TrainingData[] trainingData = allTrainingData.ToArray();

            // 训练网络


            double learningRate = 0.1;
            int epochs = 1000;
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                foreach (TrainingData data in trainingData)
                {
                    neuralNetwork.Train(data.Input, data.Target, learningRate);
                }
            }

            dire = Application.StartupPath + "\\mnist160\\test\\";
            directories = System.IO.Directory.GetDirectories(dire);
            foreach (string directory in directories)
            {
                files = System.IO.Directory.GetFiles(directory);


                // 测试网络


                foreach (var item in files)
                {
                    string testImagePath = item;
                    try
                    {
                        double[] testInput = ReadImageAsInput(testImagePath);
                        double[] output = neuralNetwork.FeedForward(testInput);

                        double maxVal = output[0];
                        for (int i = 1; i < output.Length; i++)
                        {
                            if (output[i] > maxVal)
                            {
                                maxVal = output[i];
                            }
                        }
                        int predictedDigit = Array.IndexOf(output, maxVal);

//输出结果


                        Console.WriteLine($"Predicted digit: {testImagePath}==={predictedDigit}");
                    }
                    catch (FileNotFoundException ex)
                    {
                        Console.WriteLine(ex.Message);
                    }
                }
            }
        }

        static double[] ReadImageAsInput(string imagePath)
        {
            if (!File.Exists(imagePath))
            {
                throw new FileNotFoundException($"Image file {imagePath} not found.");
            }

            using (Bitmap image = new Bitmap(imagePath))
            {
                double[] input = new double[28 * 28];
                int index = 0;
                for (int y = 0; y < 28; y++)
                {
                    for (int x = 0; x < 28; x++)
                    {
                        Color pixelColor = image.GetPixel(x, y);
                        // 将像素值转换为 0 到 1 之间的双精度值
                        input[index] = (pixelColor.R + pixelColor.G + pixelColor.B) / (3.0 * 255);
                        index++;
                    }
                }
                return input;
            }
        }
    }
}


http://www.niftyadmin.cn/n/5863352.html

相关文章

SAP S4HANA Administration (Mark Mergaerts Bert Vanstechelman)

SAP S4HANA Administration (Mark Mergaerts Bert Vanstechelman)

spring中事务为什么会回滚?什么原理?

事务回滚是保证数据一致性的关键机制&#xff0c;但如果事务回滚失效&#xff0c;可能会导致数据不一致的问题。我会用简单易懂的方式来讲解&#xff0c;帮助你理解事务回滚失效的常见原因及解决方法。 1. 什么是Spring事务回滚&#xff1f; 在Spring中&#xff0c;事务管理是…

中兴B863AV3.2-T/B863AV3.1-T2/B863AV3.1-T2K_电信高安_S905L3A-B_安卓9.0_线刷固件包

中兴B863AV3.2-T&#xff0f;B863AV3.1-T2&#xff0f;B863AV3.1-T2K_电信高安_S905L3A-B_安卓9.0_线刷固件包 B863AV3.2-T B863AV3.1-T2 已知可通刷贵州、江苏、贵州、北京、河南、陕西等省份。 线刷方法&#xff1a;&#xff08;新手参考借鉴一下&#xff09; 1、准备好一…

区块链相关方法-波士顿矩阵 (BCG Matrix)

波士顿矩阵&#xff08;BCG Matrix&#xff09;&#xff0c;又称市场增长率 - 相对市场份额矩阵、波士顿咨询集团法、四象限分析法、产品系列结构管理法等&#xff0c;由美国著名的管理学家、波士顿咨询公司创始人布鲁斯・亨德森于 1970 年首创1。以下是关于波士顿矩阵的详细介…

DeepSeek-R1之三_基于RAGFlowAI托管平台在Docker中部署搭建本地AI知识库

DeepSeekR1之三_基于RAGFlowAI托管平台在Docker中部署搭建本地AI知识库 文章目录 DeepSeekR1之三_基于RAGFlowAI托管平台在Docker中部署搭建本地AI知识库1. RAGFlow是什么1. 主要功能1. **"Quality in, quality out"**2. &#x1f371; **基于模板的文本切片**3. &am…

千峰React:函数组件使用(2)

前面写了三千字没保存&#xff0c;恨&#xff01; 批量渲染 function App() {const list [{id:0,text:aaaa},{id:1,text:bbbb},{id:2,text:cccc}]// for (let i 0; i < list.length; i) {// list[i] <li>{list[i]}</li>// }return (<div><…

从零开始用react + tailwindcs + express + mongodb实现一个聊天程序(一)

项目包含5个模块 1.首页 (聊天主页) 2.注册 3.登录 4.个人资料 5.设置主题 一、配置开发环境 建立项目文件夹 mkdir chat-project cd chat-project mkdir server && mkdir webcd server npm init cd web npm create vitelatest 创建前端项目时我们选择javascrip…

通俗易懂的DOM1级标准介绍

前言 在前端开发中&#xff0c;DOM&#xff08;文档对象模型&#xff09;是我们操作网页内容的核心工具。前面的文章我们介绍了DOM0级、DOM2级事件模型&#xff0c;没有DOM1级事件模型这种概念&#xff0c;但有DOM1级标准。今天我们就来讨论DOM1级标准&#xff0c;看看它到底做…