(define (make-semiring @+@ @*@ +0+) (list @+@ @*@ +0+))
(define (semiring-plus sr)          (car sr))
(define (semiring-mult sr)          (cadr sr))
(define (semiring-plus-unit sr)     (caddr sr))

(define reals-semiring    (make-semiring + * 0))
(define max-plus-semiring (make-semiring max + 0))
(define bools-semiring    (make-semiring (λ(a b) (or a b)) (λ(a b) (and a b)) #f))

(define (make-linalg semiring)
  (define @+@ (semiring-plus semiring))
  (define @*@ (semiring-mult semiring))
  (define +0+ (semiring-plus-unit semiring))
  
  (define (vec-zip $ u v)   (build-vector (vector-length u)
                                          (λ(i) ($ (vector-ref u i) (vector-ref v i))) ))
  (define (vec-foldl $ z v) 
    (define (iter i s)
      (if (= i (vector-length v))
          s
          (iter (+ i 1)
                ($ s (vector-ref v i)))))
    (iter 0 z))
  
  (define (make-vec xs)     (list->vector xs))
  (define (build-vec n f)   (build-vector n f))
  (define (vec-len v  )     (vector-length v))
  (define (vec-ref v i)     (vector-ref v i))
  (define (vec-map f v)     (build-vector (vector-length v) 
                                          (λ(i) (f (vector-ref v i)))))
  (define (scalar-prod u v) (vec-foldl @+@ +0+ (vec-zip @*@ u v)))
  (define (make-mat xss)    (make-vec (map make-vec xss)))
  (define (mat-ref m i j)   (vec-ref (vec-ref m j) i))
  (define (mat-dimx m)      (if (= 0 (vec-len m)) 
                                0 
                                (vec-len (vec-ref m 0))))
  (define (mat-dimy m)      (vector-length m))
  (define (build-mat x y f)
    (build-vec y (λ(j) (build-vec x (λ(i) (f i j))))))
  (define (mat-transpose m) 
    (build-mat (mat-dimy m) (mat-dimx m)
               (λ(i j) (mat-ref m j i))))
  (define (mat*vec m v)     (vec-zip scalar-prod m v))
  (define (mat*mat m1 m2)   
    (let ([m2t (mat-transpose m2)])
      (vec-map (λ(r1) (vec-map (λ(c2) (scalar-prod r1 c2)) m2t)) m1)))
  
  (lambda request
    (match request
      [`(make-vec ,xs)      (make-vec xs)]
      [`(build-vec ,n ,f)   (build-vec n f)]
      [`(vec-ref ,v ,i)     (vec-ref v i)]
      [`(vec-len ,v)        (vec-len v)]
      [`(vec-map ,f ,v)     (vec-map f v)]
      [`(scalar-prod ,u ,v) (scalar-prod u v)]
      [`(make-mat ,m)       (make-mat m)]
      [`(build-mat ,x ,y ,f) (build-mat x y f)]
      [`(mat-ref ,m ,i ,j)  (mat-ref m i j)]
      [`(mat-dimx ,m)       (mat-dimx m)]
      [`(mat-dimy ,m)       (mat-dimy m)]
      [`(mat-transpose ,m)  (mat-transpose m)]
      [`(mat*vec ,m ,v)     (mat*vec m v)]
      [`(mat*mat ,m1 ,m2)   (mat*mat m1 m2)])))

(define (make-vec lib xs)    (lib 'make-vec xs))
(define (build-vec lib n f)  (lib 'build-vec n f))
(define (vec-ref  lib xs i)  (lib 'vec-ref xs i))
(define (vec-len  lib xs)    (lib 'vec-len xs))
(define (vec-map  lib f v)   (lib 'vec-map f v))
(define (scalar-prod lib u v) (lib 'scalar-prod u v))

(define (make-mat lib m)     (lib 'make-mat m))
(define (build-mat lib x y f)  (lib 'build-mat x y f))
(define (mat-ref  lib m i j) (lib 'mat-ref m i j))
(define (mat-dimx lib m)     (lib 'mat-dimx m))
(define (mat-dimy lib m)     (lib 'mat-dimy m))
(define (mat-transpose lib m) (lib 'mat-transpose m))
  
(define (mat*vec  lib m v)   (lib 'mat*vec m v))
(define (mat*mat  lib m1 m2) (lib 'mat*mat m1 m2))



(define r-lib (make-linalg reals-semiring))
(define b-lib (make-linalg bools-semiring))

(define u (make-vec r-lib '(1 2 3)))
(define v (make-vec r-lib '(4 5 6)))
(scalar-prod r-lib u v)

(define m1 (make-mat r-lib '((1 2 3) (2 3 4))))
(define m2 (make-mat r-lib '((5 6) (7 8) (9 10))))
(mat*mat r-lib m1 m2)

(define m3 (make-mat b-lib '((#t #f #f) (#f #t #f))))
(define m4 (make-mat b-lib '((#t #f) (#f #t) (#f #t))))
(mat*mat b-lib m3 m4)
