From bfa7c451c73b958717502ad27688e6e9c0930b05 Mon Sep 17 00:00:00 2001 From: "E. Almqvist" Date: Tue, 20 Oct 2020 20:11:49 +0200 Subject: [PATCH] Tweaked the network etc --- rgbAI/lib/ailib/ai.py | 24 ++++- rgbAI/main.py | 8 +- rgbAI/trained_data/2020-10-20.txt | 163 ++++++++++++++++++++++++++++++ 3 files changed, 187 insertions(+), 8 deletions(-) create mode 100644 rgbAI/trained_data/2020-10-20.txt diff --git a/rgbAI/lib/ailib/ai.py b/rgbAI/lib/ailib/ai.py index 8d44a48..abf364c 100644 --- a/rgbAI/lib/ailib/ai.py +++ b/rgbAI/lib/ailib/ai.py @@ -1,13 +1,15 @@ import numpy as np from copy import deepcopy as copy -import os DEBUG_BUFFER = { "cost": None, "lr": { "weight": None, "bias": None - } + }, + "inp": None, + "predicted": None, + "correct": None } def sigmoid(x): @@ -27,6 +29,10 @@ def calcCost( predicted:np.array, correct:np.array ): # cost function, lower -> def getThinkCost( inp:np.array, predicted:np.array ): corr = correctFunc(inp) + + global DEBUG_BUFFER + DEBUG_BUFFER["correct"] = corr + return calcCost( predicted, corr ) def genRandomMatrix( x:int, y:int, min: float=0.0, max: float=1.0 ): # generate a matrix with x, y dimensions with random values from min-max in it @@ -56,6 +62,7 @@ def compareAIobjects( inp, obj1, obj2 ): global DEBUG_BUFFER DEBUG_BUFFER["cost"] = cost1 + DEBUG_BUFFER["predicted"] = res1 res2 = think( inp, obj2 ) cost2 = getThinkCost( inp, res2 ) # get the second cost @@ -163,11 +170,14 @@ def mutateProps( inpObj, curCost:float, maxLayer:int, gradient:list ): return obj def printProgress(): - global DEBUG_BUFFER + import os + global DEBUG_BUFFER os.system("clear") print(f"LR: {DEBUG_BUFFER['lr']}") print(f"Cost: {DEBUG_BUFFER['cost']}") + print("") + print(f"inp: {DEBUG_BUFFER['inp']} | pre: {DEBUG_BUFFER['predicted']} cor: {DEBUG_BUFFER['correct']}") def learn( inputNum:int, targetCost:float, obj, theta:float, curCost: float=None, trainForever: bool=False ): # Calculate the derivative for: @@ -177,7 +187,11 @@ def learn( inputNum:int, targetCost:float, obj, theta:float, curCost: float=None # i.e. : W' = W - lr * gradient (respect to W in layer i) = W - lr*[ dC / dW[i] ... ] # So if we change all the weights with i.e. 0.01 = theta, then we can derive the gradient with math and stuff - inp = np.asarray(np.random.rand( 1, inputNum ))[0] # create a random learning sample + #inp = np.asarray(np.random.rand( 1, inputNum ))[0] # create a random learning sample + inp = np.asarray([1.0, 1.0, 1.0]) + + global DEBUG_BUFFER + DEBUG_BUFFER["inp"] = inp while( trainForever or not curCost or curCost > targetCost ): # targetCost is the target for the cost function maxLen = len(obj.bias) @@ -190,3 +204,5 @@ def learn( inputNum:int, targetCost:float, obj, theta:float, curCost: float=None print("DONE\n") print(obj.weights) print(obj.bias) + + return obj diff --git a/rgbAI/main.py b/rgbAI/main.py index 3c2c565..37224db 100755 --- a/rgbAI/main.py +++ b/rgbAI/main.py @@ -7,11 +7,11 @@ class rgb(object): if( not loadedWeights or not loadedBias ): # if one is null (None) then just generate new ones print("Generating weights and biases...") - self.weights = [ ai.genRandomMatrix(3, 8), ai.genRandomMatrix(8, 8), ai.genRandomMatrix(8, 8), ai.genRandomMatrix(8, 3) ] # array of matrices of weights - # 3 input neurons -> 8 hidden neurons -> 8 hidden neurons -> 3 output neurons + self.weights = [ ai.genRandomMatrix(3, 16), ai.genRandomMatrix(16, 16), ai.genRandomMatrix(16, 16), ai.genRandomMatrix(16, 3) ] # array of matrices of weights + # 3 input neurons -> 16 hidden neurons -> 16 hidden neurons -> 3 output neurons # Generate the biases - self.bias = [ ai.genRandomMatrix(1, 8), ai.genRandomMatrix(1, 8), ai.genRandomMatrix(1, 8), ai.genRandomMatrix(1, 3) ] + self.bias = [ ai.genRandomMatrix(1, 16), ai.genRandomMatrix(1, 16), ai.genRandomMatrix(1, 16), ai.genRandomMatrix(1, 3) ] # This doesn't look very good, but it works so... print( self.weights ) @@ -41,7 +41,7 @@ class rgb(object): def init(): bot = rgb() - bot.learn() + bot = bot.learn() inpArr = np.asarray([1.0, 1.0, 1.0]) res = bot.think( inpArr ) diff --git a/rgbAI/trained_data/2020-10-20.txt b/rgbAI/trained_data/2020-10-20.txt new file mode 100644 index 0000000..f5b67fe --- /dev/null +++ b/rgbAI/trained_data/2020-10-20.txt @@ -0,0 +1,163 @@ +[array([[ 0.13657383, -0.23042832, -0.15878268, 0.01625717, 0.09307448, + 0.18129695, 0.2116157 , 0.03005379, -0.23906321, 0.33900869, + 0.53861629, -0.24311345, 0.56263432, 0.16710256, -0.20531924, + -0.04025091], + [-0.14493201, -0.03999969, -0.04536574, 0.57153015, 0.02922364, + -0.0858576 , 0.51631858, 0.24529872, -0.20769185, 0.43454781, + -0.23097147, -0.16451023, 0.44440651, 0.60449483, -0.14961027, + 0.00530867], + [ 0.47570477, 0.6321929 , 0.18826831, -0.09720299, -0.04231244, + -0.054177 , 0.31102066, 0.33274832, 0.09374735, 0.22414503, + 0.2214549 , 0.18879144, 0.08037877, -0.11656515, -0.03204839, + 0.43241358]]), array([[ 0.16805438, 0.31040123, 0.57585574, 0.41922052, 0.55521636, + -0.06293517, -0.24256846, 0.18656932, -0.13466644, 0.13617174, + -0.15533991, 0.72374925, 0.30054544, 0.06856631, 0.23241547, + 0.7275567 ], + [ 0.70020351, 0.21788736, 0.73347487, 0.28366899, 0.55231818, + 0.51453774, 0.66627062, -0.15582512, -0.04486987, 0.08910557, + -0.22164193, 0.61888889, 0.72705845, 0.50076326, 0.17261128, + 0.60235239], + [ 0.1759519 , -0.03520717, -0.1548805 , 0.50866399, -0.09330594, + -0.10404663, 0.57911986, -0.18048319, 0.62594238, -0.23703238, + 0.57377906, -0.17134602, 0.49190176, 0.28946516, 0.35465312, + 0.64585328], + [ 0.14643067, 0.18990906, 0.12110283, 0.30103004, 0.71224839, + 0.49863677, 0.21176263, 0.30926655, 0.35554206, 0.67969733, + 0.54055667, 0.1821767 , 0.13505438, 0.01354167, 0.71708359, + 0.38325529], + [ 0.21117687, -0.14987266, 0.70728899, 0.28207962, -0.12873177, + -0.2409365 , 0.0372645 , 0.69199145, 0.45591382, -0.216259 , + 0.04768903, 0.61919638, 0.04028485, 0.43677636, 0.34351561, + -0.17574492], + [-0.08483528, 0.54262422, 0.26430655, -0.09718621, -0.04815816, + 0.36051139, 0.57963613, 0.18824451, 0.37752941, -0.15812105, + 0.11207639, -0.10671547, 0.15248495, 0.54355872, 0.02105179, + -0.18640931], + [ 0.60007982, 0.42082565, 0.50058767, -0.08092875, -0.15311035, + 0.07481274, -0.06533989, 0.23647472, 0.17528246, 0.55421468, + 0.60617809, 0.59650491, 0.50975201, 0.30861656, 0.36631228, + 0.06977136], + [ 0.57098305, 0.23880751, 0.40955559, 0.02095128, -0.03336961, + 0.66587073, -0.13214473, -0.14161289, -0.09664933, 0.67537527, + -0.0167142 , 0.49583207, 0.18313851, 0.15646813, 0.07427188, + -0.0718608 ], + [-0.00987391, 0.55932519, -0.16337579, -0.05121916, 0.05253206, + 0.6599166 , 0.04628956, 0.34949889, -0.1062319 , -0.10879115, + 0.56021827, 0.66759027, 0.23128271, 0.63949833, 0.03886562, + 0.4834711 ], + [ 0.32905225, 0.11594583, 0.65178434, -0.04223889, 0.41896363, + 0.57333568, -0.00198302, 0.65220265, 0.68543686, 0.3683571 , + -0.02603741, 0.20018236, 0.16239414, 0.41868448, 0.18097101, + 0.74529124], + [ 0.59622387, 0.10582384, 0.01224179, 0.62591807, -0.09860738, + 0.58708991, -0.15917445, -0.23433273, 0.38125883, 0.62119401, + 0.53009452, 0.36465919, 0.58722475, 0.50328685, 0.11084216, + 0.70020325], + [-0.08007414, 0.70406113, 0.28146817, -0.23281128, 0.72570987, + 0.39208543, 0.54568921, 0.64135421, 0.13809933, 0.17369943, + 0.37720265, 0.65876394, 0.72315255, 0.2288451 , 0.07655611, + 0.27356539], + [ 0.5132845 , 0.06317595, 0.18229536, -0.14240562, -0.19588846, + 0.41032122, 0.12737565, 0.39377548, 0.31637434, 0.25621596, + 0.57986848, 0.23893487, 0.72227966, 0.00080269, 0.17514379, + -0.11195404], + [ 0.39934212, 0.49710534, 0.2748569 , 0.043366 , 0.1890571 , + 0.18852156, -0.04550403, 0.37443474, 0.05917768, 0.18225269, + 0.35863941, 0.35493 , 0.02507568, 0.60231666, 0.12705062, + -0.18656577], + [-0.21762182, 0.28420968, 0.0664954 , 0.61769083, -0.22446421, + -0.15199846, 0.06882507, 0.41997501, -0.19671519, 0.53733707, + 0.13745353, -0.0277646 , 0.31886772, -0.09922969, 0.27121758, + 0.16454755], + [-0.10369499, 0.57374993, 0.67271829, 0.49812178, 0.43472414, + -0.19226256, 0.59947474, 0.73502608, 0.13070667, 0.23848043, + 0.04348717, -0.14117485, 0.29202391, 0.31610878, 0.5516485 , + 0.31596777]]), array([[ 0.34477668, -0.06425723, 0.58359001, -0.14815158, -0.12798902, + 0.01122578, -0.09406077, -0.08156269, 0.47678927, -0.16832157, + 0.2989899 , -0.03149517, 0.54102489, 0.27687279, 0.07762957, + 0.30630602], + [-0.18628032, 0.48140631, 0.45127588, 0.04217173, -0.04761508, + 0.53454407, 0.07562896, 0.02482407, 0.34120148, 0.34104659, + 0.00800182, 0.30055017, 0.10554095, 0.34496294, -0.12992065, + 0.34394313], + [-0.13523876, -0.0971336 , 0.63950755, 0.36780703, 0.56103293, + -0.14544073, 0.50998027, 0.69755403, 0.17758161, 0.71673517, + -0.19777061, 0.34413252, 0.44859899, 0.19193068, 0.49557605, + 0.35980889], + [ 0.31448824, 0.55498673, 0.45547964, 0.39560346, 0.43964609, + 0.65156501, 0.60092278, 0.16079557, 0.29967172, 0.45635459, + -0.15559738, 0.09677894, -0.1473755 , 0.48921584, 0.03380254, + 0.53496187], + [ 0.20548827, 0.34162208, 0.24291551, 0.09517771, 0.42799721, + -0.21587983, 0.6276331 , 0.17878733, -0.09394476, 0.49486964, + 0.05069814, -0.03731232, 0.61765291, 0.39744186, 0.62045988, + 0.20484799], + [ 0.73550131, 0.58870611, -0.2149463 , 0.12072379, 0.33219201, + 0.26109787, 0.0982951 , 0.1284664 , 0.64710056, 0.0152397 , + 0.18381652, 0.73768473, 0.00092139, 0.03235858, 0.15103459, + 0.51981606], + [ 0.08840694, 0.26297443, 0.30844606, -0.1824745 , 0.18769289, + 0.63229512, 0.42350295, 0.31092172, 0.25027055, 0.74948354, + 0.22888569, -0.01773346, -0.11401654, -0.03108823, -0.13451111, + 0.63129517], + [ 0.54493297, 0.63825299, 0.24219368, 0.22564053, 0.06205102, + 0.30013008, -0.20026111, 0.36154652, 0.012798 , 0.35960868, + 0.14043231, 0.25682745, 0.61403935, 0.36427773, -0.15509794, + 0.18238906], + [ 0.6308805 , -0.12446523, 0.16979726, 0.10546963, 0.19526405, + -0.10779107, 0.69112703, 0.11795955, 0.59537578, 0.10703378, + -0.17694736, 0.17672196, 0.68900486, -0.1828981 , 0.01611003, + 0.03623836], + [ 0.38947013, 0.27422655, 0.10986504, 0.20315569, -0.14706311, + 0.0376341 , -0.05265269, 0.397534 , 0.52942913, 0.65569717, + 0.68041829, -0.07898684, 0.29142613, -0.05156493, 0.18032057, + 0.53108115], + [ 0.68532707, 0.34462347, 0.31177709, 0.47765148, 0.05164668, + 0.44140461, 0.19140832, 0.73120288, 0.49197314, 0.22130038, + 0.65074739, -0.05677663, -0.05086617, 0.08561628, 0.1741059 , + 0.04711781], + [ 0.30792548, 0.47748182, -0.13492255, -0.17833334, 0.43625885, + 0.3399424 , 0.72112565, 0.72721836, 0.18942578, 0.37540945, + 0.44928576, 0.19833672, -0.06612378, 0.44243888, 0.42291882, + 0.47173218], + [ 0.32018247, 0.49908758, 0.05845455, 0.6764048 , 0.61555869, + 0.14059561, 0.15291535, 0.58760176, -0.18642172, 0.68411407, + 0.72957624, 0.17706223, 0.50910493, 0.18660864, 0.73956186, + 0.07935623], + [ 0.7506918 , 0.24560067, -0.02707633, 0.53720368, -0.05693541, + 0.70035174, -0.07927034, 0.43913472, 0.39959292, -0.1352873 , + 0.27732047, 0.06998323, 0.06058959, 0.58006482, -0.24189536, + 0.08312415], + [-0.00128582, 0.3933684 , 0.51069906, 0.43526889, 0.09348662, + 0.03881429, 0.22353862, -0.02266519, -0.23932934, -0.22200674, + 0.33799059, 0.0798281 , 0.61122009, 0.00619084, 0.56015456, + 0.4992037 ], + [ 0.31080398, 0.74357545, 0.34472705, 0.03124108, 0.32283712, + -0.11436023, -0.18757332, -0.0401612 , 0.56430347, 0.74013264, + -0.21379409, 0.09312194, 0.65398465, 0.47030407, -0.24225389, + 0.34722671]]), array([[-0.72712555, -0.2209617 , -0.9213511 ], + [-0.23583579, -0.5001959 , -1.03502363], + [-0.54812013, -0.55230555, -0.10796313], + [-0.40919237, -0.27320155, -0.23281955], + [-0.60576907, -0.72301966, -0.19292567], + [-0.25151921, -0.68429653, -0.62658037], + [-0.92354466, -0.72998648, -0.48274747], + [-0.30593891, -0.68002346, -0.55250976], + [-0.37302378, -0.1585595 , -0.2190985 ], + [-0.86951793, -0.47558322, -0.150991 ], + [-0.55356092, -0.38416017, -0.4341505 ], + [-0.77230558, -0.01271597, -0.6261722 ], + [-0.15072927, -0.66071288, -1.08968044], + [-0.76838745, -0.51264552, -0.10541016], + [-0.21458776, -0.18523628, -0.39002985], + [-0.37013168, -1.05653186, -0.70163522]])] +[array([[ 0.66669102, 0.35633831, -0.07490808, 0.55737519, 0.37080628, + 0.4170183 , 0.38348049, 0.04615565, 0.71602324, -0.1080244 , + 0.48753738, -0.00088796, 0.72443262, 0.43165147, 0.41511408, + 0.73129322]]), array([[ 0.64329873, -0.19463148, 0.50162217, 0.43143105, 0.10745485, + 0.26871549, 0.29615103, -0.11371631, 0.67287694, 0.2765771 , + 0.29021985, 0.37520536, 0.37632065, 0.26134689, 0.42186981, + -0.11257989]]), array([[ 0.55071741, 0.54357322, 0.49354933, 0.02243901, -0.02043161, + 0.18415808, 0.58299901, -0.1365678 , 0.39731518, 0.51806039, + -0.00358498, 0.23929732, 0.65433423, 0.10841841, 0.70569666, + 0.25425315]]), array([[-1.36108333, -1.52971319, -1.49134245]])]