Skip to content

Commit

Permalink
🐛 fix: run on Apple Silicon (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Dec 8, 2023
1 parent 516833d commit 63934a9
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

🐛 **fix:** run on Apple Silicon ([110](https://github.com/owkin/GrAIdient/pull/110))\
⚙️ **core:** initForward,Backward model API ([109](https://github.com/owkin/GrAIdient/pull/109))\
🪜 **layer_1d:** Dropout1D ([#108](https://github.com/owkin/GrAIdient/pull/108))\
🪜 **feat:** VQGrad, VQGradSeq ([#107](https://github.com/owkin/GrAIdient/pull/107))
Expand Down
11 changes: 10 additions & 1 deletion Sources/GrAIdient/Metal/MetalKernel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,16 @@ public class MetalCommand
public func setBytes<T>(_ data: [T], atIndex index: Int)
{
let byteLength = data.count * MemoryLayout<T>.size
_encoder.setBytes(data, length: byteLength, index: index)
data.withUnsafeBufferPointer
{
dataPtr in

_encoder.setBytes(
UnsafeRawPointer(dataPtr.baseAddress)!,
length: byteLength,
index: index
)
}
}

///
Expand Down
72 changes: 28 additions & 44 deletions Sources/GrAIdient/Utils/Image.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//

import Foundation
import Cocoa
import AppKit

/// Error occuring when processing images.
public enum ImageError: Error
Expand Down Expand Up @@ -107,42 +107,14 @@ public class Image
let bufferPtr = metalBuffer.download()
let nbImages = metalBuffer.nbElems / (width * height * 3)

var output = [[UInt8]]()
for elem in 0..<nbImages
var images = [[Float]]()
for i in 0..<nbImages
{
var grid: [UInt8] = [UInt8](repeating: 0, count: width * height * 3)
grid.withUnsafeMutableBufferPointer { gridPtr in
Concurrency.slice(gridPtr.count)
{
(index: Int) in

let depth = index / (width * height)
let i = (index - depth * width * height) / width
let j = (index - depth * width * height) % width

let offsetGet = elem * 3 * height * width
let offsetSet = j + i * width

let valTmp = bufferPtr[index + offsetGet] * 255.0
let val: UInt8
if valTmp < 0
{
val = 0
}
else if valTmp > 255.0
{
val = 255
}
else
{
val = UInt8(valTmp)
}

gridPtr[3 * offsetSet + depth] = val
}}
output.append(grid)
images.append([Float](
bufferPtr[i * 3 * height * width..<(i+1) * 3 * height * width]
))
}
return output
return toRGB(toPixel(images), width: width, height: height)
}

///
Expand All @@ -157,7 +129,8 @@ public class Image
var output = [[UInt8]]()
for elem in 0..<images.count
{
output.append(images[elem].map {
output.append(images[elem].map
{
let valTmp = $0 * T(255.0)
let val: UInt8
if valTmp < 0
Expand Down Expand Up @@ -385,15 +358,26 @@ public extension NSImage
///
func extractPixels() throws -> [UInt8]
{
if let imageData = tiffRepresentation,
let imageRep = NSBitmapImageRep(data: imageData),
let dataPtr = imageRep.bitmapData
if let pixelData = (cgImage(
forProposedRect: nil, context: nil, hints: nil)!).dataProvider?.data
{
let bufferPtr = UnsafeBufferPointer(
start: dataPtr,
count: Int(3 * size.height * size.width)
)
return [UInt8](bufferPtr)
let data: UnsafePointer<UInt8> = CFDataGetBytePtr(pixelData)

var pixels = [UInt8]()
for i in 0..<Int(size.height) {
for j in 0..<Int(size.width)
{
let pos = CGPoint(x: j, y: i)

let pixelInfo: Int = (Int(size.width) * Int(pos.y) * 4) +
Int(pos.x) * 4

let r = data[pixelInfo]
let g = data[pixelInfo + 1]
let b = data[pixelInfo + 2]
pixels += [r, g, b]
}}
return pixels
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion Tests/GrAIExamples/Base/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import GrAIdient
/// Python library default path.
let PYTHON_LIB =
FileManager.default.homeDirectoryForCurrentUser.path +
"/opt/anaconda3/envs/graiexamples/lib/libpython3.9.dylib"
"/miniconda3/envs/graiexamples/lib/libpython3.9.dylib"

/// Set the Python library path.
func setPythonLib()
Expand Down
2 changes: 1 addition & 1 deletion Tests/GrAITorchTests/Base/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import GrAIdient
/// Python library default path.
let PYTHON_LIB =
FileManager.default.homeDirectoryForCurrentUser.path +
"/opt/anaconda3/envs/graitorch/lib/libpython3.9.dylib"
"/miniconda3/envs/graitorch/lib/libpython3.9.dylib"

/// Set the Python library path.
func setPythonLib()
Expand Down

0 comments on commit 63934a9

Please sign in to comment.