1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| // TrainXor_RandomSearch.cpp // UMUTech @ 2018-07-05 23:45:52 // Be aware that I'm only a novice to ANN. My apologies for any wrong info. // #include <algorithm> #include <iostream> #include <random>
std::default_random_engine random_engine;
void RandomizeW(double* w, size_t size) { std::uniform_real_distribution<double> r(0, 1); for (size_t i = 0; i < size; ++i) { w[i] = r(random_engine); } }
void PrintW(double* w, size_t size) { for (size_t i = 0; i < size; ++i) { std::cout << i << "\t" << w[i] << "\n"; } }
double ActivationFunction(double x) { // ReLU return std::max(0.0, x); }
double AnnRun(const double x[2], double* w) { // bias 乘了 -1,让结果更好地收敛到 [0, 1] double f = ActivationFunction(x[0] * w[0] + x[1] * w[1] - w[2]); double g = ActivationFunction(x[0] * w[3] + x[1] * w[4] - w[5]); return ActivationFunction(f * w[6] + g * w[7] - w[8]); }
int main() { const double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}}; const double expect_output[4] = {0, 1, 1, 0};
double last_error = 1000;
double w[3 * 3]; double w_copy[3 * 3];
std::random_device rd; random_engine.seed(rd());
int train_count = 0; for (; last_error > 0.01; ++train_count) { if (train_count % 10000 == 0) { std::cout << "Randomize\n"; RandomizeW(w, _countof(w)); }
memcpy(w_copy, w, sizeof(w));
// 随机改变 w std::uniform_real_distribution<double> r(-0.5, 0.5); for (int i = 0; i < 3 * 3; ++i) { w[i] += r(random_engine); }
double error = pow(AnnRun(input[0], w) - expect_output[0], 2.0); error += pow(AnnRun(input[1], w) - expect_output[1], 2.0); error += pow(AnnRun(input[2], w) - expect_output[2], 2.0); error += pow(AnnRun(input[3], w) - expect_output[3], 2.0);
if (error < last_error) { // 错误率更小,保存 last_error = error; } else { // 恢复 w memcpy(w, w_copy, sizeof(w)); } }
printf("Finished in %d loops.\n", train_count);
PrintW(w, _countof(w));
/* Run the network and see what it predicts. */ printf("Output for [%1.f, %1.f] is %1.f.\n", input[0][0], input[0][1], AnnRun(input[0], w)); printf("Output for [%1.f, %1.f] is %1.f.\n", input[1][0], input[1][1], AnnRun(input[1], w)); printf("Output for [%1.f, %1.f] is %1.f.\n", input[2][0], input[2][1], AnnRun(input[2], w)); printf("Output for [%1.f, %1.f] is %1.f.\n", input[3][0], input[3][1], AnnRun(input[3], w));
return 0; }
|